diff --git a/venv/lib/python3.13/site-packages/transformers/models/albert/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/albert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..57b5747909e091ede05ff07c98254224fbebed97
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/albert/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_albert import *
+ from .modeling_albert import *
+ from .modeling_flax_albert import *
+ from .modeling_tf_albert import *
+ from .tokenization_albert import *
+ from .tokenization_albert_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/albert/configuration_albert.py b/venv/lib/python3.13/site-packages/transformers/models/albert/configuration_albert.py
new file mode 100644
index 0000000000000000000000000000000000000000..b60c19d504f05f50abb0988341526f53af8ad4db
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/albert/configuration_albert.py
@@ -0,0 +1,170 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ALBERT model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+
+
+class AlbertConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`AlbertModel`] or a [`TFAlbertModel`]. It is used
+ to instantiate an ALBERT model according to the specified arguments, defining the model architecture. Instantiating
+ a configuration with the defaults will yield a similar configuration to that of the ALBERT
+ [albert/albert-xxlarge-v2](https://huggingface.co/albert/albert-xxlarge-v2) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30000):
+ Vocabulary size of the ALBERT model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`AlbertModel`] or [`TFAlbertModel`].
+ embedding_size (`int`, *optional*, defaults to 128):
+ Dimensionality of vocabulary embeddings.
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_hidden_groups (`int`, *optional*, defaults to 1):
+ Number of groups for the hidden layers, parameters in the same group are shared.
+ num_attention_heads (`int`, *optional*, defaults to 64):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 16384):
+ The dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ inner_group_num (`int`, *optional*, defaults to 1):
+ The number of inner repetition of attention and ffn.
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu_new"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`AlbertModel`] or [`TFAlbertModel`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ classifier_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for attached classifiers.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 2):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 3):
+ End of stream token id.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AlbertConfig, AlbertModel
+
+ >>> # Initializing an ALBERT-xxlarge style configuration
+ >>> albert_xxlarge_configuration = AlbertConfig()
+
+ >>> # Initializing an ALBERT-base style configuration
+ >>> albert_base_configuration = AlbertConfig(
+ ... hidden_size=768,
+ ... num_attention_heads=12,
+ ... intermediate_size=3072,
+ ... )
+
+ >>> # Initializing a model (with random weights) from the ALBERT-base style configuration
+ >>> model = AlbertModel(albert_xxlarge_configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "albert"
+
+ def __init__(
+ self,
+ vocab_size=30000,
+ embedding_size=128,
+ hidden_size=4096,
+ num_hidden_layers=12,
+ num_hidden_groups=1,
+ num_attention_heads=64,
+ intermediate_size=16384,
+ inner_group_num=1,
+ hidden_act="gelu_new",
+ hidden_dropout_prob=0,
+ attention_probs_dropout_prob=0,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ classifier_dropout_prob=0.1,
+ position_embedding_type="absolute",
+ pad_token_id=0,
+ bos_token_id=2,
+ eos_token_id=3,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.embedding_size = embedding_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_hidden_groups = num_hidden_groups
+ self.num_attention_heads = num_attention_heads
+ self.inner_group_num = inner_group_num
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.classifier_dropout_prob = classifier_dropout_prob
+ self.position_embedding_type = position_embedding_type
+
+
+# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert
+class AlbertOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ("token_type_ids", dynamic_axis),
+ ]
+ )
+
+
+__all__ = ["AlbertConfig", "AlbertOnnxConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/albert/modeling_albert.py b/venv/lib/python3.13/site-packages/transformers/models/albert/modeling_albert.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cc129366baea19b78ab5e7335fa21c5a371326b
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/albert/modeling_albert.py
@@ -0,0 +1,1349 @@
+# coding=utf-8
+# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ALBERT model."""
+
+import math
+import os
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPooling,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import (
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from ...utils import ModelOutput, auto_docstring, logging
+from .configuration_albert import AlbertConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ print(name)
+
+ for name, array in zip(names, arrays):
+ original_name = name
+
+ # If saved from the TF HUB module
+ name = name.replace("module/", "")
+
+ # Renaming and simplifying
+ name = name.replace("ffn_1", "ffn")
+ name = name.replace("bert/", "albert/")
+ name = name.replace("attention_1", "attention")
+ name = name.replace("transform/", "")
+ name = name.replace("LayerNorm_1", "full_layer_layer_norm")
+ name = name.replace("LayerNorm", "attention/LayerNorm")
+ name = name.replace("transformer/", "")
+
+ # The feed forward layer had an 'intermediate' step which has been abstracted away
+ name = name.replace("intermediate/dense/", "")
+ name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")
+
+ # ALBERT attention was split between self and output which have been abstracted away
+ name = name.replace("/output/", "/")
+ name = name.replace("/self/", "/")
+
+ # The pooler is a linear layer
+ name = name.replace("pooler/dense", "pooler")
+
+ # The classifier was simplified to predictions from cls/predictions
+ name = name.replace("cls/predictions", "predictions")
+ name = name.replace("predictions/attention", "predictions")
+
+ # Naming was changed to be more explicit
+ name = name.replace("embeddings/attention", "embeddings")
+ name = name.replace("inner_group_", "albert_layers/")
+ name = name.replace("group_", "albert_layer_groups/")
+
+ # Classifier
+ if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
+ name = "classifier/" + name
+
+ # No ALBERT model currently handles the next sentence prediction task
+ if "seq_relationship" in name:
+ name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
+ name = name.replace("weights", "weight")
+
+ name = name.split("/")
+
+ # Ignore the gradients applied by the LAMB/ADAM optimizers.
+ if (
+ "adam_m" in name
+ or "adam_v" in name
+ or "AdamWeightDecayOptimizer" in name
+ or "AdamWeightDecayOptimizer_1" in name
+ or "global_step" in name
+ ):
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ try:
+ pointer = getattr(pointer, scope_names[0])
+ except AttributeError:
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+
+ if m_name[-11:] == "_embeddings":
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+ except ValueError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ print(f"Initialize PyTorch weight {name} from {original_name}")
+ pointer.data = torch.from_numpy(array)
+
+ return model
+
+
+class AlbertEmbeddings(nn.Module):
+ """
+ Construct the embeddings from word, position and token_type embeddings.
+ """
+
+ def __init__(self, config: AlbertConfig):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer(
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
+ )
+
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values_length: int = 0,
+ ) -> torch.Tensor:
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class AlbertAttention(nn.Module):
+ def __init__(self, config: AlbertConfig):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads}"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.attention_head_size = config.hidden_size // config.num_attention_heads
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pruned_heads = set()
+
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ def prune_heads(self, heads: list[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.query = prune_linear_layer(self.query, index)
+ self.key = prune_linear_layer(self.key, index)
+ self.value = prune_linear_layer(self.value, index)
+ self.dense = prune_linear_layer(self.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.num_attention_heads = self.num_attention_heads - len(heads)
+ self.all_head_size = self.attention_head_size * self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
+ batch_size, seq_length, _ = hidden_states.shape
+ query_layer = self.query(hidden_states)
+ key_layer = self.key(hidden_states)
+ value_layer = self.value(hidden_states)
+ query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
+ 1, 2
+ )
+ key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
+ value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
+ 1, 2
+ )
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.attention_dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.transpose(2, 1).flatten(2)
+
+ projected_context_layer = self.dense(context_layer)
+ projected_context_layer_dropout = self.output_dropout(projected_context_layer)
+ layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
+ return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
+
+
+class AlbertSdpaAttention(AlbertAttention):
+ def __init__(self, config):
+ super().__init__(config)
+ self.dropout_prob = config.attention_probs_dropout_prob
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
+ if self.position_embedding_type != "absolute" or output_attentions:
+ logger.warning(
+ "AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
+ "non-absolute `position_embedding_type` or `output_attentions=True` . Falling back to "
+ "the eager attention implementation, but specifying the eager implementation will be required from "
+ "Transformers version v5.0.0 onwards. This warning can be removed using the argument "
+ '`attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(hidden_states, attention_mask, output_attentions=output_attentions)
+
+ batch_size, seq_len, _ = hidden_states.size()
+ query_layer = (
+ self.query(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+ key_layer = (
+ self.key(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+ value_layer = (
+ self.value(hidden_states)
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+ .transpose(1, 2)
+ )
+
+ attention_output = torch.nn.functional.scaled_dot_product_attention(
+ query=query_layer,
+ key=key_layer,
+ value=value_layer,
+ attn_mask=attention_mask,
+ dropout_p=self.dropout_prob if self.training else 0.0,
+ is_causal=False,
+ )
+
+ attention_output = attention_output.transpose(1, 2)
+ attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size)
+
+ projected_context_layer = self.dense(attention_output)
+ projected_context_layer_dropout = self.output_dropout(projected_context_layer)
+ layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
+ return (layernormed_context_layer,)
+
+
+ALBERT_ATTENTION_CLASSES = {
+ "eager": AlbertAttention,
+ "sdpa": AlbertSdpaAttention,
+}
+
+
+class AlbertLayer(nn.Module):
+ def __init__(self, config: AlbertConfig):
+ super().__init__()
+
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config)
+ self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.activation = ACT2FN[config.hidden_act]
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
+
+ ffn_output = apply_chunking_to_forward(
+ self.ff_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output[0],
+ )
+ hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
+
+ return (hidden_states,) + attention_output[1:] # add attentions if we output them
+
+ def ff_chunk(self, attention_output: torch.Tensor) -> torch.Tensor:
+ ffn_output = self.ffn(attention_output)
+ ffn_output = self.activation(ffn_output)
+ ffn_output = self.ffn_output(ffn_output)
+ return ffn_output
+
+
+class AlbertLayerGroup(nn.Module):
+ def __init__(self, config: AlbertConfig):
+ super().__init__()
+
+ self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
+ layer_hidden_states = ()
+ layer_attentions = ()
+
+ for layer_index, albert_layer in enumerate(self.albert_layers):
+ layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
+ hidden_states = layer_output[0]
+
+ if output_attentions:
+ layer_attentions = layer_attentions + (layer_output[1],)
+
+ if output_hidden_states:
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
+
+ outputs = (hidden_states,)
+ if output_hidden_states:
+ outputs = outputs + (layer_hidden_states,)
+ if output_attentions:
+ outputs = outputs + (layer_attentions,)
+ return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
+
+
+class AlbertTransformer(nn.Module):
+ def __init__(self, config: AlbertConfig):
+ super().__init__()
+
+ self.config = config
+ self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
+ self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[BaseModelOutput, tuple]:
+ hidden_states = self.embedding_hidden_mapping_in(hidden_states)
+
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask
+
+ for i in range(self.config.num_hidden_layers):
+ # Number of layers in a hidden group
+ layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
+
+ # Index of the hidden group
+ group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
+
+ layer_group_output = self.albert_layer_groups[group_idx](
+ hidden_states,
+ attention_mask,
+ head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
+ output_attentions,
+ output_hidden_states,
+ )
+ hidden_states = layer_group_output[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + layer_group_output[-1]
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+
+@auto_docstring
+class AlbertPreTrainedModel(PreTrainedModel):
+ config: AlbertConfig
+ load_tf_weights = load_tf_weights_in_albert
+ base_model_prefix = "albert"
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, AlbertMLMHead):
+ module.bias.data.zero_()
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`AlbertForPreTraining`].
+ """
+)
+class AlbertForPreTrainingOutput(ModelOutput):
+ r"""
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
+ (classification) loss.
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ sop_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+ before SoftMax).
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ prediction_logits: Optional[torch.FloatTensor] = None
+ sop_logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+@auto_docstring
+class AlbertModel(AlbertPreTrainedModel):
+ config: AlbertConfig
+ base_model_prefix = "albert"
+
+ def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True):
+ r"""
+ add_pooling_layer (bool, *optional*, defaults to `True`):
+ Whether to add a pooling layer
+ """
+ super().__init__(config)
+
+ self.config = config
+ self.embeddings = AlbertEmbeddings(config)
+ self.encoder = AlbertTransformer(config)
+ if add_pooling_layer:
+ self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
+ self.pooler_activation = nn.Tanh()
+ else:
+ self.pooler = None
+ self.pooler_activation = None
+
+ self.attn_implementation = config._attn_implementation
+ self.position_embedding_type = config.position_embedding_type
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Embedding:
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
+ a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
+ model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers.
+
+ These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
+ while [2,3] correspond to the two inner groups of the second hidden layer.
+
+ Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more
+ information about head pruning
+ """
+ for layer, heads in heads_to_prune.items():
+ group_idx = int(layer / self.config.inner_group_num)
+ inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
+ self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[BaseModelOutputWithPooling, tuple]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ embedding_output = self.embeddings(
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
+ )
+
+ use_sdpa_attention_mask = (
+ self.attn_implementation == "sdpa"
+ and self.position_embedding_type == "absolute"
+ and head_mask is None
+ and not output_attentions
+ )
+
+ if use_sdpa_attention_mask:
+ extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ attention_mask, embedding_output.dtype, tgt_len=seq_length
+ )
+ else:
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
+
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ extended_attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs[0]
+
+ pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
+ `sentence order prediction (classification)` head.
+ """
+)
+class AlbertForPreTraining(AlbertPreTrainedModel):
+ _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
+
+ def __init__(self, config: AlbertConfig):
+ super().__init__(config)
+
+ self.albert = AlbertModel(config)
+ self.predictions = AlbertMLMHead(config)
+ self.sop_classifier = AlbertSOPHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self) -> nn.Linear:
+ return self.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
+ self.predictions.decoder = new_embeddings
+
+ def get_input_embeddings(self) -> nn.Embedding:
+ return self.albert.embeddings.word_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ sentence_order_label: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[AlbertForPreTrainingOutput, tuple]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ sentence_order_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
+ (see `input_ids` docstring) Indices should be in `[0, 1]`. `0` indicates original order (sequence A, then
+ sequence B), `1` indicates switched order (sequence B, then sequence A).
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AlbertForPreTraining
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
+ >>> model = AlbertForPreTraining.from_pretrained("albert/albert-base-v2")
+
+ >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
+ >>> # Batch size 1
+ >>> outputs = model(input_ids)
+
+ >>> prediction_logits = outputs.prediction_logits
+ >>> sop_logits = outputs.sop_logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.albert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output, pooled_output = outputs[:2]
+
+ prediction_scores = self.predictions(sequence_output)
+ sop_scores = self.sop_classifier(pooled_output)
+
+ total_loss = None
+ if labels is not None and sentence_order_label is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
+ total_loss = masked_lm_loss + sentence_order_loss
+
+ if not return_dict:
+ output = (prediction_scores, sop_scores) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return AlbertForPreTrainingOutput(
+ loss=total_loss,
+ prediction_logits=prediction_scores,
+ sop_logits=sop_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class AlbertMLMHead(nn.Module):
+ def __init__(self, config: AlbertConfig):
+ super().__init__()
+
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+ self.dense = nn.Linear(config.hidden_size, config.embedding_size)
+ self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
+ self.activation = ACT2FN[config.hidden_act]
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+
+ prediction_scores = hidden_states
+
+ return prediction_scores
+
+ def _tie_weights(self) -> None:
+ # For accelerate compatibility and to not break backward compatibility
+ if self.decoder.bias.device.type == "meta":
+ self.decoder.bias = self.bias
+ else:
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
+ self.bias = self.decoder.bias
+
+
+class AlbertSOPHead(nn.Module):
+ def __init__(self, config: AlbertConfig):
+ super().__init__()
+
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
+ dropout_pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(dropout_pooled_output)
+ return logits
+
+
+@auto_docstring
+class AlbertForMaskedLM(AlbertPreTrainedModel):
+ _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.albert = AlbertModel(config, add_pooling_layer=False)
+ self.predictions = AlbertMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self) -> nn.Linear:
+ return self.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
+ self.predictions.decoder = new_embeddings
+ self.predictions.bias = new_embeddings.bias
+
+ def get_input_embeddings(self) -> nn.Embedding:
+ return self.albert.embeddings.word_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[MaskedLMOutput, tuple]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoTokenizer, AlbertForMaskedLM
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
+ >>> model = AlbertForMaskedLM.from_pretrained("albert/albert-base-v2")
+
+ >>> # add mask_token
+ >>> inputs = tokenizer("The capital of [MASK] is Paris.", return_tensors="pt")
+ >>> with torch.no_grad():
+ ... logits = model(**inputs).logits
+
+ >>> # retrieve index of [MASK]
+ >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
+ >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
+ >>> tokenizer.decode(predicted_token_id)
+ 'france'
+ ```
+
+ ```python
+ >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
+ >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
+ >>> outputs = model(**inputs, labels=labels)
+ >>> round(outputs.loss.item(), 2)
+ 0.81
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.albert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_outputs = outputs[0]
+
+ prediction_scores = self.predictions(sequence_outputs)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """
+)
+class AlbertForSequenceClassification(AlbertPreTrainedModel):
+ def __init__(self, config: AlbertConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.albert = AlbertModel(config)
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[SequenceClassifierOutput, tuple]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.albert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class AlbertForTokenClassification(AlbertPreTrainedModel):
+ def __init__(self, config: AlbertConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.albert = AlbertModel(config, add_pooling_layer=False)
+ classifier_dropout_prob = (
+ config.classifier_dropout_prob
+ if config.classifier_dropout_prob is not None
+ else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[TokenClassifierOutput, tuple]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.albert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class AlbertForQuestionAnswering(AlbertPreTrainedModel):
+ def __init__(self, config: AlbertConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.albert = AlbertModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[AlbertForPreTrainingOutput, tuple]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.albert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits: torch.Tensor = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class AlbertForMultipleChoice(AlbertPreTrainedModel):
+ def __init__(self, config: AlbertConfig):
+ super().__init__(config)
+
+ self.albert = AlbertModel(config)
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[AlbertForPreTrainingOutput, tuple]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
+ *input_ids* above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+ outputs = self.albert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits: torch.Tensor = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "load_tf_weights_in_albert",
+ "AlbertPreTrainedModel",
+ "AlbertModel",
+ "AlbertForPreTraining",
+ "AlbertForMaskedLM",
+ "AlbertForSequenceClassification",
+ "AlbertForTokenClassification",
+ "AlbertForQuestionAnswering",
+ "AlbertForMultipleChoice",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/albert/modeling_flax_albert.py b/venv/lib/python3.13/site-packages/transformers/models/albert/modeling_flax_albert.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2f19cb27716fb3f8846ef88e870e3eb1188a4bf
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/albert/modeling_flax_albert.py
@@ -0,0 +1,1132 @@
+# coding=utf-8
+# Copyright 2021 Google AI, Google Brain and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional
+
+import flax
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+
+from ...modeling_flax_outputs import (
+ FlaxBaseModelOutput,
+ FlaxBaseModelOutputWithPooling,
+ FlaxMaskedLMOutput,
+ FlaxMultipleChoiceModelOutput,
+ FlaxQuestionAnsweringModelOutput,
+ FlaxSequenceClassifierOutput,
+ FlaxTokenClassifierOutput,
+)
+from ...modeling_flax_utils import (
+ ACT2FN,
+ FlaxPreTrainedModel,
+ append_call_sample_docstring,
+ append_replace_return_docstrings,
+ overwrite_call_docstring,
+)
+from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from .configuration_albert import AlbertConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "albert/albert-base-v2"
+_CONFIG_FOR_DOC = "AlbertConfig"
+
+
+@flax.struct.dataclass
+class FlaxAlbertForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`FlaxAlbertForPreTraining`].
+
+ Args:
+ prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ sop_logits (`jnp.ndarray` of shape `(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+ before SoftMax).
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ prediction_logits: jnp.ndarray = None
+ sop_logits: jnp.ndarray = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+ attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+ALBERT_START_DOCSTRING = r"""
+
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
+
+ This model is also a
+ [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
+ a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
+ behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`AlbertConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+ALBERT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`numpy.ndarray` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+"""
+
+
+class FlaxAlbertEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.word_embeddings = nn.Embed(
+ self.config.vocab_size,
+ self.config.embedding_size,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.position_embeddings = nn.Embed(
+ self.config.max_position_embeddings,
+ self.config.embedding_size,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.token_type_embeddings = nn.Embed(
+ self.config.type_vocab_size,
+ self.config.embedding_size,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
+
+ def __call__(self, input_ids, token_type_ids, position_ids, deterministic: bool = True):
+ # Embed
+ inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
+ position_embeds = self.position_embeddings(position_ids.astype("i4"))
+ token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
+
+ # Sum all embeddings
+ hidden_states = inputs_embeds + token_type_embeddings + position_embeds
+
+ # Layer Norm
+ hidden_states = self.LayerNorm(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ return hidden_states
+
+
+class FlaxAlbertSelfAttention(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ if self.config.hidden_size % self.config.num_attention_heads != 0:
+ raise ValueError(
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
+ " : {self.config.num_attention_heads}"
+ )
+
+ self.query = nn.Dense(
+ self.config.hidden_size,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ )
+ self.key = nn.Dense(
+ self.config.hidden_size,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ )
+ self.value = nn.Dense(
+ self.config.hidden_size,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ )
+ self.dense = nn.Dense(
+ self.config.hidden_size,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ dtype=self.dtype,
+ )
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
+
+ def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
+
+ query_states = self.query(hidden_states).reshape(
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
+ )
+ value_states = self.value(hidden_states).reshape(
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
+ )
+ key_states = self.key(hidden_states).reshape(
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
+ )
+
+ # Convert the boolean attention mask to an attention bias.
+ if attention_mask is not None:
+ # attention mask in the form of attention bias
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+ attention_bias = lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
+ )
+ else:
+ attention_bias = None
+
+ dropout_rng = None
+ if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ bias=attention_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.config.attention_probs_dropout_prob,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ precision=None,
+ )
+
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+ attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
+
+ projected_attn_output = self.dense(attn_output)
+ projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic)
+ layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states)
+ outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,)
+ return outputs
+
+
+class FlaxAlbertLayer(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype)
+ self.ffn = nn.Dense(
+ self.config.intermediate_size,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ dtype=self.dtype,
+ )
+ self.activation = ACT2FN[self.config.hidden_act]
+ self.ffn_output = nn.Dense(
+ self.config.hidden_size,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ dtype=self.dtype,
+ )
+ self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ ):
+ attention_outputs = self.attention(
+ hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
+ )
+ attention_output = attention_outputs[0]
+ ffn_output = self.ffn(attention_output)
+ ffn_output = self.activation(ffn_output)
+ ffn_output = self.ffn_output(ffn_output)
+ ffn_output = self.dropout(ffn_output, deterministic=deterministic)
+ hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attention_outputs[1],)
+ return outputs
+
+
+class FlaxAlbertLayerCollection(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layers = [
+ FlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num)
+ ]
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ ):
+ layer_hidden_states = ()
+ layer_attentions = ()
+
+ for layer_index, albert_layer in enumerate(self.layers):
+ layer_output = albert_layer(
+ hidden_states,
+ attention_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_output[0]
+
+ if output_attentions:
+ layer_attentions = layer_attentions + (layer_output[1],)
+
+ if output_hidden_states:
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
+
+ outputs = (hidden_states,)
+ if output_hidden_states:
+ outputs = outputs + (layer_hidden_states,)
+ if output_attentions:
+ outputs = outputs + (layer_attentions,)
+ return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
+
+
+class FlaxAlbertLayerCollections(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ layer_index: Optional[str] = None
+
+ def setup(self):
+ self.albert_layers = FlaxAlbertLayerCollection(self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ ):
+ outputs = self.albert_layers(
+ hidden_states,
+ attention_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ return outputs
+
+
+class FlaxAlbertLayerGroups(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layers = [
+ FlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype)
+ for i in range(self.config.num_hidden_groups)
+ ]
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ all_attentions = () if output_attentions else None
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
+
+ for i in range(self.config.num_hidden_layers):
+ # Index of the hidden group
+ group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
+ layer_group_output = self.layers[group_idx](
+ hidden_states,
+ attention_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ hidden_states = layer_group_output[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + layer_group_output[-1]
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
+ return FlaxBaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+
+class FlaxAlbertEncoder(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.embedding_hidden_mapping_in = nn.Dense(
+ self.config.hidden_size,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ dtype=self.dtype,
+ )
+ self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ hidden_states = self.embedding_hidden_mapping_in(hidden_states)
+ return self.albert_layer_groups(
+ hidden_states,
+ attention_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+
+class FlaxAlbertOnlyMLMHead(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32
+ bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
+
+ def setup(self):
+ self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype)
+ self.activation = ACT2FN[self.config.hidden_act]
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+ self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
+ self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
+
+ def __call__(self, hidden_states, shared_embedding=None):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+
+ if shared_embedding is not None:
+ hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
+ else:
+ hidden_states = self.decoder(hidden_states)
+
+ hidden_states += self.bias
+ return hidden_states
+
+
+class FlaxAlbertSOPHead(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.dropout = nn.Dropout(self.config.classifier_dropout_prob)
+ self.classifier = nn.Dense(2, dtype=self.dtype)
+
+ def __call__(self, pooled_output, deterministic=True):
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
+ logits = self.classifier(pooled_output)
+ return logits
+
+
+class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = AlbertConfig
+ base_model_prefix = "albert"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: AlbertConfig,
+ input_shape: tuple = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ token_type_ids = jnp.zeros_like(input_ids)
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
+ attention_mask = jnp.ones_like(input_ids)
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ random_params = self.module.init(
+ rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False
+ )["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ params: Optional[dict] = None,
+ dropout_rng: jax.random.PRNGKey = None,
+ train: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ # init input tensors if not passed
+ if token_type_ids is None:
+ token_type_ids = jnp.zeros_like(input_ids)
+
+ if position_ids is None:
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ return self.module.apply(
+ {"params": params or self.params},
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ jnp.array(token_type_ids, dtype="i4"),
+ jnp.array(position_ids, dtype="i4"),
+ not train,
+ output_attentions,
+ output_hidden_states,
+ return_dict,
+ rngs=rngs,
+ )
+
+
+class FlaxAlbertModule(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ add_pooling_layer: bool = True
+
+ def setup(self):
+ self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype)
+ self.encoder = FlaxAlbertEncoder(self.config, dtype=self.dtype)
+ if self.add_pooling_layer:
+ self.pooler = nn.Dense(
+ self.config.hidden_size,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ dtype=self.dtype,
+ name="pooler",
+ )
+ self.pooler_activation = nn.tanh
+ else:
+ self.pooler = None
+ self.pooler_activation = None
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ token_type_ids: Optional[np.ndarray] = None,
+ position_ids: Optional[np.ndarray] = None,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # make sure `token_type_ids` is correctly initialized when not passed
+ if token_type_ids is None:
+ token_type_ids = jnp.zeros_like(input_ids)
+
+ # make sure `position_ids` is correctly initialized when not passed
+ if position_ids is None:
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic)
+
+ outputs = self.encoder(
+ hidden_states,
+ attention_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+ if self.add_pooling_layer:
+ pooled = self.pooler(hidden_states[:, 0])
+ pooled = self.pooler_activation(pooled)
+ else:
+ pooled = None
+
+ if not return_dict:
+ # if pooled is None, don't return it
+ if pooled is None:
+ return (hidden_states,) + outputs[1:]
+ return (hidden_states, pooled) + outputs[1:]
+
+ return FlaxBaseModelOutputWithPooling(
+ last_hidden_state=hidden_states,
+ pooler_output=pooled,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.",
+ ALBERT_START_DOCSTRING,
+)
+class FlaxAlbertModel(FlaxAlbertPreTrainedModel):
+ module_class = FlaxAlbertModule
+
+
+append_call_sample_docstring(FlaxAlbertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)
+
+
+class FlaxAlbertForPreTrainingModule(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
+ self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
+ self.sop_classifier = FlaxAlbertSOPHead(config=self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.albert(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
+ else:
+ shared_embedding = None
+
+ hidden_states = outputs[0]
+ pooled_output = outputs[1]
+
+ prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
+ sop_scores = self.sop_classifier(pooled_output, deterministic=deterministic)
+
+ if not return_dict:
+ return (prediction_scores, sop_scores) + outputs[2:]
+
+ return FlaxAlbertForPreTrainingOutput(
+ prediction_logits=prediction_scores,
+ sop_logits=sop_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
+ `sentence order prediction (classification)` head.
+ """,
+ ALBERT_START_DOCSTRING,
+)
+class FlaxAlbertForPreTraining(FlaxAlbertPreTrainedModel):
+ module_class = FlaxAlbertForPreTrainingModule
+
+
+FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING = """
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, FlaxAlbertForPreTraining
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
+ >>> model = FlaxAlbertForPreTraining.from_pretrained("albert/albert-base-v2")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.prediction_logits
+ >>> seq_relationship_logits = outputs.sop_logits
+ ```
+"""
+
+overwrite_call_docstring(
+ FlaxAlbertForPreTraining,
+ ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING,
+)
+append_replace_return_docstrings(
+ FlaxAlbertForPreTraining, output_type=FlaxAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
+)
+
+
+class FlaxAlbertForMaskedLMModule(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.albert = FlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
+ self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.albert(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
+ else:
+ shared_embedding = None
+
+ # Compute the prediction scores
+ logits = self.predictions(hidden_states, shared_embedding=shared_embedding)
+
+ if not return_dict:
+ return (logits,) + outputs[1:]
+
+ return FlaxMaskedLMOutput(
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING)
+class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel):
+ module_class = FlaxAlbertForMaskedLMModule
+
+
+append_call_sample_docstring(
+ FlaxAlbertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC, revision="refs/pr/11"
+)
+
+
+class FlaxAlbertForSequenceClassificationModule(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
+ classifier_dropout = (
+ self.config.classifier_dropout_prob
+ if self.config.classifier_dropout_prob is not None
+ else self.config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(rate=classifier_dropout)
+ self.classifier = nn.Dense(
+ self.config.num_labels,
+ dtype=self.dtype,
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.albert(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
+ logits = self.classifier(pooled_output)
+
+ if not return_dict:
+ return (logits,) + outputs[2:]
+
+ return FlaxSequenceClassifierOutput(
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """,
+ ALBERT_START_DOCSTRING,
+)
+class FlaxAlbertForSequenceClassification(FlaxAlbertPreTrainedModel):
+ module_class = FlaxAlbertForSequenceClassificationModule
+
+
+append_call_sample_docstring(
+ FlaxAlbertForSequenceClassification,
+ _CHECKPOINT_FOR_DOC,
+ FlaxSequenceClassifierOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+class FlaxAlbertForMultipleChoiceModule(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
+ self.classifier = nn.Dense(1, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ num_choices = input_ids.shape[1]
+ input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
+ attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
+ position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
+
+ # Model
+ outputs = self.albert(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
+ logits = self.classifier(pooled_output)
+
+ reshaped_logits = logits.reshape(-1, num_choices)
+
+ if not return_dict:
+ return (reshaped_logits,) + outputs[2:]
+
+ return FlaxMultipleChoiceModelOutput(
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ ALBERT_START_DOCSTRING,
+)
+class FlaxAlbertForMultipleChoice(FlaxAlbertPreTrainedModel):
+ module_class = FlaxAlbertForMultipleChoiceModule
+
+
+overwrite_call_docstring(
+ FlaxAlbertForMultipleChoice, ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
+)
+append_call_sample_docstring(
+ FlaxAlbertForMultipleChoice,
+ _CHECKPOINT_FOR_DOC,
+ FlaxMultipleChoiceModelOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+class FlaxAlbertForTokenClassificationModule(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
+ classifier_dropout = (
+ self.config.classifier_dropout_prob
+ if self.config.classifier_dropout_prob is not None
+ else self.config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(rate=classifier_dropout)
+ self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.albert(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ logits = self.classifier(hidden_states)
+
+ if not return_dict:
+ return (logits,) + outputs[1:]
+
+ return FlaxTokenClassifierOutput(
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ ALBERT_START_DOCSTRING,
+)
+class FlaxAlbertForTokenClassification(FlaxAlbertPreTrainedModel):
+ module_class = FlaxAlbertForTokenClassificationModule
+
+
+append_call_sample_docstring(
+ FlaxAlbertForTokenClassification,
+ _CHECKPOINT_FOR_DOC,
+ FlaxTokenClassifierOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+class FlaxAlbertForQuestionAnsweringModule(nn.Module):
+ config: AlbertConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
+ self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.albert(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ logits = self.qa_outputs(hidden_states)
+ start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
+ start_logits = start_logits.squeeze(-1)
+ end_logits = end_logits.squeeze(-1)
+
+ if not return_dict:
+ return (start_logits, end_logits) + outputs[1:]
+
+ return FlaxQuestionAnsweringModelOutput(
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ ALBERT_START_DOCSTRING,
+)
+class FlaxAlbertForQuestionAnswering(FlaxAlbertPreTrainedModel):
+ module_class = FlaxAlbertForQuestionAnsweringModule
+
+
+append_call_sample_docstring(
+ FlaxAlbertForQuestionAnswering,
+ _CHECKPOINT_FOR_DOC,
+ FlaxQuestionAnsweringModelOutput,
+ _CONFIG_FOR_DOC,
+)
+
+__all__ = [
+ "FlaxAlbertPreTrainedModel",
+ "FlaxAlbertModel",
+ "FlaxAlbertForPreTraining",
+ "FlaxAlbertForMaskedLM",
+ "FlaxAlbertForSequenceClassification",
+ "FlaxAlbertForMultipleChoice",
+ "FlaxAlbertForTokenClassification",
+ "FlaxAlbertForQuestionAnswering",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/albert/modeling_tf_albert.py b/venv/lib/python3.13/site-packages/transformers/models/albert/modeling_tf_albert.py
new file mode 100644
index 0000000000000000000000000000000000000000..101ab63dc0545992fe68a53205f4ad81c607d9ca
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/albert/modeling_tf_albert.py
@@ -0,0 +1,1572 @@
+# coding=utf-8
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 ALBERT model."""
+
+from __future__ import annotations
+
+import math
+from dataclasses import dataclass
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutput,
+ TFBaseModelOutputWithPooling,
+ TFMaskedLMOutput,
+ TFMultipleChoiceModelOutput,
+ TFQuestionAnsweringModelOutput,
+ TFSequenceClassifierOutput,
+ TFTokenClassifierOutput,
+)
+from ...modeling_tf_utils import (
+ TFMaskedLanguageModelingLoss,
+ TFModelInputType,
+ TFMultipleChoiceLoss,
+ TFPreTrainedModel,
+ TFQuestionAnsweringLoss,
+ TFSequenceClassificationLoss,
+ TFTokenClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_albert import AlbertConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "albert/albert-base-v2"
+_CONFIG_FOR_DOC = "AlbertConfig"
+
+
+class TFAlbertPreTrainingLoss:
+ """
+ Loss function suitable for ALBERT pretraining, that is, the task of pretraining a language model by combining SOP +
+ MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
+ """
+
+ def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
+ if self.config.tf_legacy_loss:
+ # make sure only labels that are not equal to -100
+ # are taken into account as loss
+ masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
+ masked_lm_reduced_logits = tf.boolean_mask(
+ tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
+ mask=masked_lm_active_loss,
+ )
+ masked_lm_labels = tf.boolean_mask(
+ tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
+ )
+ sentence_order_active_loss = tf.not_equal(
+ tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100
+ )
+ sentence_order_reduced_logits = tf.boolean_mask(
+ tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss
+ )
+ sentence_order_label = tf.boolean_mask(
+ tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss
+ )
+ masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
+ sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits)
+ masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0]))
+ masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)
+
+ return masked_lm_loss + sentence_order_loss
+
+ # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
+ unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
+ # make sure only labels that are not equal to -100
+ # are taken into account for the loss computation
+ lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
+ masked_lm_losses = unmasked_lm_losses * lm_loss_mask
+ reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask)
+
+ sop_logits = tf.reshape(logits[1], (-1, 2))
+ # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
+ unmasked_sop_loss = loss_fn(y_true=tf.nn.relu(labels["sentence_order_label"]), y_pred=sop_logits)
+ sop_loss_mask = tf.cast(labels["sentence_order_label"] != -100, dtype=unmasked_sop_loss.dtype)
+
+ masked_sop_loss = unmasked_sop_loss * sop_loss_mask
+ reduced_masked_sop_loss = tf.reduce_sum(masked_sop_loss) / tf.reduce_sum(sop_loss_mask)
+
+ return tf.reshape(reduced_masked_lm_loss + reduced_masked_sop_loss, (1,))
+
+
+class TFAlbertEmbeddings(keras.layers.Layer):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config: AlbertConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.embedding_size = config.embedding_size
+ self.max_position_embeddings = config.max_position_embeddings
+ self.initializer_range = config.initializer_range
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+
+ def build(self, input_shape=None):
+ with tf.name_scope("word_embeddings"):
+ self.weight = self.add_weight(
+ name="weight",
+ shape=[self.config.vocab_size, self.embedding_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ with tf.name_scope("token_type_embeddings"):
+ self.token_type_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.config.type_vocab_size, self.embedding_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ with tf.name_scope("position_embeddings"):
+ self.position_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.max_position_embeddings, self.embedding_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.embedding_size])
+
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
+ def call(
+ self,
+ input_ids: tf.Tensor | None = None,
+ position_ids: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ past_key_values_length=0,
+ training: bool = False,
+ ) -> tf.Tensor:
+ """
+ Applies embedding based on inputs tensor.
+
+ Returns:
+ final_embeddings (`tf.Tensor`): output embedding tensor.
+ """
+ if input_ids is None and inputs_embeds is None:
+ raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
+
+ if input_ids is not None:
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
+ inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
+
+ input_shape = shape_list(inputs_embeds)[:-1]
+
+ if token_type_ids is None:
+ token_type_ids = tf.fill(dims=input_shape, value=0)
+
+ if position_ids is None:
+ position_ids = tf.expand_dims(
+ tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
+ )
+
+ position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
+ token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
+ final_embeddings = inputs_embeds + position_embeds + token_type_embeds
+ final_embeddings = self.LayerNorm(inputs=final_embeddings)
+ final_embeddings = self.dropout(inputs=final_embeddings, training=training)
+
+ return final_embeddings
+
+
+class TFAlbertAttention(keras.layers.Layer):
+ """Contains the complete attention sublayer, including both dropouts and layer norm."""
+
+ def __init__(self, config: AlbertConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number "
+ f"of attention heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
+ self.output_attentions = config.output_attentions
+
+ self.query = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
+ )
+ self.key = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
+ )
+ self.value = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
+ )
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ # Two different dropout probabilities; see https://github.com/google-research/albert/blob/master/modeling.py#L971-L993
+ self.attention_dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
+ self.output_dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
+
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
+
+ def call(
+ self,
+ input_tensor: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ batch_size = shape_list(input_tensor)[0]
+ mixed_query_layer = self.query(inputs=input_tensor)
+ mixed_key_layer = self.key(inputs=input_tensor)
+ mixed_value_layer = self.value(inputs=input_tensor)
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
+ key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
+ value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ # (batch size, num_heads, seq_len_q, seq_len_k)
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+ dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
+ attention_scores = tf.divide(attention_scores, dk)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in TFAlbertModel call() function)
+ attention_scores = tf.add(attention_scores, attention_mask)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.attention_dropout(inputs=attention_probs, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = tf.multiply(attention_probs, head_mask)
+
+ context_layer = tf.matmul(attention_probs, value_layer)
+ context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
+
+ # (batch_size, seq_len_q, all_head_size)
+ context_layer = tf.reshape(tensor=context_layer, shape=(batch_size, -1, self.all_head_size))
+ self_outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+ hidden_states = self_outputs[0]
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.output_dropout(inputs=hidden_states, training=training)
+ attention_output = self.LayerNorm(inputs=hidden_states + input_tensor)
+
+ # add attentions if we output them
+ outputs = (attention_output,) + self_outputs[1:]
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "query", None) is not None:
+ with tf.name_scope(self.query.name):
+ self.query.build([None, None, self.config.hidden_size])
+ if getattr(self, "key", None) is not None:
+ with tf.name_scope(self.key.name):
+ self.key.build([None, None, self.config.hidden_size])
+ if getattr(self, "value", None) is not None:
+ with tf.name_scope(self.value.name):
+ self.value.build([None, None, self.config.hidden_size])
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+class TFAlbertLayer(keras.layers.Layer):
+ def __init__(self, config: AlbertConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.attention = TFAlbertAttention(config, name="attention")
+ self.ffn = keras.layers.Dense(
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn"
+ )
+
+ if isinstance(config.hidden_act, str):
+ self.activation = get_tf_activation(config.hidden_act)
+ else:
+ self.activation = config.hidden_act
+
+ self.ffn_output = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn_output"
+ )
+ self.full_layer_layer_norm = keras.layers.LayerNormalization(
+ epsilon=config.layer_norm_eps, name="full_layer_layer_norm"
+ )
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ attention_outputs = self.attention(
+ input_tensor=hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ ffn_output = self.ffn(inputs=attention_outputs[0])
+ ffn_output = self.activation(ffn_output)
+ ffn_output = self.ffn_output(inputs=ffn_output)
+ ffn_output = self.dropout(inputs=ffn_output, training=training)
+ hidden_states = self.full_layer_layer_norm(inputs=ffn_output + attention_outputs[0])
+
+ # add attentions if we output them
+ outputs = (hidden_states,) + attention_outputs[1:]
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "ffn", None) is not None:
+ with tf.name_scope(self.ffn.name):
+ self.ffn.build([None, None, self.config.hidden_size])
+ if getattr(self, "ffn_output", None) is not None:
+ with tf.name_scope(self.ffn_output.name):
+ self.ffn_output.build([None, None, self.config.intermediate_size])
+ if getattr(self, "full_layer_layer_norm", None) is not None:
+ with tf.name_scope(self.full_layer_layer_norm.name):
+ self.full_layer_layer_norm.build([None, None, self.config.hidden_size])
+
+
+class TFAlbertLayerGroup(keras.layers.Layer):
+ def __init__(self, config: AlbertConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.albert_layers = [
+ TFAlbertLayer(config, name=f"albert_layers_._{i}") for i in range(config.inner_group_num)
+ ]
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ output_hidden_states: bool,
+ training: bool = False,
+ ) -> TFBaseModelOutput | tuple[tf.Tensor]:
+ layer_hidden_states = () if output_hidden_states else None
+ layer_attentions = () if output_attentions else None
+
+ for layer_index, albert_layer in enumerate(self.albert_layers):
+ if output_hidden_states:
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
+
+ layer_output = albert_layer(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask[layer_index],
+ output_attentions=output_attentions,
+ training=training,
+ )
+ hidden_states = layer_output[0]
+
+ if output_attentions:
+ layer_attentions = layer_attentions + (layer_output[1],)
+
+ # Add last layer
+ if output_hidden_states:
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
+
+ return tuple(v for v in [hidden_states, layer_hidden_states, layer_attentions] if v is not None)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "albert_layers", None) is not None:
+ for layer in self.albert_layers:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+class TFAlbertTransformer(keras.layers.Layer):
+ def __init__(self, config: AlbertConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.num_hidden_layers = config.num_hidden_layers
+ self.num_hidden_groups = config.num_hidden_groups
+ # Number of layers in a hidden group
+ self.layers_per_group = int(config.num_hidden_layers / config.num_hidden_groups)
+ self.embedding_hidden_mapping_in = keras.layers.Dense(
+ units=config.hidden_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="embedding_hidden_mapping_in",
+ )
+ self.albert_layer_groups = [
+ TFAlbertLayerGroup(config, name=f"albert_layer_groups_._{i}") for i in range(config.num_hidden_groups)
+ ]
+ self.config = config
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ output_hidden_states: bool,
+ return_dict: bool,
+ training: bool = False,
+ ) -> TFBaseModelOutput | tuple[tf.Tensor]:
+ hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states)
+ all_attentions = () if output_attentions else None
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
+
+ for i in range(self.num_hidden_layers):
+ # Index of the hidden group
+ group_idx = int(i / (self.num_hidden_layers / self.num_hidden_groups))
+ layer_group_output = self.albert_layer_groups[group_idx](
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask[group_idx * self.layers_per_group : (group_idx + 1) * self.layers_per_group],
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ training=training,
+ )
+ hidden_states = layer_group_output[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + layer_group_output[-1]
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
+
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embedding_hidden_mapping_in", None) is not None:
+ with tf.name_scope(self.embedding_hidden_mapping_in.name):
+ self.embedding_hidden_mapping_in.build([None, None, self.config.embedding_size])
+ if getattr(self, "albert_layer_groups", None) is not None:
+ for layer in self.albert_layer_groups:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+class TFAlbertPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = AlbertConfig
+ base_model_prefix = "albert"
+
+
+class TFAlbertMLMHead(keras.layers.Layer):
+ def __init__(self, config: AlbertConfig, input_embeddings: keras.layers.Layer, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.embedding_size = config.embedding_size
+ self.dense = keras.layers.Dense(
+ config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ if isinstance(config.hidden_act, str):
+ self.activation = get_tf_activation(config.hidden_act)
+ else:
+ self.activation = config.hidden_act
+
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = input_embeddings
+
+ def build(self, input_shape=None):
+ self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
+ self.decoder_bias = self.add_weight(
+ shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias"
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.embedding_size])
+
+ def get_output_embeddings(self) -> keras.layers.Layer:
+ return self.decoder
+
+ def set_output_embeddings(self, value: tf.Variable):
+ self.decoder.weight = value
+ self.decoder.vocab_size = shape_list(value)[0]
+
+ def get_bias(self) -> dict[str, tf.Variable]:
+ return {"bias": self.bias, "decoder_bias": self.decoder_bias}
+
+ def set_bias(self, value: tf.Variable):
+ self.bias = value["bias"]
+ self.decoder_bias = value["decoder_bias"]
+ self.config.vocab_size = shape_list(value["bias"])[0]
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.LayerNorm(inputs=hidden_states)
+ seq_length = shape_list(tensor=hidden_states)[1]
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
+ hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True)
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
+ hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.decoder_bias)
+
+ return hidden_states
+
+
+@keras_serializable
+class TFAlbertMainLayer(keras.layers.Layer):
+ config_class = AlbertConfig
+
+ def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+
+ self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
+ self.encoder = TFAlbertTransformer(config, name="encoder")
+ self.pooler = (
+ keras.layers.Dense(
+ units=config.hidden_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation="tanh",
+ name="pooler",
+ )
+ if add_pooling_layer
+ else None
+ )
+
+ def get_input_embeddings(self) -> keras.layers.Layer:
+ return self.embeddings
+
+ def set_input_embeddings(self, value: tf.Variable):
+ self.embeddings.weight = value
+ self.embeddings.vocab_size = shape_list(value)[0]
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ raise NotImplementedError
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if attention_mask is None:
+ attention_mask = tf.fill(dims=input_shape, value=1)
+
+ if token_type_ids is None:
+ token_type_ids = tf.fill(dims=input_shape, value=0)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ training=training,
+ )
+
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
+ one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
+ ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
+ extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.config.num_hidden_layers
+
+ encoder_outputs = self.encoder(
+ hidden_states=embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(inputs=sequence_output[:, 0]) if self.pooler is not None else None
+
+ if not return_dict:
+ return (
+ sequence_output,
+ pooled_output,
+ ) + encoder_outputs[1:]
+
+ return TFBaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "pooler", None) is not None:
+ with tf.name_scope(self.pooler.name):
+ self.pooler.build([None, None, self.config.hidden_size])
+
+
+@dataclass
+class TFAlbertForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`TFAlbertForPreTraining`].
+
+ Args:
+ prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ sop_logits (`tf.Tensor` of shape `(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+ before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: tf.Tensor | None = None
+ prediction_logits: tf.Tensor | None = None
+ sop_logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+ALBERT_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Args:
+ config ([`AlbertConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ALBERT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.",
+ ALBERT_START_DOCSTRING,
+)
+class TFAlbertModel(TFAlbertPreTrainedModel):
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.albert = TFAlbertMainLayer(config, name="albert")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
+ outputs = self.albert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "albert", None) is not None:
+ with tf.name_scope(self.albert.name):
+ self.albert.build(None)
+
+
+@add_start_docstrings(
+ """
+ Albert Model with two heads on top for pretraining: a `masked language modeling` head and a `sentence order
+ prediction` (classification) head.
+ """,
+ ALBERT_START_DOCSTRING,
+)
+class TFAlbertForPreTraining(TFAlbertPreTrainedModel, TFAlbertPreTrainingLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"predictions.decoder.weight"]
+
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+
+ self.albert = TFAlbertMainLayer(config, name="albert")
+ self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions")
+ self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier")
+
+ def get_lm_head(self) -> keras.layers.Layer:
+ return self.predictions
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ sentence_order_label: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFAlbertForPreTrainingOutput | tuple[tf.Tensor]:
+ r"""
+ Return:
+
+ Example:
+
+ ```python
+ >>> import tensorflow as tf
+ >>> from transformers import AutoTokenizer, TFAlbertForPreTraining
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
+ >>> model = TFAlbertForPreTraining.from_pretrained("albert/albert-base-v2")
+
+ >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]
+ >>> # Batch size 1
+ >>> outputs = model(input_ids)
+
+ >>> prediction_logits = outputs.prediction_logits
+ >>> sop_logits = outputs.sop_logits
+ ```"""
+
+ outputs = self.albert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output, pooled_output = outputs[:2]
+ prediction_scores = self.predictions(hidden_states=sequence_output)
+ sop_scores = self.sop_classifier(pooled_output=pooled_output, training=training)
+ total_loss = None
+
+ if labels is not None and sentence_order_label is not None:
+ d_labels = {"labels": labels}
+ d_labels["sentence_order_label"] = sentence_order_label
+ total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, sop_scores))
+
+ if not return_dict:
+ output = (prediction_scores, sop_scores) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return TFAlbertForPreTrainingOutput(
+ loss=total_loss,
+ prediction_logits=prediction_scores,
+ sop_logits=sop_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "albert", None) is not None:
+ with tf.name_scope(self.albert.name):
+ self.albert.build(None)
+ if getattr(self, "predictions", None) is not None:
+ with tf.name_scope(self.predictions.name):
+ self.predictions.build(None)
+ if getattr(self, "sop_classifier", None) is not None:
+ with tf.name_scope(self.sop_classifier.name):
+ self.sop_classifier.build(None)
+
+
+class TFAlbertSOPHead(keras.layers.Layer):
+ def __init__(self, config: AlbertConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob)
+ self.classifier = keras.layers.Dense(
+ units=config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="classifier",
+ )
+ self.config = config
+
+ def call(self, pooled_output: tf.Tensor, training: bool) -> tf.Tensor:
+ dropout_pooled_output = self.dropout(inputs=pooled_output, training=training)
+ logits = self.classifier(inputs=dropout_pooled_output)
+
+ return logits
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING)
+class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions.decoder.weight"]
+
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
+ self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions")
+
+ def get_lm_head(self) -> keras.layers.Layer:
+ return self.predictions
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFMaskedLMOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import tensorflow as tf
+ >>> from transformers import AutoTokenizer, TFAlbertForMaskedLM
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
+ >>> model = TFAlbertForMaskedLM.from_pretrained("albert/albert-base-v2")
+
+ >>> # add mask_token
+ >>> inputs = tokenizer(f"The capital of [MASK] is Paris.", return_tensors="tf")
+ >>> logits = model(**inputs).logits
+
+ >>> # retrieve index of [MASK]
+ >>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1]
+ >>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1)
+ >>> tokenizer.decode(predicted_token_id)
+ 'france'
+ ```
+
+ ```python
+ >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
+ >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
+ >>> outputs = model(**inputs, labels=labels)
+ >>> round(float(outputs.loss), 2)
+ 0.81
+ ```
+ """
+ outputs = self.albert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ prediction_scores = self.predictions(hidden_states=sequence_output, training=training)
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFMaskedLMOutput(
+ loss=loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "albert", None) is not None:
+ with tf.name_scope(self.albert.name):
+ self.albert.build(None)
+ if getattr(self, "predictions", None) is not None:
+ with tf.name_scope(self.predictions.name):
+ self.predictions.build(None)
+
+
+@add_start_docstrings(
+ """
+ Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """,
+ ALBERT_START_DOCSTRING,
+)
+class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClassificationLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"predictions"]
+ _keys_to_ignore_on_load_missing = [r"dropout"]
+
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+
+ self.albert = TFAlbertMainLayer(config, name="albert")
+ self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob)
+ self.classifier = keras.layers.Dense(
+ units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="vumichien/albert-base-v2-imdb",
+ output_type=TFSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output="'LABEL_1'",
+ expected_loss=0.12,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ outputs = self.albert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ pooled_output = outputs[1]
+ pooled_output = self.dropout(inputs=pooled_output, training=training)
+ logits = self.classifier(inputs=pooled_output)
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "albert", None) is not None:
+ with tf.name_scope(self.albert.name):
+ self.albert.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ ALBERT_START_DOCSTRING,
+)
+class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
+ _keys_to_ignore_on_load_missing = [r"dropout"]
+
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+
+ self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
+ classifier_dropout_prob = (
+ config.classifier_dropout_prob
+ if config.classifier_dropout_prob is not None
+ else config.hidden_dropout_prob
+ )
+ self.dropout = keras.layers.Dropout(rate=classifier_dropout_prob)
+ self.classifier = keras.layers.Dense(
+ units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFTokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFTokenClassifierOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ outputs = self.albert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(inputs=sequence_output, training=training)
+ logits = self.classifier(inputs=sequence_output)
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFTokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "albert", None) is not None:
+ with tf.name_scope(self.albert.name):
+ self.albert.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ ALBERT_START_DOCSTRING,
+)
+class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
+
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+
+ self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
+ self.qa_outputs = keras.layers.Dense(
+ units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="vumichien/albert-base-v2-squad2",
+ output_type=TFQuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ qa_target_start_index=12,
+ qa_target_end_index=13,
+ expected_output="'a nice puppet'",
+ expected_loss=7.36,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ start_positions: np.ndarray | tf.Tensor | None = None,
+ end_positions: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]:
+ r"""
+ start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ outputs = self.albert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ logits = self.qa_outputs(inputs=sequence_output)
+ start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
+ start_logits = tf.squeeze(input=start_logits, axis=-1)
+ end_logits = tf.squeeze(input=end_logits, axis=-1)
+ loss = None
+
+ if start_positions is not None and end_positions is not None:
+ labels = {"start_position": start_positions}
+ labels["end_position"] = end_positions
+ loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFQuestionAnsweringModelOutput(
+ loss=loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "albert", None) is not None:
+ with tf.name_scope(self.albert.name):
+ self.albert.build(None)
+ if getattr(self, "qa_outputs", None) is not None:
+ with tf.name_scope(self.qa_outputs.name):
+ self.qa_outputs.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ ALBERT_START_DOCSTRING,
+)
+class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
+ _keys_to_ignore_on_load_missing = [r"dropout"]
+
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.albert = TFAlbertMainLayer(config, name="albert")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.classifier = keras.layers.Dense(
+ units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFMultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
+ where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
+ """
+
+ if input_ids is not None:
+ num_choices = shape_list(input_ids)[1]
+ seq_length = shape_list(input_ids)[2]
+ else:
+ num_choices = shape_list(inputs_embeds)[1]
+ seq_length = shape_list(inputs_embeds)[2]
+
+ flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
+ flat_attention_mask = (
+ tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
+ )
+ flat_token_type_ids = (
+ tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
+ )
+ flat_position_ids = (
+ tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None
+ )
+ flat_inputs_embeds = (
+ tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]))
+ if inputs_embeds is not None
+ else None
+ )
+ outputs = self.albert(
+ input_ids=flat_input_ids,
+ attention_mask=flat_attention_mask,
+ token_type_ids=flat_token_type_ids,
+ position_ids=flat_position_ids,
+ head_mask=head_mask,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ pooled_output = outputs[1]
+ pooled_output = self.dropout(inputs=pooled_output, training=training)
+ logits = self.classifier(inputs=pooled_output)
+ reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFMultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "albert", None) is not None:
+ with tf.name_scope(self.albert.name):
+ self.albert.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+__all__ = [
+ "TFAlbertPreTrainedModel",
+ "TFAlbertModel",
+ "TFAlbertForPreTraining",
+ "TFAlbertForMaskedLM",
+ "TFAlbertForSequenceClassification",
+ "TFAlbertForTokenClassification",
+ "TFAlbertForQuestionAnswering",
+ "TFAlbertForMultipleChoice",
+ "TFAlbertMainLayer",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/albert/tokenization_albert.py b/venv/lib/python3.13/site-packages/transformers/models/albert/tokenization_albert.py
new file mode 100644
index 0000000000000000000000000000000000000000..011ad689edbdb10f53694eb9c774604d922a0d73
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/albert/tokenization_albert.py
@@ -0,0 +1,320 @@
+# coding=utf-8
+# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for ALBERT model."""
+
+import os
+import unicodedata
+from shutil import copyfile
+from typing import Any, Optional
+
+import sentencepiece as spm
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+from ...utils.import_utils import requires
+
+
+logger = logging.get_logger(__name__)
+VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
+
+
+SPIECE_UNDERLINE = "▁"
+
+
+@requires(backends=("sentencepiece",))
+class AlbertTokenizer(PreTrainedTokenizer):
+ """
+ Construct an ALBERT tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ remove_space (`bool`, *optional*, defaults to `True`):
+ Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
+ keep_accents (`bool`, *optional*, defaults to `False`):
+ Whether or not to keep accents when tokenizing.
+ bos_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ sp_model_kwargs (`dict`, *optional*):
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+ to set:
+
+ - `enable_sampling`: Enable subword regularization.
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+ - `nbest_size = {0,1}`: No sampling is performed.
+ - `nbest_size > 1`: samples from the nbest_size results.
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+ using forward-filtering-and-backward-sampling algorithm.
+
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+ BPE-dropout.
+
+ Attributes:
+ sp_model (`SentencePieceProcessor`):
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=True,
+ remove_space=True,
+ keep_accents=False,
+ bos_token="[CLS]",
+ eos_token="[SEP]",
+ unk_token="",
+ sep_token="[SEP]",
+ pad_token="",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ sp_model_kwargs: Optional[dict[str, Any]] = None,
+ **kwargs,
+ ) -> None:
+ # Mask token behave like a normal word, i.e. include the space before it and
+ # is included in the raw text, there should be a match in a non-normalized sentence.
+ mask_token = (
+ AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
+ if isinstance(mask_token, str)
+ else mask_token
+ )
+
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+
+ self.do_lower_case = do_lower_case
+ self.remove_space = remove_space
+ self.keep_accents = keep_accents
+ self.vocab_file = vocab_file
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(vocab_file)
+
+ super().__init__(
+ do_lower_case=do_lower_case,
+ remove_space=remove_space,
+ keep_accents=keep_accents,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ sp_model_kwargs=self.sp_model_kwargs,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self) -> int:
+ return len(self.sp_model)
+
+ def get_vocab(self) -> dict[str, int]:
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+
+ # for backward compatibility
+ if not hasattr(self, "sp_model_kwargs"):
+ self.sp_model_kwargs = {}
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(self.vocab_file)
+
+ def preprocess_text(self, inputs):
+ if self.remove_space:
+ outputs = " ".join(inputs.strip().split())
+ else:
+ outputs = inputs
+ outputs = outputs.replace("``", '"').replace("''", '"')
+
+ if not self.keep_accents:
+ outputs = unicodedata.normalize("NFKD", outputs)
+ outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
+ if self.do_lower_case:
+ outputs = outputs.lower()
+
+ return outputs
+
+ def _tokenize(self, text: str) -> list[str]:
+ """Tokenize a string."""
+ text = self.preprocess_text(text)
+ pieces = self.sp_model.encode(text, out_type=str)
+ new_pieces = []
+ for piece in pieces:
+ if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
+ # Logic to handle special cases see https://github.com/google-research/bert/blob/master/README.md#tokenization
+ # `9,9` -> ['▁9', ',', '9'] instead of [`_9,`, '9']
+ cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
+ if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
+ if len(cur_pieces[0]) == 1:
+ cur_pieces = cur_pieces[1:]
+ else:
+ cur_pieces[0] = cur_pieces[0][1:]
+ cur_pieces.append(piece[-1])
+ new_pieces.extend(cur_pieces)
+ else:
+ new_pieces.append(piece)
+
+ return new_pieces
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.sp_model.PieceToId(token)
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.sp_model.IdToPiece(index)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ current_sub_tokens = []
+ out_string = ""
+ prev_is_special = False
+ for token in tokens:
+ # make sure that special tokens are not decoded using sentencepiece model
+ if token in self.all_special_tokens:
+ if not prev_is_special:
+ out_string += " "
+ out_string += self.sp_model.decode(current_sub_tokens) + token
+ prev_is_special = True
+ current_sub_tokens = []
+ else:
+ current_sub_tokens.append(token)
+ prev_is_special = False
+ out_string += self.sp_model.decode(current_sub_tokens)
+ return out_string.strip()
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. An ALBERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return cls + token_ids_0 + sep
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+
+__all__ = ["AlbertTokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/albert/tokenization_albert_fast.py b/venv/lib/python3.13/site-packages/transformers/models/albert/tokenization_albert_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed9add51d20743948dc1fe51ad6f5fe0c1ed1543
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/albert/tokenization_albert_fast.py
@@ -0,0 +1,178 @@
+# coding=utf-8
+# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for ALBERT model."""
+
+import os
+from shutil import copyfile
+from typing import Optional
+
+from ...tokenization_utils import AddedToken
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import is_sentencepiece_available, logging
+
+
+if is_sentencepiece_available():
+ from .tokenization_albert import AlbertTokenizer
+else:
+ AlbertTokenizer = None
+
+logger = logging.get_logger(__name__)
+VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
+
+
+SPIECE_UNDERLINE = "▁"
+
+
+class AlbertTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" ALBERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on
+ [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). This
+ tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods
+
+ Args:
+ vocab_file (`str`):
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ remove_space (`bool`, *optional*, defaults to `True`):
+ Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
+ keep_accents (`bool`, *optional*, defaults to `False`):
+ Whether or not to keep accents when tokenizing.
+ bos_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The end of sequence token. .. note:: When building a sequence using special tokens, this is not the token
+ that is used for the end of sequence. The token used is the `sep_token`.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = AlbertTokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ do_lower_case=True,
+ remove_space=True,
+ keep_accents=False,
+ bos_token="[CLS]",
+ eos_token="[SEP]",
+ unk_token="",
+ sep_token="[SEP]",
+ pad_token="",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ **kwargs,
+ ):
+ # Mask token behave like a normal word, i.e. include the space before it and
+ # is included in the raw text, there should be a match in a non-normalized sentence.
+ mask_token = (
+ AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
+ if isinstance(mask_token, str)
+ else mask_token
+ )
+
+ super().__init__(
+ vocab_file,
+ tokenizer_file=tokenizer_file,
+ do_lower_case=do_lower_case,
+ remove_space=remove_space,
+ keep_accents=keep_accents,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ **kwargs,
+ )
+
+ self.do_lower_case = do_lower_case
+ self.remove_space = remove_space
+ self.keep_accents = keep_accents
+ self.vocab_file = vocab_file
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. An ALBERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return cls + token_ids_0 + sep
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not self.can_save_slow_tokenizer:
+ raise ValueError(
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+ "tokenizer."
+ )
+
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
+
+
+__all__ = ["AlbertTokenizerFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/apertus/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/apertus/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dea6f28438b45454f6a66c04c9b3076e0dedceb8
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/apertus/__init__.py
@@ -0,0 +1,32 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
+#
+# This code is based on HuggingFace's LLaMA implementation in this library.
+# It has been modified from its original forms to accommodate the architectural
+# differences made by the Swiss AI Initiative that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_apertus import *
+ from .modeling_apertus import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/apertus/configuration_apertus.py b/venv/lib/python3.13/site-packages/transformers/models/apertus/configuration_apertus.py
new file mode 100644
index 0000000000000000000000000000000000000000..180ad756dc8839ff090abf7eb21453d7658d6be4
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/apertus/configuration_apertus.py
@@ -0,0 +1,214 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/apertus/modular_apertus.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_apertus.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 the HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class ApertusConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ApertusModel`]. It is used to instantiate a Apertus
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Apertus-8B.
+ e.g. [swiss-ai/Apertus-8B](https://huggingface.co/swiss-ai/Apertus-8B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 131072):
+ Vocabulary size of the Apertus model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ApertusModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 14336):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"xielu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 65536):
+ The maximum sequence length that this model might ever be used with. Apertus supports up to 65536 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 3):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 12000000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+
+ ```python
+ >>> from transformers import ApertusModel, ApertusConfig
+
+ >>> # Initializing a Apertus-8B style configuration
+ >>> configuration = ApertusConfig()
+
+ >>> # Initializing a model from the Apertus-8B style configuration
+ >>> model = ApertusModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "apertus"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=131072,
+ hidden_size=4096,
+ intermediate_size=14336,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="xielu",
+ max_position_embeddings=65536,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=3,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=12000000.0,
+ rope_scaling={
+ "rope_type": "llama3",
+ "factor": 8.0,
+ "original_max_position_embeddings": 8192,
+ "low_freq_factor": 1.0,
+ "high_freq_factor": 4.0,
+ },
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+
+__all__ = ["ApertusConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/apertus/modeling_apertus.py b/venv/lib/python3.13/site-packages/transformers/models/apertus/modeling_apertus.py
new file mode 100644
index 0000000000000000000000000000000000000000..a121146d86989fbad6ece183259df0e06be3dbe4
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/apertus/modeling_apertus.py
@@ -0,0 +1,488 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/apertus/modular_apertus.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_apertus.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 the HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_layers import GenericForTokenClassification, GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_apertus import ApertusConfig
+
+
+class ApertusMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.up_proj(x)))
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class ApertusRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ ApertusRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class ApertusRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: ApertusConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class ApertusAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
+ self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = self.q_norm(query_states)
+ key_states = self.k_norm(key_states)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class ApertusDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: ApertusConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = ApertusAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = ApertusMLP(config)
+ self.attention_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.feedforward_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor]:
+ residual = hidden_states
+ hidden_states = self.attention_layernorm(hidden_states)
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.feedforward_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class ApertusPreTrainedModel(PreTrainedModel):
+ config: ApertusConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["ApertusDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": ApertusDecoderLayer,
+ "attentions": ApertusAttention,
+ }
+
+
+@auto_docstring
+class ApertusModel(ApertusPreTrainedModel):
+ def __init__(self, config: ApertusConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [ApertusDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = ApertusRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class ApertusForCausalLM(ApertusPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = ApertusModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, ApertusForCausalLM
+
+ >>> model = ApertusForCausalLM.from_pretrained("swiss-ai/Apertus-8B")
+ >>> tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class ApertusForTokenClassification(GenericForTokenClassification, ApertusPreTrainedModel):
+ pass
+
+
+__all__ = ["ApertusModel", "ApertusForCausalLM", "ApertusForTokenClassification", "ApertusPreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/apertus/modular_apertus.py b/venv/lib/python3.13/site-packages/transformers/models/apertus/modular_apertus.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8d1e3f815c0f828d8517da4a87e620528b6dd98
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/apertus/modular_apertus.py
@@ -0,0 +1,371 @@
+# coding=utf-8
+# Copyright 2025 the HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional
+
+import torch
+from torch import nn
+
+from ...cache_utils import Cache
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, logging
+from ..llama.configuration_llama import LlamaConfig
+from ..llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaForCausalLM,
+ LlamaForTokenClassification,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ apply_rotary_pos_emb,
+ eager_attention_forward,
+)
+from ..nemotron.modeling_nemotron import NemotronMLP
+
+
+logger = logging.get_logger(__name__)
+
+
+class ApertusConfig(LlamaConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ApertusModel`]. It is used to instantiate a Apertus
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Apertus-8B.
+ e.g. [swiss-ai/Apertus-8B](https://huggingface.co/swiss-ai/Apertus-8B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 131072):
+ Vocabulary size of the Apertus model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ApertusModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 14336):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"xielu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 65536):
+ The maximum sequence length that this model might ever be used with. Apertus supports up to 65536 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 3):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 12000000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+
+ ```python
+ >>> from transformers import ApertusModel, ApertusConfig
+
+ >>> # Initializing a Apertus-8B style configuration
+ >>> configuration = ApertusConfig()
+
+ >>> # Initializing a model from the Apertus-8B style configuration
+ >>> model = ApertusModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "apertus"
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ }
+
+ def __init__(
+ self,
+ vocab_size=131072,
+ hidden_size=4096,
+ intermediate_size=14336,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="xielu",
+ max_position_embeddings=65536,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=3,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=12000000.0,
+ rope_scaling={
+ "rope_type": "llama3",
+ "factor": 8.0,
+ "original_max_position_embeddings": 8192,
+ "low_freq_factor": 1.0,
+ "high_freq_factor": 4.0,
+ },
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ hidden_act=hidden_act,
+ max_position_embeddings=max_position_embeddings,
+ initializer_range=initializer_range,
+ rms_norm_eps=rms_norm_eps,
+ use_cache=use_cache,
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ **kwargs,
+ )
+ del self.pretraining_tp
+ del self.mlp_bias
+ del self.head_dim
+
+
+class ApertusMLP(NemotronMLP):
+ def __init__(self, config):
+ super().__init__()
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+
+
+class ApertusRMSNorm(LlamaRMSNorm):
+ pass
+
+
+class ApertusRotaryEmbedding(LlamaRotaryEmbedding):
+ pass
+
+
+class ApertusAttention(LlamaAttention):
+ def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+ self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
+ self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = self.q_norm(query_states)
+ key_states = self.k_norm(key_states)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class ApertusDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: ApertusConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.attention_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.feedforward_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ del self.input_layernorm
+ del self.post_attention_layernorm
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor]:
+ residual = hidden_states
+ hidden_states = self.attention_layernorm(hidden_states)
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.feedforward_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class ApertusPreTrainedModel(LlamaPreTrainedModel):
+ pass
+
+
+class ApertusModel(LlamaModel):
+ pass
+
+
+class ApertusForCausalLM(LlamaForCausalLM):
+ def forward(self, **super_kwargs):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, ApertusForCausalLM
+
+ >>> model = ApertusForCausalLM.from_pretrained("swiss-ai/Apertus-8B")
+ >>> tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ return super().forward(**super_kwargs)
+
+
+class ApertusForTokenClassification(LlamaForTokenClassification):
+ pass
+
+
+__all__ = [
+ "ApertusConfig",
+ "ApertusModel",
+ "ApertusForCausalLM",
+ "ApertusForTokenClassification",
+ "ApertusPreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/arcee/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/arcee/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c3df45b2a3b14a72b36362c87833298be8fce48
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/arcee/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_arcee import *
+ from .modeling_arcee import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/arcee/configuration_arcee.py b/venv/lib/python3.13/site-packages/transformers/models/arcee/configuration_arcee.py
new file mode 100644
index 0000000000000000000000000000000000000000..5793697311bdc5906836811387c8a47f9db48dd9
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/arcee/configuration_arcee.py
@@ -0,0 +1,201 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/arcee/modular_arcee.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_arcee.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class ArceeConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ArceeModel`]. It is used to instantiate an Arcee
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the AFM-4.5B-Base.
+
+ Pre-trained weights are available at
+ [arcee-ai/AFM-4.5B](https://huggingface.co/arcee-ai/AFM-4.5B)
+ and were used to build the examples below.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the Arcee model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ArceeModel`]
+ hidden_size (`int`, *optional*, defaults to 2560):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 18432):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
+ The maximum sequence length that this model might ever be used with. AFM-4.5B-Base supports up to 16384 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 128000):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 128001):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'yarn'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'yarn'. The original max position embeddings used during pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn'. The scaling factor to be applied on the attention computation. If unspecified,
+ it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+ head_dim (`int`, *optional*):
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
+
+ ```python
+ >>> from transformers import ArceeModel, ArceeConfig
+
+ >>> # Initializing an Arcee AFM-4.5B-Base style configuration
+ >>> configuration = ArceeConfig()
+
+ >>> # Initializing a model from the AFM-4.5B-Base style configuration
+ >>> model = ArceeModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "arcee"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=2560,
+ intermediate_size=18432,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="relu2",
+ max_position_embeddings=4096,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=128000,
+ eos_token_id=128001,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ mlp_bias=False,
+ head_dim=None,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.mlp_bias = mlp_bias
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+
+__all__ = ["ArceeConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/arcee/modeling_arcee.py b/venv/lib/python3.13/site-packages/transformers/models/arcee/modeling_arcee.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dc4ba885af2f704ae275a51b6f4aef0d7451174
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/arcee/modeling_arcee.py
@@ -0,0 +1,506 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/arcee/modular_arcee.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_arcee.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from transformers.utils import auto_docstring
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_layers import (
+ GenericForQuestionAnswering,
+ GenericForSequenceClassification,
+ GenericForTokenClassification,
+ GradientCheckpointingLayer,
+)
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_arcee import ArceeConfig
+
+
+class ArceeMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.up_proj(x)))
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class ArceeRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ ArceeRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class ArceeRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: ArceeConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class ArceeAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: ArceeConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class ArceeDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: ArceeConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = ArceeAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = ArceeMLP(config)
+ self.input_layernorm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class ArceePreTrainedModel(PreTrainedModel):
+ config: ArceeConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["ArceeDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": ArceeDecoderLayer,
+ "attentions": ArceeAttention,
+ }
+
+
+@auto_docstring
+class ArceeModel(ArceePreTrainedModel):
+ def __init__(self, config: ArceeConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [ArceeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = ArceeRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
+class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = ArceeModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, ArceeForCausalLM
+
+ >>> model = ArceeForCausalLM.from_pretrained("meta-arcee/Arcee-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-arcee/Arcee-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
+class ArceeForSequenceClassification(GenericForSequenceClassification, ArceePreTrainedModel):
+ pass
+
+
+@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
+class ArceeForQuestionAnswering(GenericForQuestionAnswering, ArceePreTrainedModel):
+ base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
+
+
+@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
+class ArceeForTokenClassification(GenericForTokenClassification, ArceePreTrainedModel):
+ pass
+
+
+__all__ = [
+ "ArceeForCausalLM",
+ "ArceeForQuestionAnswering",
+ "ArceeForSequenceClassification",
+ "ArceeForTokenClassification",
+ "ArceeModel",
+ "ArceePreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/arcee/modular_arcee.py b/venv/lib/python3.13/site-packages/transformers/models/arcee/modular_arcee.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a35ee8a1373e7df25ce0b2dea73856fd95ec3ec
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/arcee/modular_arcee.py
@@ -0,0 +1,225 @@
+# coding=utf-8
+# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Arcee model."""
+
+from transformers.utils import auto_docstring, logging
+
+from ..llama.configuration_llama import LlamaConfig
+from ..llama.modeling_llama import (
+ LlamaForCausalLM,
+ LlamaForQuestionAnswering,
+ LlamaForSequenceClassification,
+ LlamaForTokenClassification,
+)
+from ..nemotron.modeling_nemotron import NemotronMLP
+
+
+logger = logging.get_logger(__name__)
+
+
+class ArceeConfig(LlamaConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ArceeModel`]. It is used to instantiate an Arcee
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the AFM-4.5B-Base.
+
+ Pre-trained weights are available at
+ [arcee-ai/AFM-4.5B](https://huggingface.co/arcee-ai/AFM-4.5B)
+ and were used to build the examples below.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the Arcee model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ArceeModel`]
+ hidden_size (`int`, *optional*, defaults to 2560):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 18432):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
+ The maximum sequence length that this model might ever be used with. AFM-4.5B-Base supports up to 16384 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 128000):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 128001):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'yarn'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'yarn'. The original max position embeddings used during pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn'. The scaling factor to be applied on the attention computation. If unspecified,
+ it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+ head_dim (`int`, *optional*):
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
+
+ ```python
+ >>> from transformers import ArceeModel, ArceeConfig
+
+ >>> # Initializing an Arcee AFM-4.5B-Base style configuration
+ >>> configuration = ArceeConfig()
+
+ >>> # Initializing a model from the AFM-4.5B-Base style configuration
+ >>> model = ArceeModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "arcee"
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=2560,
+ intermediate_size=18432,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="relu2",
+ max_position_embeddings=4096,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=128000,
+ eos_token_id=128001,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ mlp_bias=False,
+ head_dim=None,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ hidden_act=hidden_act,
+ max_position_embeddings=max_position_embeddings,
+ initializer_range=initializer_range,
+ rms_norm_eps=rms_norm_eps,
+ use_cache=use_cache,
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ mlp_bias=mlp_bias,
+ head_dim=head_dim,
+ **kwargs,
+ )
+
+ del self.pretraining_tp
+
+
+class ArceeMLP(NemotronMLP):
+ pass
+
+
+@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
+class ArceeForCausalLM(LlamaForCausalLM):
+ pass
+
+
+@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
+class ArceeForSequenceClassification(LlamaForSequenceClassification):
+ pass
+
+
+@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
+class ArceeForQuestionAnswering(LlamaForQuestionAnswering):
+ pass
+
+
+@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
+class ArceeForTokenClassification(LlamaForTokenClassification):
+ pass
+
+
+__all__ = [
+ "ArceeConfig",
+ "ArceeForCausalLM",
+ "ArceeForQuestionAnswering",
+ "ArceeForSequenceClassification",
+ "ArceeForTokenClassification",
+ "ArceeModel", # noqa: F822
+ "ArceePreTrainedModel", # noqa: F822
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/aria/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/aria/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f73301321527c185cfab149b171a38f5fd4f7852
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/aria/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_aria import *
+ from .image_processing_aria import *
+ from .modeling_aria import *
+ from .processing_aria import *
+
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/aria/configuration_aria.py b/venv/lib/python3.13/site-packages/transformers/models/aria/configuration_aria.py
new file mode 100644
index 0000000000000000000000000000000000000000..67f023e1dbf4903d6815f6bdf8abbdaeea2239a4
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/aria/configuration_aria.py
@@ -0,0 +1,307 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/aria/modular_aria.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_aria.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+class AriaTextConfig(PretrainedConfig):
+ r"""
+ This class handles the configuration for the text component of the Aria model.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
+ [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
+ This class extends the LlamaConfig to include additional parameters specific to the Mixture of Experts (MoE) architecture.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`LlamaModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 4096):
+ The size of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
+ Llama 2 up to 4096, CodeLlama up to 16384.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 2):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+ head_dim (`int`, *optional*):
+ The attention head dimension. If None, it will default to hidden_size // num_heads
+ moe_num_experts (`int`, *optional*, defaults to 8):
+ The number of experts in the MoE layer.
+ moe_topk (`int`, *optional*, defaults to 2):
+ The number of top experts to route to for each token.
+ moe_num_shared_experts (`int`, *optional*, defaults to 2):
+ The number of shared experts.
+ """
+
+ model_type = "aria_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `AriaTextModel`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+ base_config_key = "text_config"
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size: int = 4096,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=2,
+ bos_token_id=1,
+ eos_token_id=2,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ mlp_bias=False,
+ head_dim=None,
+ moe_num_experts: int = 8,
+ moe_topk: int = 2,
+ moe_num_shared_experts: int = 2,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.mlp_bias = mlp_bias
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+ self.moe_num_experts = moe_num_experts
+ self.moe_topk = moe_topk
+ self.moe_num_shared_experts = moe_num_shared_experts
+
+
+class AriaConfig(PretrainedConfig):
+ r"""
+ This class handles the configuration for both vision and text components of the Aria model,
+ as well as additional parameters for image token handling and projector mapping.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
+ [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`AriaVisionConfig` or `dict`, *optional*):
+ Configuration for the vision component.
+ vision_feature_layer (`int`, *optional*, defaults to -1):
+ The index of the layer to select the vision feature.
+ text_config (`AriaTextConfig` or `dict`, *optional*):
+ Configuration for the text component.
+ projector_patch_to_query_dict (`dict`, *optional*):
+ Mapping of patch sizes to query dimensions.
+ image_token_index (`int`, *optional*, defaults to 9):
+ Index used to represent image tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated normal initializer for initializing all weight matrices.
+
+ Attributes:
+ model_type (`str`):
+ Type of the model, set to `"aria"`.
+ image_token_index (`int`):
+ Index used to represent image tokens.
+ projector_patch_to_query_dict (`dict`):
+ Mapping of patch sizes to query dimensions.
+ vision_config (`AriaVisionConfig`):
+ Configuration for the vision component.
+ text_config (`AriaTextConfig`):
+ Configuration for the text component.
+ """
+
+ model_type = "aria"
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ }
+ sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig}
+
+ def __init__(
+ self,
+ vision_config=None,
+ vision_feature_layer: int = -1,
+ text_config: AriaTextConfig = None,
+ projector_patch_to_query_dict: Optional[dict] = None,
+ image_token_index: int = 9,
+ initializer_range: float = 0.02,
+ **kwargs,
+ ):
+ self.image_token_index = image_token_index
+
+ # Convert the keys and values of projector_patch_to_query_dict to integers
+ # This ensures consistency even if they were provided as strings
+ if projector_patch_to_query_dict is None:
+ projector_patch_to_query_dict = {
+ 1225: 128,
+ 4900: 256,
+ }
+ self.projector_patch_to_query_dict = {int(k): int(v) for k, v in projector_patch_to_query_dict.items()}
+ self.max_value_projector_patch_to_query_dict = max(self.projector_patch_to_query_dict.values())
+ self.vision_feature_layer = vision_feature_layer
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = "idefics3_vision"
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ elif vision_config is None:
+ vision_config = CONFIG_MAPPING["idefics3_vision"]()
+
+ self.vision_config = vision_config
+ self.initializer_range = initializer_range
+
+ if isinstance(text_config, dict) and "model_type" in text_config:
+ text_config = AriaTextConfig(**text_config)
+ elif text_config is None:
+ text_config = AriaTextConfig()
+
+ self.text_config = text_config
+
+ super().__init__(**kwargs)
+
+
+__all__ = ["AriaConfig", "AriaTextConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/aria/image_processing_aria.py b/venv/lib/python3.13/site-packages/transformers/models/aria/image_processing_aria.py
new file mode 100644
index 0000000000000000000000000000000000000000..659ed5f112d8d4374973ab13ed8f9248c7b10d69
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/aria/image_processing_aria.py
@@ -0,0 +1,527 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/aria/modular_aria.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_aria.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections.abc import Iterable
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
+from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
+from ...image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def divide_to_patches(image: np.ndarray, patch_size: int, input_data_format) -> list[np.ndarray]:
+ """
+ Divides an image into patches of a specified size.
+
+ Args:
+ image (`np.ndarray`):
+ The input image.
+ patch_size (`int`):
+ The size of each patch.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ list: A list of np.ndarray representing the patches.
+ """
+ patches = []
+ height, width = get_image_size(image, channel_dim=input_data_format)
+ for i in range(0, height, patch_size):
+ for j in range(0, width, patch_size):
+ if input_data_format == ChannelDimension.LAST:
+ patch = image[i : i + patch_size, j : j + patch_size]
+ else:
+ patch = image[:, i : i + patch_size, j : j + patch_size]
+ patches.append(patch)
+
+ return patches
+
+
+class AriaImageProcessor(BaseImageProcessor):
+ """
+ A vision processor for the Aria model that handles image preprocessing.
+ Initialize the AriaImageProcessor.
+
+ Args:
+ image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
+ Mean values for normalization.
+ image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
+ Standard deviation values for normalization.
+ max_image_size (`int`, *optional*, defaults to 980):
+ Maximum image size.
+ min_image_size (`int`, *optional*, defaults to 336):
+ Minimum image size.
+ split_resolutions (`list`, *optional*, defaults to a list of optimal,resolutions as tuples):
+ The optimal resolutions for splitting the image.
+ split_image (`bool`, *optional*, defaults to `False`):
+ Whether to split the image.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image.
+ resample (PILImageResampling, *optional*, defaults to `BICUBIC`):
+ The resampling filter to use if resizing the image.
+ """
+
+ model_input_names = ["pixel_values", "pixel_mask", "num_crops"]
+
+ def __init__(
+ self,
+ image_mean: Optional[list[float]] = None,
+ image_std: Optional[list[float]] = None,
+ max_image_size: int = 980,
+ min_image_size: int = 336,
+ split_resolutions: Optional[list[tuple[int, int]]] = None,
+ split_image: Optional[bool] = False,
+ do_convert_rgb: Optional[bool] = True,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: Optional[bool] = True,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if image_mean is None:
+ image_mean = [0.5, 0.5, 0.5]
+ if image_std is None:
+ image_std = [0.5, 0.5, 0.5]
+ self.max_image_size = max_image_size
+ self.min_image_size = min_image_size
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.split_image = split_image
+ if split_resolutions is None:
+ split_resolutions = [(1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (2, 4), (2, 3), (2, 2), (2, 1), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1), (6, 1), (7, 1), (8, 1)] # fmt: skip
+ split_resolutions = [(el[0] * 490, el[1] * 490) for el in split_resolutions]
+ self.split_resolutions = split_resolutions
+ self.do_convert_rgb = do_convert_rgb
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.resample = resample
+
+ def preprocess(
+ self,
+ images: Union[ImageInput, list[ImageInput]],
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ max_image_size: Optional[int] = None,
+ min_image_size: Optional[int] = None,
+ split_image: Optional[bool] = None,
+ do_convert_rgb: Optional[bool] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ resample: Optional[PILImageResampling] = None,
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Process a list of images.
+
+ Args:
+ images (ImageInput or list of ImageInput):
+ The input image or a list of images.
+ image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
+ Mean values for normalization.
+ image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
+ Standard deviation values for normalization.
+ max_image_size (`int`, *optional*, defaults to `self.max_image_size` (980)):
+ Maximum image size.
+ min_image_size (`int`, *optional*, defaults to `self.min_image_size` (336)):
+ Minimum image size.
+ split_image (`bool`, *optional*, defaults to `self.split_image` (False)):
+ Whether to split the image.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb` (True)):
+ Whether to convert the image to RGB.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize` (True)):
+ Whether to normalize the image.
+ resample (PILImageResampling, *optional*, defaults to `self.resample` (BICUBIC)):
+ The resampling filter to use if resizing the image.
+ return_tensors (`str` or `TensorType`, *optional*, defaults to "pt"):
+ The type of tensor to return.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`:
+ image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`:
+ image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`:
+ image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`:
+ image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+
+ Returns:
+ BatchFeature:
+ A BatchFeature object containing:
+ - 'pixel_values':
+ Tensor of processed image pixel values.
+ - 'pixel_mask':
+ Boolean pixel mask. This mask is a 2D tensor of shape (max_image_size, max_image_size) where:
+ - True (1) values indicate pixels that belong to the original resized image.
+ - False (0) values indicate pixels that are part of the padding.
+ The mask helps distinguish between actual image content and padded areas in subsequent processing steps.
+ - 'num_crops':
+ The maximum number of crops across all images.
+ """
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ max_image_size = max_image_size if max_image_size is not None else self.max_image_size
+ min_image_size = min_image_size if min_image_size is not None else self.min_image_size
+ split_image = split_image if split_image is not None else self.split_image
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ resample = resample if resample is not None else self.resample
+
+ if max_image_size not in [490, 980]:
+ raise ValueError("max_image_size must be either 490 or 980")
+
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ )
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ pixel_values = []
+ pixel_masks = []
+ num_crops = None
+
+ for image in images:
+ if split_image:
+ crop_images = self.get_image_patches(
+ image,
+ self.split_resolutions,
+ max_image_size,
+ resample,
+ data_format=input_data_format,
+ input_data_format=input_data_format,
+ )
+ else:
+ crop_images = [image]
+ if num_crops is None or len(crop_images) > num_crops:
+ num_crops = len(crop_images)
+
+ for crop_image in crop_images:
+ # At this point the scale is the rescaling factor that would bring the image to max_size in its larger dimension
+ h, w = get_image_size(crop_image)
+ scale = max_image_size / max(h, w)
+ if w >= h:
+ new_size = (max(int(h * scale), min_image_size), max_image_size) # h, w
+ else:
+ new_size = (max_image_size, max(int(w * scale), min_image_size)) # h, w
+
+ crop_image_resized = resize(
+ crop_image,
+ new_size,
+ resample=resample,
+ data_format=input_data_format,
+ input_data_format=input_data_format,
+ )
+
+ padding_bottom, padding_right = max_image_size - new_size[0], max_image_size - new_size[1]
+ crop_image_padded = pad(
+ crop_image_resized,
+ ((0, padding_bottom), (0, padding_right)),
+ data_format=input_data_format,
+ input_data_format=input_data_format,
+ )
+
+ # Create a pixel mask
+ pixel_mask = np.zeros((max_image_size, max_image_size), dtype=bool)
+ pixel_mask[: new_size[0], : new_size[1]] = 1
+ pixel_masks.append(pixel_mask)
+
+ if do_rescale:
+ crop_image_padded = self.rescale(
+ image=crop_image_padded, scale=rescale_factor, input_data_format=input_data_format
+ )
+
+ if do_normalize:
+ crop_image_padded = self.normalize(
+ crop_image_padded,
+ self.image_mean,
+ self.image_std,
+ data_format=input_data_format,
+ input_data_format=input_data_format,
+ )
+ crop_image_padded = (
+ to_channel_dimension_format(crop_image_padded, data_format, input_data_format)
+ if data_format is not None
+ else crop_image_padded
+ )
+
+ pixel_values.append(crop_image_padded)
+ return BatchFeature(
+ data={
+ "pixel_values": np.stack(pixel_values, axis=0),
+ "pixel_mask": np.stack(pixel_masks, axis=0),
+ "num_crops": num_crops,
+ },
+ tensor_type=return_tensors,
+ )
+
+ def _resize_for_patching(
+ self, image: np.ndarray, target_resolution: tuple, resample, input_data_format: ChannelDimension
+ ) -> np.ndarray:
+ """
+ Resizes an image to a target resolution while maintaining aspect ratio.
+
+ Args:
+ image (np.ndarray):
+ The input image.
+ target_resolution (tuple):
+ The target resolution (height, width) of the image.
+ resample (`PILImageResampling`):
+ Resampling filter to use if resizing the image.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ np.ndarray: The resized and padded image.
+ """
+ new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
+
+ # Resize the image
+ resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
+
+ return resized_image
+
+ def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
+ original_height, original_width = original_resolution
+ target_height, target_width = target_resolution
+ paste_x, r_x = divmod(target_width - original_width, 2)
+ paste_y, r_y = divmod(target_height - original_height, 2)
+ return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
+
+ def _pad_for_patching(
+ self, image: np.ndarray, target_resolution: tuple, input_data_format: ChannelDimension
+ ) -> np.ndarray:
+ """
+ Pad an image to a target resolution while maintaining aspect ratio.
+ """
+ new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
+ padding = self._get_padding_size(new_resolution, target_resolution)
+
+ padded_image = self.pad(image, padding=padding)
+
+ return padded_image
+
+ def pad(
+ self,
+ image: np.ndarray,
+ padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]],
+ mode: PaddingMode = PaddingMode.CONSTANT,
+ constant_values: Union[float, Iterable[float]] = 0.0,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
+ dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
+ as input.
+
+ Args:
+ image (`np.ndarray`):
+ The image to pad.
+ padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
+ Padding to apply to the edges of the height, width axes. Can be one of three formats:
+ - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
+ - `((before, after),)` yields same before and after pad for height and width.
+ - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
+ mode (`PaddingMode`):
+ The padding mode to use. Can be one of:
+ - `"constant"`: pads with a constant value.
+ - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
+ vector along each axis.
+ - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
+ - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
+ constant_values (`float` or `Iterable[float]`, *optional*):
+ The value to use for the padding if `mode` is `"constant"`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+
+ Returns:
+ `np.ndarray`: The padded image.
+
+ """
+
+ # call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
+ if isinstance(padding, int) or len(padding) != 4:
+ return pad(image, padding, mode, constant_values, data_format, input_data_format)
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+
+ padding_mode_mapping = {
+ PaddingMode.CONSTANT: "constant",
+ PaddingMode.REFLECT: "reflect",
+ PaddingMode.REPLICATE: "edge",
+ PaddingMode.SYMMETRIC: "symmetric",
+ }
+ image = np.pad(image, padding, mode=padding_mode_mapping[mode], constant_values=constant_values)
+ image = (
+ to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
+ )
+ return image
+
+ def get_image_patches(
+ self,
+ image: np.ndarray,
+ grid_pinpoints: list[tuple[int, int]],
+ patch_size: int,
+ resample: PILImageResampling,
+ data_format: ChannelDimension,
+ input_data_format: ChannelDimension,
+ ) -> list[np.ndarray]:
+ """
+ Process an image with variable resolutions by dividing it into patches.
+
+ Args:
+ image (`np.ndarray`):
+ The input image to be processed.
+ grid_pinpoints (list[tuple[int, int]]):
+ A list of possible resolutions as tuples.
+ patch_size (`int`):
+ Size of the patches to divide the image into.
+ resample (`PILImageResampling`):
+ Resampling filter to use if resizing the image.
+ data_format (`ChannelDimension` or `str`):
+ The channel dimension format for the output image.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ `list[np.ndarray]`: A list of NumPy arrays containing the processed image patches.
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints must be a list of possible resolutions.")
+
+ possible_resolutions = grid_pinpoints
+
+ image_size = get_image_size(image, channel_dim=input_data_format)
+ best_resolution = select_best_resolution(image_size, possible_resolutions)
+ resized_image = self._resize_for_patching(
+ image, best_resolution, resample=resample, input_data_format=input_data_format
+ )
+ padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
+
+ patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format)
+
+ # make sure that all patches are in the input data format
+ patches = [
+ to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format)
+ for patch in patches
+ ]
+ return patches
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number of image patches for a given image size.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*)
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of patches per image.
+ """
+ split_image = images_kwargs.get("split_image", self.split_image)
+ max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
+
+ resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
+ num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
+ return num_patches
+
+
+__all__ = ["AriaImageProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/aria/modeling_aria.py b/venv/lib/python3.13/site-packages/transformers/models/aria/modeling_aria.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb9c49952b665372ceb6b53fb24600cf08d69124
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/aria/modeling_aria.py
@@ -0,0 +1,1275 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/aria/modular_aria.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_aria.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from ..auto import AutoModel
+from .configuration_aria import AriaConfig, AriaTextConfig
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class AriaTextRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ AriaTextRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class AriaProjectorMLP(nn.Module):
+ """
+ Feed-Forward Network module for the Aria Projector.
+
+ Args:
+ in_features (`int`):
+ Input embedding dimension.
+ hidden_features (`int`):
+ Hidden dimension of the feed-forward network.
+ output_dim (`int`):
+ Output dimension.
+ """
+
+ def __init__(self, in_features, hidden_features, output_dim):
+ super().__init__()
+ self.linear_in = nn.Linear(in_features, hidden_features, bias=False)
+ self.linear_out = nn.Linear(hidden_features, output_dim, bias=False)
+ self.act = ACT2FN["gelu_new"]
+
+ def forward(self, hidden_states):
+ hidden_states = self.act(self.linear_in(hidden_states))
+ hidden_states = self.linear_out(hidden_states)
+ return hidden_states
+
+
+class AriaCrossAttention(nn.Module):
+ """
+ Aria Cross-Attention module.
+
+ Args:
+ config (`AriaConfig`):
+ The configuration to use.
+ """
+
+ def __init__(self, config: AriaConfig, dropout_rate: float = 0):
+ super().__init__()
+ hidden_size = config.vision_config.hidden_size
+ num_heads = config.vision_config.num_attention_heads
+ self.num_heads = num_heads
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
+
+ # Original code here: https://github.com/rhymes-ai/Aria/blob/719ff4e52b727443cba3793b0e27fe64e0244fe1/aria/model/projector.py#L48
+ self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
+ self.linear = nn.Linear(hidden_size, hidden_size)
+ self.dropout = nn.Dropout(dropout_rate)
+
+ self.layer_norm = nn.LayerNorm(hidden_size)
+ self.layer_norm_kv = nn.LayerNorm(hidden_size)
+
+ def forward(self, key_value_states, hidden_states, attn_mask=None):
+ """
+ Forward pass of the AriaCrossAttention module.
+
+ Args:
+ key_value_states (`torch.Tensor`):
+ Input tensor for key and value.
+ hidden_states (`torch.Tensor`):
+ Input tensor for query.
+ attn_mask (`torch.Tensor`, *optional*, defaults to None):
+ Attention mask.
+
+ Returns:
+ torch.Tensor:
+ Output tensor after cross-attention.
+ """
+ query = self.q_proj(self.layer_norm(hidden_states))
+
+ key_value_states = self.layer_norm_kv(key_value_states)
+ key = self.k_proj(key_value_states)
+ value = self.v_proj(key_value_states)
+
+ attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask)
+
+ attn_output = self.dropout(self.linear(attn_output))
+
+ return attn_output
+
+
+class AriaProjector(nn.Module):
+ """
+ Aria Projector module.
+
+ This module projects vision features into the language model's embedding space, enabling interaction between vision and language components.
+
+ Args:
+ config (`AriaConfig`):
+ Configuration object for the model.
+ """
+
+ def __init__(
+ self,
+ config: AriaConfig,
+ ):
+ super().__init__()
+
+ self.patch_to_query_dict = config.projector_patch_to_query_dict
+ self.in_features = config.vision_config.hidden_size
+ self.num_heads = config.vision_config.num_attention_heads
+ self.kv_dim = config.vision_config.hidden_size
+ self.hidden_features = config.text_config.hidden_size
+ self.output_dim = config.text_config.hidden_size
+
+ self.query = nn.Parameter(torch.zeros(config.max_value_projector_patch_to_query_dict, self.in_features))
+
+ self.cross_attn = AriaCrossAttention(config)
+
+ self.layer_norm = nn.LayerNorm(self.in_features)
+ self.feed_forward = AriaProjectorMLP(self.in_features, self.hidden_features, self.output_dim)
+
+ def forward(self, key_value_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ """
+ Forward pass of the Projector module.
+
+ Args:
+ key_value_states (`torch.Tensor`):
+ Input tensor of shape (batch_size, num_patches, kv_dim).
+ attn_mask (`torch.Tensor`, *optional*, default is None):
+ Attention mask.
+
+ Returns:
+ `torch.Tensor`: Output tensor of shape (batch_size, query_number, output_dim).
+ """
+ batch_size, num_patches = key_value_states.shape[0], key_value_states.shape[1]
+
+ if num_patches not in self.patch_to_query_dict:
+ raise KeyError(
+ f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}."
+ )
+ query_num = self.patch_to_query_dict[num_patches]
+
+ queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
+ attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)
+
+ attention_out = self.cross_attn(key_value_states, queries, attn_mask=attn_mask)
+
+ out = self.feed_forward(self.layer_norm(attention_out))
+
+ return out
+
+
+class AriaSharedExpertsMLP(nn.Module):
+ """
+ Shared Expert MLP for shared experts.
+
+ Unlike routed experts, shared experts process all tokens without routing.
+ This class reconfigures the intermediate size in comparison to the LlamaMLP.
+
+ Args:
+ config (`AriaTextConfig`): Configuration object for the Aria language model.
+ """
+
+ def __init__(self, config: AriaTextConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size * config.moe_num_shared_experts
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
+ """
+ Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
+
+ Args:
+ token_states (torch.Tensor): Input tensor of shape (num_tokens, in_features).
+ expert_weights (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features).
+ tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
+
+ Returns:
+ torch.Tensor: Output tensor of shape (num_tokens, out_features).
+ """
+ num_tokens = token_states.shape[0]
+ out_features = expert_weights.shape[-1]
+ output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device)
+
+ cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
+ # Insert zero at the beginning for offset index's convenience
+ zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
+ cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
+
+ for expert_num in range(expert_weights.shape[0]):
+ start = cumsum_num_tokens[expert_num]
+ end = cumsum_num_tokens[expert_num + 1]
+ tokens = token_states[start:end]
+
+ out = torch.matmul(tokens, expert_weights[expert_num])
+ output[start:end] = out
+ return output
+
+
+class AriaGroupedExpertsGemm(nn.Module):
+ """
+ Grouped GEMM (General Matrix Multiplication) module for efficient expert computation.
+ This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm)
+ for optimized performance. If the grouped_gemm library is not installed, it gracefully
+ falls back to a sequential GEMM implementation, which may be slower but ensures
+ functionality.
+
+ Args:
+ in_features (`int`):
+ Number of input features.
+ out_features (`int`):
+ Number of output features.
+ groups (`int`):
+ Number of expert groups.
+ """
+
+ def __init__(self, in_features, out_features, groups):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.groups = groups
+ self.weight = nn.Parameter(torch.empty(groups, in_features, out_features))
+
+ def forward(self, input, tokens_per_expert):
+ """
+ Perform grouped matrix multiplication.
+
+ Args:
+ input (`torch.Tensor`):
+ Input tensor of shape (num_tokens, in_features).
+ tokens_per_expert (`torch.Tensor`):
+ Number of tokens assigned to each expert.
+
+ Returns:
+ torch.Tensor: Output tensor of shape (num_tokens, out_features).
+ """
+ return sequential_experts_gemm(
+ input,
+ self.weight,
+ tokens_per_expert.cpu(),
+ )
+
+
+class AriaGroupedExpertsMLP(nn.Module):
+ """
+ Grouped MLP module for Mixture of Experts.
+
+ Args:
+ config (`AriaTextConfig`):
+ Configuration object for the model.
+ """
+
+ def __init__(self, config: AriaTextConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts)
+ self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts)
+
+ def forward(self, permuted_tokens, tokens_per_expert):
+ """
+ Forward pass of the Grouped MLP.
+
+ Args:
+ permuted_tokens (torch.Tensor): Permuted input tokens.
+ tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
+
+ Returns:
+ torch.Tensor: Output tensor after passing through the MLP.
+ """
+ fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
+ projection, gate = torch.chunk(fc1_output, 2, dim=-1)
+ fc1_output = nn.functional.silu(projection) * gate
+ fc2_output = self.fc2(fc1_output, tokens_per_expert)
+ return fc2_output
+
+
+# Token permutation adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587
+class AriaTextMoELayer(nn.Module):
+ """
+ Aria Text Mixture of Experts (MoE) Layer.
+
+ This layer applies a gating mechanism to route input tokens to different experts.
+
+ Args:
+ config (`AriaTextConfig`):
+ Configuration object for the text component of the model.
+ """
+
+ def __init__(self, config: AriaTextConfig):
+ super().__init__()
+
+ self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
+ self.experts = AriaGroupedExpertsMLP(config)
+ self.shared_experts = AriaSharedExpertsMLP(config)
+ self.config = config
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the MoE Layer.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input tensor of shape (batch_size, sequence_length, hidden_size).
+
+ Returns:
+ torch.Tensor: Output tensor after passing through the MoE layer.
+
+ Process:
+ 1. Route tokens to experts using the router.
+ 2. Permute tokens based on routing decisions.
+ 3. Process tokens through experts.
+ 4. Unpermute and combine expert outputs.
+ 5. Add shared expert output to the final result.
+ """
+ original_shape = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_states.size(-1))
+
+ # Top K Routing
+ logits = self.router(hidden_states)
+ top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
+ scores = nn.functional.softmax(top_logits, dim=-1)
+
+ original_dtype = top_indices.dtype
+
+ tokens_per_expert = torch.histc(
+ top_indices.flatten().to(torch.float32),
+ bins=self.config.moe_num_experts,
+ min=0,
+ max=self.config.moe_num_experts - 1,
+ ).to(original_dtype)
+ indices = top_indices
+
+ # Token permutation
+ flatten_indices = indices.view(-1)
+ sorted_indices = torch.argsort(flatten_indices)
+ permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk)
+
+ # Process through experts
+ expert_output = self.experts(permuted_tokens, tokens_per_expert)
+
+ # Token unpermutation
+ unpermuted_tokens = torch.zeros(
+ (scores.shape[0] * self.config.moe_topk, expert_output.size(1)),
+ dtype=expert_output.dtype,
+ device=expert_output.device,
+ )
+ unpermuted_tokens.index_copy_(0, sorted_indices, expert_output)
+ unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1))
+
+ output = (unpermuted_tokens * scores.unsqueeze(-1)).sum(dim=1).view(original_shape)
+
+ # Add shared expert output
+ shared_expert_output = self.shared_experts(hidden_states.view(original_shape))
+ return output + shared_expert_output
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class AriaTextAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: AriaTextConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class AriaTextDecoderLayer(GradientCheckpointingLayer):
+ """
+ Aria Text Decoder Layer.
+
+ This class defines a single decoder layer in the language model, incorporating self-attention and Mixture of Experts (MoE) feed-forward network.
+
+ Args:
+ config (`AriaTextConfig`):
+ Configuration object for the text component of the model.
+ layer_idx (`int`):
+ Index of the layer.
+ """
+
+ def __init__(self, config: AriaTextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = AriaTextAttention(config=config, layer_idx=layer_idx)
+ self.mlp = AriaTextMoELayer(config)
+ self.input_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class AriaTextPreTrainedModel(PreTrainedModel):
+ config: AriaTextConfig
+ base_model_prefix = "model"
+ _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
+ supports_gradient_checkpointing = True
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": AriaTextDecoderLayer,
+ "attentions": AriaTextAttention,
+ }
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, AriaGroupedExpertsGemm):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+
+@auto_docstring
+class AriaPreTrainedModel(PreTrainedModel):
+ config: AriaConfig
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["AriaDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": AriaTextDecoderLayer,
+ "attentions": AriaTextAttention,
+ }
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, AriaProjector):
+ nn.init.trunc_normal_(module.query, std=self.config.initializer_range)
+
+
+class AriaTextRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: AriaTextConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@auto_docstring
+class AriaTextModel(AriaTextPreTrainedModel):
+ def __init__(self, config: AriaTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [AriaTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = AriaTextRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config: AriaTextConfig):
+ super().__init__(config)
+ self.model = AriaTextModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AriaTextForCausalLM
+
+ >>> model = AriaTextForCausalLM.from_pretrained("meta-aria_text/AriaText-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-aria_text/AriaText-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Aria causal language model (or autoregressive) outputs.
+ """
+)
+class AriaCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Aria outputs, with hidden states and attentions.
+ """
+)
+class AriaModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ The Aria model which consists of a vision backbone and a language model, without a language modeling head.
+ """
+)
+class AriaModel(AriaPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
+ def __init__(self, config: AriaConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config.vision_config)
+ self.multi_modal_projector = AriaProjector(config)
+ self.language_model = AutoModel.from_config(config.text_config)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_mask: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: int = -1,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
+ The tensors corresponding to the input images.
+ pixel_mask (`torch.FloatTensor]`, *optional*):
+ The tensors corresponding to the input image mask.
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
+ image_outputs = self.vision_tower(
+ pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
+ )
+ image_attn_mask = None
+ if patch_attention_mask is not None:
+ flattened_mask = patch_attention_mask.flatten(1)
+ image_attn_mask = torch.logical_not(flattened_mask)
+
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
+ image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
+ return image_features
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ return special_image_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_mask: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, AriaModelOutputWithPast]:
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # 2. Merge text and images
+ if pixel_values is not None and inputs_embeds.shape[1] != 1:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ vision_feature_layer=self.config.vision_feature_layer,
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return AriaModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values if use_cache else None,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+ def _create_patch_attention_mask(self, pixel_mask):
+ if pixel_mask is None:
+ return None
+
+ patches_subgrid = pixel_mask.unfold(
+ dimension=1,
+ size=self.vision_tower.config.patch_size,
+ step=self.vision_tower.config.patch_size,
+ )
+ patches_subgrid = patches_subgrid.unfold(
+ dimension=2,
+ size=self.vision_tower.config.patch_size,
+ step=self.vision_tower.config.patch_size,
+ )
+ return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
+
+
+@auto_docstring(
+ custom_intro="""
+ Aria model for conditional generation tasks.
+
+ This model combines a vision tower, a multi-modal projector, and a language model
+ to perform tasks that involve both image and text inputs.
+ """
+)
+class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: AriaConfig):
+ super().__init__(config)
+ self.model = AriaModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_mask: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: int = -1,
+ ):
+ return self.model.get_image_features(
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ vision_feature_layer=vision_feature_layer,
+ )
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_mask: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, AriaCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `AriaForConditionalGeneration`).
+ Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
+ computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> import requests
+ >>> import torch
+ >>> from PIL import Image
+ >>> from io import BytesIO
+
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> from transformers.image_utils import load_image
+
+ >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
+ >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
+ >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
+ >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
+
+ >>> processor = AutoProcessor.from_pretrained("Rhymes-AI/Aria")
+ >>> model = AutoModel.from_pretrained("Rhymes-AI/Aria", dtype=torch.bfloat16, device_map="auto")
+
+ >>> # Create inputs
+ >>> messages = [
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {"type": "image"},
+ ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
+ ... {"type": "image"},
+ ... {"type": "text", "text": "What can we see in this image?"},
+ ... ]
+ ... },
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {"type": "image"},
+ ... {"type": "text", "text": "In which city is that bridge located?"},
+ ... ]
+ ... }
+ ... ]
+
+ >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
+ >>> images = [[image1, image2], [image3]]
+ >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
+ >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
+
+ >>> print(generated_texts[0])
+ Assistant: There are buildings, trees, lights, and water visible in this image.
+
+ >>> print(generated_texts[1])
+ Assistant: The bridge is in San Francisco.
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return AriaCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ pixel_mask=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["pixel_mask"] = pixel_mask
+
+ return model_inputs
+
+
+__all__ = [
+ "AriaForConditionalGeneration",
+ "AriaPreTrainedModel",
+ "AriaTextPreTrainedModel",
+ "AriaTextModel",
+ "AriaModel",
+ "AriaTextForCausalLM",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/aria/modular_aria.py b/venv/lib/python3.13/site-packages/transformers/models/aria/modular_aria.py
new file mode 100644
index 0000000000000000000000000000000000000000..02f2f884dadf9469f0fad049ade8981252666e65
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/aria/modular_aria.py
@@ -0,0 +1,1610 @@
+# coding=utf-8
+# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections.abc import Iterable
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...configuration_utils import PretrainedConfig
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
+from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
+from ...image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils import PreTokenizedInput, TextInput
+from ...utils import TensorType, TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ..auto import CONFIG_MAPPING, AutoConfig, AutoTokenizer
+from ..llama.configuration_llama import LlamaConfig
+from ..llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaForCausalLM,
+ LlamaMLP,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+)
+from ..llava.modeling_llava import (
+ LlavaCausalLMOutputWithPast,
+ LlavaForConditionalGeneration,
+ LlavaModel,
+ LlavaModelOutputWithPast,
+)
+from ..llava_next.image_processing_llava_next import divide_to_patches
+
+
+logger = logging.get_logger(__name__)
+
+
+def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
+ """
+ Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
+
+ Args:
+ token_states (torch.Tensor): Input tensor of shape (num_tokens, in_features).
+ expert_weights (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features).
+ tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
+
+ Returns:
+ torch.Tensor: Output tensor of shape (num_tokens, out_features).
+ """
+ num_tokens = token_states.shape[0]
+ out_features = expert_weights.shape[-1]
+ output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device)
+
+ cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
+ # Insert zero at the beginning for offset index's convenience
+ zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
+ cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
+
+ for expert_num in range(expert_weights.shape[0]):
+ start = cumsum_num_tokens[expert_num]
+ end = cumsum_num_tokens[expert_num + 1]
+ tokens = token_states[start:end]
+
+ out = torch.matmul(tokens, expert_weights[expert_num])
+ output[start:end] = out
+ return output
+
+
+class AriaTextConfig(LlamaConfig):
+ r"""
+ This class handles the configuration for the text component of the Aria model.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
+ [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
+ This class extends the LlamaConfig to include additional parameters specific to the Mixture of Experts (MoE) architecture.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`LlamaModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 4096):
+ The size of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
+ Llama 2 up to 4096, CodeLlama up to 16384.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 2):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+ head_dim (`int`, *optional*):
+ The attention head dimension. If None, it will default to hidden_size // num_heads
+ moe_num_experts (`int`, *optional*, defaults to 8):
+ The number of experts in the MoE layer.
+ moe_topk (`int`, *optional*, defaults to 2):
+ The number of top experts to route to for each token.
+ moe_num_shared_experts (`int`, *optional*, defaults to 2):
+ The number of shared experts.
+ """
+
+ model_type = "aria_text"
+ base_config_key = "text_config"
+
+ def __init__(
+ self,
+ intermediate_size: int = 4096,
+ moe_num_experts: int = 8,
+ moe_topk: int = 2,
+ moe_num_shared_experts: int = 2,
+ pad_token_id=2,
+ **super_kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, **super_kwargs)
+ self.intermediate_size = intermediate_size
+ self.moe_num_experts = moe_num_experts
+ self.moe_topk = moe_topk
+ self.moe_num_shared_experts = moe_num_shared_experts
+
+
+class AriaConfig(PretrainedConfig):
+ r"""
+ This class handles the configuration for both vision and text components of the Aria model,
+ as well as additional parameters for image token handling and projector mapping.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
+ [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`AriaVisionConfig` or `dict`, *optional*):
+ Configuration for the vision component.
+ vision_feature_layer (`int`, *optional*, defaults to -1):
+ The index of the layer to select the vision feature.
+ text_config (`AriaTextConfig` or `dict`, *optional*):
+ Configuration for the text component.
+ projector_patch_to_query_dict (`dict`, *optional*):
+ Mapping of patch sizes to query dimensions.
+ image_token_index (`int`, *optional*, defaults to 9):
+ Index used to represent image tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated normal initializer for initializing all weight matrices.
+
+ Attributes:
+ model_type (`str`):
+ Type of the model, set to `"aria"`.
+ image_token_index (`int`):
+ Index used to represent image tokens.
+ projector_patch_to_query_dict (`dict`):
+ Mapping of patch sizes to query dimensions.
+ vision_config (`AriaVisionConfig`):
+ Configuration for the vision component.
+ text_config (`AriaTextConfig`):
+ Configuration for the text component.
+ """
+
+ model_type = "aria"
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ }
+ sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig}
+
+ def __init__(
+ self,
+ vision_config=None,
+ vision_feature_layer: int = -1,
+ text_config: AriaTextConfig = None,
+ projector_patch_to_query_dict: Optional[dict] = None,
+ image_token_index: int = 9,
+ initializer_range: float = 0.02,
+ **kwargs,
+ ):
+ self.image_token_index = image_token_index
+
+ # Convert the keys and values of projector_patch_to_query_dict to integers
+ # This ensures consistency even if they were provided as strings
+ if projector_patch_to_query_dict is None:
+ projector_patch_to_query_dict = {
+ 1225: 128,
+ 4900: 256,
+ }
+ self.projector_patch_to_query_dict = {int(k): int(v) for k, v in projector_patch_to_query_dict.items()}
+ self.max_value_projector_patch_to_query_dict = max(self.projector_patch_to_query_dict.values())
+ self.vision_feature_layer = vision_feature_layer
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = "idefics3_vision"
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ elif vision_config is None:
+ vision_config = CONFIG_MAPPING["idefics3_vision"]()
+
+ self.vision_config = vision_config
+ self.initializer_range = initializer_range
+
+ if isinstance(text_config, dict) and "model_type" in text_config:
+ text_config = AriaTextConfig(**text_config)
+ elif text_config is None:
+ text_config = AriaTextConfig()
+
+ self.text_config = text_config
+
+ super().__init__(**kwargs)
+
+
+class AriaTextRMSNorm(LlamaRMSNorm):
+ pass
+
+
+class AriaProjectorMLP(nn.Module):
+ """
+ Feed-Forward Network module for the Aria Projector.
+
+ Args:
+ in_features (`int`):
+ Input embedding dimension.
+ hidden_features (`int`):
+ Hidden dimension of the feed-forward network.
+ output_dim (`int`):
+ Output dimension.
+ """
+
+ def __init__(self, in_features, hidden_features, output_dim):
+ super().__init__()
+ self.linear_in = nn.Linear(in_features, hidden_features, bias=False)
+ self.linear_out = nn.Linear(hidden_features, output_dim, bias=False)
+ self.act = ACT2FN["gelu_new"]
+
+ def forward(self, hidden_states):
+ hidden_states = self.act(self.linear_in(hidden_states))
+ hidden_states = self.linear_out(hidden_states)
+ return hidden_states
+
+
+class AriaCrossAttention(nn.Module):
+ """
+ Aria Cross-Attention module.
+
+ Args:
+ config (`AriaConfig`):
+ The configuration to use.
+ """
+
+ def __init__(self, config: AriaConfig, dropout_rate: float = 0):
+ super().__init__()
+ hidden_size = config.vision_config.hidden_size
+ num_heads = config.vision_config.num_attention_heads
+ self.num_heads = num_heads
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
+
+ # Original code here: https://github.com/rhymes-ai/Aria/blob/719ff4e52b727443cba3793b0e27fe64e0244fe1/aria/model/projector.py#L48
+ self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
+ self.linear = nn.Linear(hidden_size, hidden_size)
+ self.dropout = nn.Dropout(dropout_rate)
+
+ self.layer_norm = nn.LayerNorm(hidden_size)
+ self.layer_norm_kv = nn.LayerNorm(hidden_size)
+
+ def forward(self, key_value_states, hidden_states, attn_mask=None):
+ """
+ Forward pass of the AriaCrossAttention module.
+
+ Args:
+ key_value_states (`torch.Tensor`):
+ Input tensor for key and value.
+ hidden_states (`torch.Tensor`):
+ Input tensor for query.
+ attn_mask (`torch.Tensor`, *optional*, defaults to None):
+ Attention mask.
+
+ Returns:
+ torch.Tensor:
+ Output tensor after cross-attention.
+ """
+ query = self.q_proj(self.layer_norm(hidden_states))
+
+ key_value_states = self.layer_norm_kv(key_value_states)
+ key = self.k_proj(key_value_states)
+ value = self.v_proj(key_value_states)
+
+ attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask)
+
+ attn_output = self.dropout(self.linear(attn_output))
+
+ return attn_output
+
+
+class AriaProjector(nn.Module):
+ """
+ Aria Projector module.
+
+ This module projects vision features into the language model's embedding space, enabling interaction between vision and language components.
+
+ Args:
+ config (`AriaConfig`):
+ Configuration object for the model.
+ """
+
+ def __init__(
+ self,
+ config: AriaConfig,
+ ):
+ super().__init__()
+
+ self.patch_to_query_dict = config.projector_patch_to_query_dict
+ self.in_features = config.vision_config.hidden_size
+ self.num_heads = config.vision_config.num_attention_heads
+ self.kv_dim = config.vision_config.hidden_size
+ self.hidden_features = config.text_config.hidden_size
+ self.output_dim = config.text_config.hidden_size
+
+ self.query = nn.Parameter(torch.zeros(config.max_value_projector_patch_to_query_dict, self.in_features))
+
+ self.cross_attn = AriaCrossAttention(config)
+
+ self.layer_norm = nn.LayerNorm(self.in_features)
+ self.feed_forward = AriaProjectorMLP(self.in_features, self.hidden_features, self.output_dim)
+
+ def forward(self, key_value_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ """
+ Forward pass of the Projector module.
+
+ Args:
+ key_value_states (`torch.Tensor`):
+ Input tensor of shape (batch_size, num_patches, kv_dim).
+ attn_mask (`torch.Tensor`, *optional*, default is None):
+ Attention mask.
+
+ Returns:
+ `torch.Tensor`: Output tensor of shape (batch_size, query_number, output_dim).
+ """
+ batch_size, num_patches = key_value_states.shape[0], key_value_states.shape[1]
+
+ if num_patches not in self.patch_to_query_dict:
+ raise KeyError(
+ f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}."
+ )
+ query_num = self.patch_to_query_dict[num_patches]
+
+ queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
+ attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)
+
+ attention_out = self.cross_attn(key_value_states, queries, attn_mask=attn_mask)
+
+ out = self.feed_forward(self.layer_norm(attention_out))
+
+ return out
+
+
+class AriaImageProcessor(BaseImageProcessor):
+ """
+ A vision processor for the Aria model that handles image preprocessing.
+ Initialize the AriaImageProcessor.
+
+ Args:
+ image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
+ Mean values for normalization.
+ image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
+ Standard deviation values for normalization.
+ max_image_size (`int`, *optional*, defaults to 980):
+ Maximum image size.
+ min_image_size (`int`, *optional*, defaults to 336):
+ Minimum image size.
+ split_resolutions (`list`, *optional*, defaults to a list of optimal,resolutions as tuples):
+ The optimal resolutions for splitting the image.
+ split_image (`bool`, *optional*, defaults to `False`):
+ Whether to split the image.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image.
+ resample (PILImageResampling, *optional*, defaults to `BICUBIC`):
+ The resampling filter to use if resizing the image.
+ """
+
+ model_input_names = ["pixel_values", "pixel_mask", "num_crops"]
+
+ def __init__(
+ self,
+ image_mean: Optional[list[float]] = None,
+ image_std: Optional[list[float]] = None,
+ max_image_size: int = 980,
+ min_image_size: int = 336,
+ split_resolutions: Optional[list[tuple[int, int]]] = None,
+ split_image: Optional[bool] = False,
+ do_convert_rgb: Optional[bool] = True,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: Optional[bool] = True,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if image_mean is None:
+ image_mean = [0.5, 0.5, 0.5]
+ if image_std is None:
+ image_std = [0.5, 0.5, 0.5]
+ self.max_image_size = max_image_size
+ self.min_image_size = min_image_size
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.split_image = split_image
+ if split_resolutions is None:
+ split_resolutions = [(1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (2, 4), (2, 3), (2, 2), (2, 1), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1), (6, 1), (7, 1), (8, 1)] # fmt: skip
+ split_resolutions = [(el[0] * 490, el[1] * 490) for el in split_resolutions]
+ self.split_resolutions = split_resolutions
+ self.do_convert_rgb = do_convert_rgb
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.resample = resample
+
+ def preprocess(
+ self,
+ images: Union[ImageInput, list[ImageInput]],
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ max_image_size: Optional[int] = None,
+ min_image_size: Optional[int] = None,
+ split_image: Optional[bool] = None,
+ do_convert_rgb: Optional[bool] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ resample: Optional[PILImageResampling] = None,
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Process a list of images.
+
+ Args:
+ images (ImageInput or list of ImageInput):
+ The input image or a list of images.
+ image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
+ Mean values for normalization.
+ image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
+ Standard deviation values for normalization.
+ max_image_size (`int`, *optional*, defaults to `self.max_image_size` (980)):
+ Maximum image size.
+ min_image_size (`int`, *optional*, defaults to `self.min_image_size` (336)):
+ Minimum image size.
+ split_image (`bool`, *optional*, defaults to `self.split_image` (False)):
+ Whether to split the image.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb` (True)):
+ Whether to convert the image to RGB.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize` (True)):
+ Whether to normalize the image.
+ resample (PILImageResampling, *optional*, defaults to `self.resample` (BICUBIC)):
+ The resampling filter to use if resizing the image.
+ return_tensors (`str` or `TensorType`, *optional*, defaults to "pt"):
+ The type of tensor to return.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`:
+ image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`:
+ image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`:
+ image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`:
+ image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+
+ Returns:
+ BatchFeature:
+ A BatchFeature object containing:
+ - 'pixel_values':
+ Tensor of processed image pixel values.
+ - 'pixel_mask':
+ Boolean pixel mask. This mask is a 2D tensor of shape (max_image_size, max_image_size) where:
+ - True (1) values indicate pixels that belong to the original resized image.
+ - False (0) values indicate pixels that are part of the padding.
+ The mask helps distinguish between actual image content and padded areas in subsequent processing steps.
+ - 'num_crops':
+ The maximum number of crops across all images.
+ """
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ max_image_size = max_image_size if max_image_size is not None else self.max_image_size
+ min_image_size = min_image_size if min_image_size is not None else self.min_image_size
+ split_image = split_image if split_image is not None else self.split_image
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ resample = resample if resample is not None else self.resample
+
+ if max_image_size not in [490, 980]:
+ raise ValueError("max_image_size must be either 490 or 980")
+
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ )
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ pixel_values = []
+ pixel_masks = []
+ num_crops = None
+
+ for image in images:
+ if split_image:
+ crop_images = self.get_image_patches(
+ image,
+ self.split_resolutions,
+ max_image_size,
+ resample,
+ data_format=input_data_format,
+ input_data_format=input_data_format,
+ )
+ else:
+ crop_images = [image]
+ if num_crops is None or len(crop_images) > num_crops:
+ num_crops = len(crop_images)
+
+ for crop_image in crop_images:
+ # At this point the scale is the rescaling factor that would bring the image to max_size in its larger dimension
+ h, w = get_image_size(crop_image)
+ scale = max_image_size / max(h, w)
+ if w >= h:
+ new_size = (max(int(h * scale), min_image_size), max_image_size) # h, w
+ else:
+ new_size = (max_image_size, max(int(w * scale), min_image_size)) # h, w
+
+ crop_image_resized = resize(
+ crop_image,
+ new_size,
+ resample=resample,
+ data_format=input_data_format,
+ input_data_format=input_data_format,
+ )
+
+ padding_bottom, padding_right = max_image_size - new_size[0], max_image_size - new_size[1]
+ crop_image_padded = pad(
+ crop_image_resized,
+ ((0, padding_bottom), (0, padding_right)),
+ data_format=input_data_format,
+ input_data_format=input_data_format,
+ )
+
+ # Create a pixel mask
+ pixel_mask = np.zeros((max_image_size, max_image_size), dtype=bool)
+ pixel_mask[: new_size[0], : new_size[1]] = 1
+ pixel_masks.append(pixel_mask)
+
+ if do_rescale:
+ crop_image_padded = self.rescale(
+ image=crop_image_padded, scale=rescale_factor, input_data_format=input_data_format
+ )
+
+ if do_normalize:
+ crop_image_padded = self.normalize(
+ crop_image_padded,
+ self.image_mean,
+ self.image_std,
+ data_format=input_data_format,
+ input_data_format=input_data_format,
+ )
+ crop_image_padded = (
+ to_channel_dimension_format(crop_image_padded, data_format, input_data_format)
+ if data_format is not None
+ else crop_image_padded
+ )
+
+ pixel_values.append(crop_image_padded)
+ return BatchFeature(
+ data={
+ "pixel_values": np.stack(pixel_values, axis=0),
+ "pixel_mask": np.stack(pixel_masks, axis=0),
+ "num_crops": num_crops,
+ },
+ tensor_type=return_tensors,
+ )
+
+ def _resize_for_patching(
+ self, image: np.ndarray, target_resolution: tuple, resample, input_data_format: ChannelDimension
+ ) -> np.ndarray:
+ """
+ Resizes an image to a target resolution while maintaining aspect ratio.
+
+ Args:
+ image (np.ndarray):
+ The input image.
+ target_resolution (tuple):
+ The target resolution (height, width) of the image.
+ resample (`PILImageResampling`):
+ Resampling filter to use if resizing the image.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ np.ndarray: The resized and padded image.
+ """
+ new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
+
+ # Resize the image
+ resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
+
+ return resized_image
+
+ def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
+ original_height, original_width = original_resolution
+ target_height, target_width = target_resolution
+ paste_x, r_x = divmod(target_width - original_width, 2)
+ paste_y, r_y = divmod(target_height - original_height, 2)
+ return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
+
+ def _pad_for_patching(
+ self, image: np.ndarray, target_resolution: tuple, input_data_format: ChannelDimension
+ ) -> np.ndarray:
+ """
+ Pad an image to a target resolution while maintaining aspect ratio.
+ """
+ new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
+ padding = self._get_padding_size(new_resolution, target_resolution)
+
+ padded_image = self.pad(image, padding=padding)
+
+ return padded_image
+
+ def pad(
+ self,
+ image: np.ndarray,
+ padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]],
+ mode: PaddingMode = PaddingMode.CONSTANT,
+ constant_values: Union[float, Iterable[float]] = 0.0,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
+ dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
+ as input.
+
+ Args:
+ image (`np.ndarray`):
+ The image to pad.
+ padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
+ Padding to apply to the edges of the height, width axes. Can be one of three formats:
+ - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
+ - `((before, after),)` yields same before and after pad for height and width.
+ - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
+ mode (`PaddingMode`):
+ The padding mode to use. Can be one of:
+ - `"constant"`: pads with a constant value.
+ - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
+ vector along each axis.
+ - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
+ - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
+ constant_values (`float` or `Iterable[float]`, *optional*):
+ The value to use for the padding if `mode` is `"constant"`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+
+ Returns:
+ `np.ndarray`: The padded image.
+
+ """
+
+ # call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
+ if isinstance(padding, int) or len(padding) != 4:
+ return pad(image, padding, mode, constant_values, data_format, input_data_format)
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+
+ padding_mode_mapping = {
+ PaddingMode.CONSTANT: "constant",
+ PaddingMode.REFLECT: "reflect",
+ PaddingMode.REPLICATE: "edge",
+ PaddingMode.SYMMETRIC: "symmetric",
+ }
+ image = np.pad(image, padding, mode=padding_mode_mapping[mode], constant_values=constant_values)
+ image = (
+ to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
+ )
+ return image
+
+ def get_image_patches(
+ self,
+ image: np.ndarray,
+ grid_pinpoints: list[tuple[int, int]],
+ patch_size: int,
+ resample: PILImageResampling,
+ data_format: ChannelDimension,
+ input_data_format: ChannelDimension,
+ ) -> list[np.ndarray]:
+ """
+ Process an image with variable resolutions by dividing it into patches.
+
+ Args:
+ image (`np.ndarray`):
+ The input image to be processed.
+ grid_pinpoints (list[tuple[int, int]]):
+ A list of possible resolutions as tuples.
+ patch_size (`int`):
+ Size of the patches to divide the image into.
+ resample (`PILImageResampling`):
+ Resampling filter to use if resizing the image.
+ data_format (`ChannelDimension` or `str`):
+ The channel dimension format for the output image.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ `list[np.ndarray]`: A list of NumPy arrays containing the processed image patches.
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints must be a list of possible resolutions.")
+
+ possible_resolutions = grid_pinpoints
+
+ image_size = get_image_size(image, channel_dim=input_data_format)
+ best_resolution = select_best_resolution(image_size, possible_resolutions)
+ resized_image = self._resize_for_patching(
+ image, best_resolution, resample=resample, input_data_format=input_data_format
+ )
+ padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
+
+ patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format)
+
+ # make sure that all patches are in the input data format
+ patches = [
+ to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format)
+ for patch in patches
+ ]
+ return patches
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number of image patches for a given image size.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*)
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of patches per image.
+ """
+ split_image = images_kwargs.get("split_image", self.split_image)
+ max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
+
+ resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
+ num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
+ return num_patches
+
+
+class AriaProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "return_mm_token_type_ids": False,
+ },
+ "images_kwargs": {
+ "max_image_size": 980,
+ "split_image": False,
+ },
+ "return_tensors": TensorType.PYTORCH,
+ }
+
+
+class AriaProcessor(ProcessorMixin):
+ """
+ AriaProcessor is a processor for the Aria model which wraps the Aria image preprocessor and the LLama slow tokenizer.
+
+ Args:
+ image_processor (`AriaImageProcessor`, *optional*):
+ The AriaImageProcessor to use for image preprocessing.
+ tokenizer (`PreTrainedTokenizerBase`, *optional*):
+ An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
+ chat_template (`str`, *optional*):
+ A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string.
+ size_conversion (`Dict`, *optional*):
+ A dictionary indicating size conversions for images.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "AriaImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer: Union[AutoTokenizer, str] = None,
+ chat_template: Optional[str] = None,
+ size_conversion: Optional[dict[Union[float, int], int]] = None,
+ ):
+ if size_conversion is None:
+ size_conversion = {490: 128, 980: 256}
+ self.size_conversion = {int(k): v for k, v in size_conversion.items()}
+
+ self.image_token = tokenizer.image_token
+ self.image_token_id = tokenizer.image_token_id
+ if tokenizer is not None and tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.unk_token
+
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
+ images: Optional[ImageInput] = None,
+ audio=None,
+ videos=None,
+ **kwargs: Unpack[AriaProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s).
+
+ Args:
+ text (`TextInput`, `PreTokenizedInput`, `list[TextInput]`, `list[PreTokenizedInput]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ images (`ImageInput`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ AriaProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list) and not isinstance(text[0], str):
+ raise TypeError("Invalid input text. Please provide a string, or a list of strings")
+
+ if images is not None:
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
+ # expand the image_token according to the num_crops and tokens per image
+ tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
+ prompt_strings = []
+ num_crops = image_inputs.pop("num_crops") * tokens_per_image
+ for sample in text:
+ sample = sample.replace(self.tokenizer.image_token, self.tokenizer.image_token * num_crops)
+ prompt_strings.append(sample)
+
+ else:
+ image_inputs = {}
+ prompt_strings = text
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
+ self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+ Args:
+ image_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (height, width) per each image.
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
+ images_kwargs.update(kwargs)
+
+ max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
+ num_image_patches = [
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
+ for image_size in image_sizes
+ ]
+ num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+
+ return MultiModalData(**vision_data)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+
+ # Remove `num_crops`, it is popped and used only when processing. Make a copy of list when removing
+ # otherwise `self.image_processor.model_input_names` is also modified
+ image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
+
+class AriaSharedExpertsMLP(LlamaMLP):
+ """
+ Shared Expert MLP for shared experts.
+
+ Unlike routed experts, shared experts process all tokens without routing.
+ This class reconfigures the intermediate size in comparison to the LlamaMLP.
+
+ Args:
+ config (`AriaTextConfig`): Configuration object for the Aria language model.
+ """
+
+ def __init__(self, config: AriaTextConfig):
+ super().__init__(config)
+ self.intermediate_size = config.intermediate_size * config.moe_num_shared_experts
+
+
+class AriaGroupedExpertsGemm(nn.Module):
+ """
+ Grouped GEMM (General Matrix Multiplication) module for efficient expert computation.
+ This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm)
+ for optimized performance. If the grouped_gemm library is not installed, it gracefully
+ falls back to a sequential GEMM implementation, which may be slower but ensures
+ functionality.
+
+ Args:
+ in_features (`int`):
+ Number of input features.
+ out_features (`int`):
+ Number of output features.
+ groups (`int`):
+ Number of expert groups.
+ """
+
+ def __init__(self, in_features, out_features, groups):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.groups = groups
+ self.weight = nn.Parameter(torch.empty(groups, in_features, out_features))
+
+ def forward(self, input, tokens_per_expert):
+ """
+ Perform grouped matrix multiplication.
+
+ Args:
+ input (`torch.Tensor`):
+ Input tensor of shape (num_tokens, in_features).
+ tokens_per_expert (`torch.Tensor`):
+ Number of tokens assigned to each expert.
+
+ Returns:
+ torch.Tensor: Output tensor of shape (num_tokens, out_features).
+ """
+ return sequential_experts_gemm(
+ input,
+ self.weight,
+ tokens_per_expert.cpu(),
+ )
+
+
+class AriaGroupedExpertsMLP(nn.Module):
+ """
+ Grouped MLP module for Mixture of Experts.
+
+ Args:
+ config (`AriaTextConfig`):
+ Configuration object for the model.
+ """
+
+ def __init__(self, config: AriaTextConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts)
+ self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts)
+
+ def forward(self, permuted_tokens, tokens_per_expert):
+ """
+ Forward pass of the Grouped MLP.
+
+ Args:
+ permuted_tokens (torch.Tensor): Permuted input tokens.
+ tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
+
+ Returns:
+ torch.Tensor: Output tensor after passing through the MLP.
+ """
+ fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
+ projection, gate = torch.chunk(fc1_output, 2, dim=-1)
+ fc1_output = nn.functional.silu(projection) * gate
+ fc2_output = self.fc2(fc1_output, tokens_per_expert)
+ return fc2_output
+
+
+# Token permutation adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587
+class AriaTextMoELayer(nn.Module):
+ """
+ Aria Text Mixture of Experts (MoE) Layer.
+
+ This layer applies a gating mechanism to route input tokens to different experts.
+
+ Args:
+ config (`AriaTextConfig`):
+ Configuration object for the text component of the model.
+ """
+
+ def __init__(self, config: AriaTextConfig):
+ super().__init__()
+
+ self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
+ self.experts = AriaGroupedExpertsMLP(config)
+ self.shared_experts = AriaSharedExpertsMLP(config)
+ self.config = config
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the MoE Layer.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input tensor of shape (batch_size, sequence_length, hidden_size).
+
+ Returns:
+ torch.Tensor: Output tensor after passing through the MoE layer.
+
+ Process:
+ 1. Route tokens to experts using the router.
+ 2. Permute tokens based on routing decisions.
+ 3. Process tokens through experts.
+ 4. Unpermute and combine expert outputs.
+ 5. Add shared expert output to the final result.
+ """
+ original_shape = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_states.size(-1))
+
+ # Top K Routing
+ logits = self.router(hidden_states)
+ top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
+ scores = nn.functional.softmax(top_logits, dim=-1)
+
+ original_dtype = top_indices.dtype
+
+ tokens_per_expert = torch.histc(
+ top_indices.flatten().to(torch.float32),
+ bins=self.config.moe_num_experts,
+ min=0,
+ max=self.config.moe_num_experts - 1,
+ ).to(original_dtype)
+ indices = top_indices
+
+ # Token permutation
+ flatten_indices = indices.view(-1)
+ sorted_indices = torch.argsort(flatten_indices)
+ permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk)
+
+ # Process through experts
+ expert_output = self.experts(permuted_tokens, tokens_per_expert)
+
+ # Token unpermutation
+ unpermuted_tokens = torch.zeros(
+ (scores.shape[0] * self.config.moe_topk, expert_output.size(1)),
+ dtype=expert_output.dtype,
+ device=expert_output.device,
+ )
+ unpermuted_tokens.index_copy_(0, sorted_indices, expert_output)
+ unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1))
+
+ output = (unpermuted_tokens * scores.unsqueeze(-1)).sum(dim=1).view(original_shape)
+
+ # Add shared expert output
+ shared_expert_output = self.shared_experts(hidden_states.view(original_shape))
+ return output + shared_expert_output
+
+
+class AriaTextAttention(LlamaAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ pass
+
+
+class AriaTextDecoderLayer(LlamaDecoderLayer):
+ """
+ Aria Text Decoder Layer.
+
+ This class defines a single decoder layer in the language model, incorporating self-attention and Mixture of Experts (MoE) feed-forward network.
+
+ Args:
+ config (`AriaTextConfig`):
+ Configuration object for the text component of the model.
+ layer_idx (`int`):
+ Index of the layer.
+ """
+
+ def __init__(self, config: AriaTextConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.mlp = AriaTextMoELayer(config)
+
+
+@auto_docstring
+class AriaTextPreTrainedModel(PreTrainedModel):
+ config: AriaTextConfig
+ base_model_prefix = "model"
+ _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
+ supports_gradient_checkpointing = True
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": AriaTextDecoderLayer,
+ "attentions": AriaTextAttention,
+ }
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, AriaGroupedExpertsGemm):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+
+class AriaPreTrainedModel(LlamaPreTrainedModel):
+ config: AriaConfig
+ base_model_prefix = ""
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ PreTrainedModel._init_weights(self, module)
+ if isinstance(module, AriaProjector):
+ nn.init.trunc_normal_(module.query, std=self.config.initializer_range)
+
+
+class AriaTextModel(LlamaModel):
+ def __init__(self, config: AriaTextConfig):
+ super().__init__(config)
+ self.layers = nn.ModuleList(
+ [AriaTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.gradient_checkpointing = False
+ self.post_init()
+
+
+class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: AriaTextConfig):
+ super().__init__(config)
+ self.model = AriaTextModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(self, **super_kwargs):
+ super().forward(self, **super_kwargs)
+
+
+class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
+ pass
+
+
+class AriaModelOutputWithPast(LlavaModelOutputWithPast):
+ pass
+
+
+class AriaModel(LlavaModel):
+ def __init__(self, config: AriaConfig):
+ super().__init__(config)
+ self.multi_modal_projector = AriaProjector(config)
+
+ def _create_patch_attention_mask(self, pixel_mask):
+ if pixel_mask is None:
+ return None
+
+ patches_subgrid = pixel_mask.unfold(
+ dimension=1,
+ size=self.vision_tower.config.patch_size,
+ step=self.vision_tower.config.patch_size,
+ )
+ patches_subgrid = patches_subgrid.unfold(
+ dimension=2,
+ size=self.vision_tower.config.patch_size,
+ step=self.vision_tower.config.patch_size,
+ )
+ return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_mask: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: int = -1,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
+ The tensors corresponding to the input images.
+ pixel_mask (`torch.FloatTensor]`, *optional*):
+ The tensors corresponding to the input image mask.
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
+ image_outputs = self.vision_tower(
+ pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
+ )
+ image_attn_mask = None
+ if patch_attention_mask is not None:
+ flattened_mask = patch_attention_mask.flatten(1)
+ image_attn_mask = torch.logical_not(flattened_mask)
+
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
+ image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
+ return image_features
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_mask: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, AriaModelOutputWithPast]:
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # 2. Merge text and images
+ if pixel_values is not None and inputs_embeds.shape[1] != 1:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ vision_feature_layer=self.config.vision_feature_layer,
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return AriaModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values if use_cache else None,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Aria model for conditional generation tasks.
+
+ This model combines a vision tower, a multi-modal projector, and a language model
+ to perform tasks that involve both image and text inputs.
+ """
+)
+class AriaForConditionalGeneration(LlavaForConditionalGeneration):
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_mask: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: int = -1,
+ ):
+ return self.model.get_image_features(
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ vision_feature_layer=vision_feature_layer,
+ )
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_mask: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, AriaCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `AriaForConditionalGeneration`).
+ Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
+ computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> import requests
+ >>> import torch
+ >>> from PIL import Image
+ >>> from io import BytesIO
+
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> from transformers.image_utils import load_image
+
+ >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
+ >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
+ >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
+ >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
+
+ >>> processor = AutoProcessor.from_pretrained("Rhymes-AI/Aria")
+ >>> model = AutoModel.from_pretrained("Rhymes-AI/Aria", dtype=torch.bfloat16, device_map="auto")
+
+ >>> # Create inputs
+ >>> messages = [
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {"type": "image"},
+ ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
+ ... {"type": "image"},
+ ... {"type": "text", "text": "What can we see in this image?"},
+ ... ]
+ ... },
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {"type": "image"},
+ ... {"type": "text", "text": "In which city is that bridge located?"},
+ ... ]
+ ... }
+ ... ]
+
+ >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
+ >>> images = [[image1, image2], [image3]]
+ >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
+ >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
+
+ >>> print(generated_texts[0])
+ Assistant: There are buildings, trees, lights, and water visible in this image.
+
+ >>> print(generated_texts[1])
+ Assistant: The bridge is in San Francisco.
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return AriaCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ pixel_mask=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["pixel_mask"] = pixel_mask
+
+ return model_inputs
+
+
+__all__ = [
+ "AriaConfig",
+ "AriaTextConfig",
+ "AriaImageProcessor",
+ "AriaProcessor",
+ "AriaForConditionalGeneration",
+ "AriaPreTrainedModel",
+ "AriaTextPreTrainedModel",
+ "AriaTextModel",
+ "AriaModel",
+ "AriaTextForCausalLM",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/aria/processing_aria.py b/venv/lib/python3.13/site-packages/transformers/models/aria/processing_aria.py
new file mode 100644
index 0000000000000000000000000000000000000000..9264776e80fdab173276ce38ccd8f89965b161f2
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/aria/processing_aria.py
@@ -0,0 +1,189 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/aria/modular_aria.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_aria.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils import PreTokenizedInput, TextInput
+from ...utils import TensorType
+from ..auto import AutoTokenizer
+
+
+class AriaProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "return_mm_token_type_ids": False,
+ },
+ "images_kwargs": {
+ "max_image_size": 980,
+ "split_image": False,
+ },
+ "return_tensors": TensorType.PYTORCH,
+ }
+
+
+class AriaProcessor(ProcessorMixin):
+ """
+ AriaProcessor is a processor for the Aria model which wraps the Aria image preprocessor and the LLama slow tokenizer.
+
+ Args:
+ image_processor (`AriaImageProcessor`, *optional*):
+ The AriaImageProcessor to use for image preprocessing.
+ tokenizer (`PreTrainedTokenizerBase`, *optional*):
+ An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
+ chat_template (`str`, *optional*):
+ A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string.
+ size_conversion (`Dict`, *optional*):
+ A dictionary indicating size conversions for images.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "AriaImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer: Union[AutoTokenizer, str] = None,
+ chat_template: Optional[str] = None,
+ size_conversion: Optional[dict[Union[float, int], int]] = None,
+ ):
+ if size_conversion is None:
+ size_conversion = {490: 128, 980: 256}
+ self.size_conversion = {int(k): v for k, v in size_conversion.items()}
+
+ self.image_token = tokenizer.image_token
+ self.image_token_id = tokenizer.image_token_id
+ if tokenizer is not None and tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.unk_token
+
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
+ images: Optional[ImageInput] = None,
+ audio=None,
+ videos=None,
+ **kwargs: Unpack[AriaProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s).
+
+ Args:
+ text (`TextInput`, `PreTokenizedInput`, `list[TextInput]`, `list[PreTokenizedInput]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ images (`ImageInput`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ AriaProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list) and not isinstance(text[0], str):
+ raise TypeError("Invalid input text. Please provide a string, or a list of strings")
+
+ if images is not None:
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
+ # expand the image_token according to the num_crops and tokens per image
+ tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
+ prompt_strings = []
+ num_crops = image_inputs.pop("num_crops") * tokens_per_image
+ for sample in text:
+ sample = sample.replace(self.tokenizer.image_token, self.tokenizer.image_token * num_crops)
+ prompt_strings.append(sample)
+
+ else:
+ image_inputs = {}
+ prompt_strings = text
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
+ self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+ Args:
+ image_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (height, width) per each image.
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
+ images_kwargs.update(kwargs)
+
+ max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
+ num_image_patches = [
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
+ for image_size in image_sizes
+ ]
+ num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+
+ return MultiModalData(**vision_data)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+
+ # Remove `num_crops`, it is popped and used only when processing. Make a copy of list when removing
+ # otherwise `self.image_processor.model_input_names` is also modified
+ image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
+
+__all__ = ["AriaProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/auto/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/auto/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..34a6ae1e5c2e4f042c141624bd2296587e9f811d
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/auto/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .auto_factory import *
+ from .configuration_auto import *
+ from .feature_extraction_auto import *
+ from .image_processing_auto import *
+ from .modeling_auto import *
+ from .modeling_flax_auto import *
+ from .modeling_tf_auto import *
+ from .processing_auto import *
+ from .tokenization_auto import *
+ from .video_processing_auto import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/auto/auto_factory.py b/venv/lib/python3.13/site-packages/transformers/models/auto/auto_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8781c8042a6f51aa2b68f03a847f6a6320ec9ba
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/auto/auto_factory.py
@@ -0,0 +1,882 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Factory function to build auto-model classes."""
+
+import copy
+import importlib
+import json
+import os
+import warnings
+from collections import OrderedDict
+from collections.abc import Iterator
+from typing import Any, TypeVar, Union
+
+from ...configuration_utils import PretrainedConfig
+from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
+from ...utils import (
+ CONFIG_NAME,
+ cached_file,
+ copy_func,
+ extract_commit_hash,
+ find_adapter_config_file,
+ is_peft_available,
+ is_torch_available,
+ logging,
+ requires_backends,
+)
+from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
+
+
+if is_torch_available():
+ from ...generation import GenerationMixin
+
+
+logger = logging.get_logger(__name__)
+
+_T = TypeVar("_T")
+# Tokenizers will depend on packages installed, too much variance and there are no common base or Protocol
+_LazyAutoMappingValue = tuple[Union[type[Any], None], Union[type[Any], None]]
+
+CLASS_DOCSTRING = """
+ This is a generic model class that will be instantiated as one of the model classes of the library when created
+ with the [`~BaseAutoModelClass.from_pretrained`] class method or the [`~BaseAutoModelClass.from_config`] class
+ method.
+
+ This class cannot be instantiated directly using `__init__()` (throws an error).
+"""
+
+FROM_CONFIG_DOCSTRING = """
+ Instantiates one of the model classes of the library from a configuration.
+
+ Note:
+ Loading a model from its configuration file does **not** load the model weights. It only affects the
+ model's configuration. Use [`~BaseAutoModelClass.from_pretrained`] to load the model weights.
+
+ Args:
+ config ([`PretrainedConfig`]):
+ The model class to instantiate is selected based on the configuration class:
+
+ List options
+ attn_implementation (`str`, *optional*):
+ The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoConfig, BaseAutoModelClass
+
+ >>> # Download configuration from huggingface.co and cache.
+ >>> config = AutoConfig.from_pretrained("checkpoint_placeholder")
+ >>> model = BaseAutoModelClass.from_config(config)
+ ```
+"""
+
+FROM_PRETRAINED_TORCH_DOCSTRING = """
+ Instantiate one of the model classes of the library from a pretrained model.
+
+ The model class to instantiate is selected based on the `model_type` property of the config object (either
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
+ falling back to using pattern matching on `pretrained_model_name_or_path`:
+
+ List options
+
+ The model is set in evaluation mode by default using `model.eval()` (so for instance, dropout modules are
+ deactivated). To train the model, you should first set it back in training mode with `model.train()`
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+ model_args (additional positional arguments, *optional*):
+ Will be passed along to the underlying model `__init__()` method.
+ config ([`PretrainedConfig`], *optional*):
+ Configuration for the model to use instead of an automatically loaded configuration. Configuration can
+ be automatically loaded when:
+
+ - The model is a model provided by the library (loaded with the *model id* string of a pretrained
+ model).
+ - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
+ save directory.
+ - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
+ configuration JSON file named *config.json* is found in the directory.
+ state_dict (*dict[str, torch.Tensor]*, *optional*):
+ A state dictionary to use instead of a state dictionary loaded from saved weights file.
+
+ This option can be used if you want to create a model from a pretrained configuration but load your own
+ weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and
+ [`~PreTrainedModel.from_pretrained`] is not a simpler option.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ from_tf (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a TensorFlow checkpoint save file (see docstring of
+ `pretrained_model_name_or_path` argument).
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (e.g., not try downloading the model).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
+ execute code present on the Hub on your local machine.
+ code_revision (`str`, *optional*, defaults to `"main"`):
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than
+ the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
+ system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
+ allowed by git.
+ kwargs (additional keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
+ automatically loaded:
+
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
+ already been done)
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
+ initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
+ corresponds to a configuration attribute will be used to override said attribute with the
+ supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
+ will be passed to the underlying model's `__init__` function.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoConfig, BaseAutoModelClass
+
+ >>> # Download model and configuration from huggingface.co and cache.
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
+
+ >>> # Update configuration during loading
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
+ >>> model.config.output_attentions
+ True
+
+ >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
+ >>> config = AutoConfig.from_pretrained("./tf_model/shortcut_placeholder_tf_model_config.json")
+ >>> model = BaseAutoModelClass.from_pretrained(
+ ... "./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index", from_tf=True, config=config
+ ... )
+ ```
+"""
+
+FROM_PRETRAINED_TF_DOCSTRING = """
+ Instantiate one of the model classes of the library from a pretrained model.
+
+ The model class to instantiate is selected based on the `model_type` property of the config object (either
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
+ falling back to using pattern matching on `pretrained_model_name_or_path`:
+
+ List options
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
+ case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
+ argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
+ using the provided conversion scripts and loading the TensorFlow model afterwards.
+ model_args (additional positional arguments, *optional*):
+ Will be passed along to the underlying model `__init__()` method.
+ config ([`PretrainedConfig`], *optional*):
+ Configuration for the model to use instead of an automatically loaded configuration. Configuration can
+ be automatically loaded when:
+
+ - The model is a model provided by the library (loaded with the *model id* string of a pretrained
+ model).
+ - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
+ save directory.
+ - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
+ configuration JSON file named *config.json* is found in the directory.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ from_pt (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a PyTorch checkpoint save file (see docstring of
+ `pretrained_model_name_or_path` argument).
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (e.g., not try downloading the model).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
+ execute code present on the Hub on your local machine.
+ code_revision (`str`, *optional*, defaults to `"main"`):
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than
+ the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
+ system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
+ allowed by git.
+ kwargs (additional keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
+ automatically loaded:
+
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
+ already been done)
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
+ initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
+ corresponds to a configuration attribute will be used to override said attribute with the
+ supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
+ will be passed to the underlying model's `__init__` function.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoConfig, BaseAutoModelClass
+
+ >>> # Download model and configuration from huggingface.co and cache.
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
+
+ >>> # Update configuration during loading
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
+ >>> model.config.output_attentions
+ True
+
+ >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
+ >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json")
+ >>> model = BaseAutoModelClass.from_pretrained(
+ ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config
+ ... )
+ ```
+"""
+
+FROM_PRETRAINED_FLAX_DOCSTRING = """
+ Instantiate one of the model classes of the library from a pretrained model.
+
+ The model class to instantiate is selected based on the `model_type` property of the config object (either
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
+ falling back to using pattern matching on `pretrained_model_name_or_path`:
+
+ List options
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
+ case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
+ argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
+ using the provided conversion scripts and loading the TensorFlow model afterwards.
+ model_args (additional positional arguments, *optional*):
+ Will be passed along to the underlying model `__init__()` method.
+ config ([`PretrainedConfig`], *optional*):
+ Configuration for the model to use instead of an automatically loaded configuration. Configuration can
+ be automatically loaded when:
+
+ - The model is a model provided by the library (loaded with the *model id* string of a pretrained
+ model).
+ - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
+ save directory.
+ - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
+ configuration JSON file named *config.json* is found in the directory.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ from_pt (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a PyTorch checkpoint save file (see docstring of
+ `pretrained_model_name_or_path` argument).
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (e.g., not try downloading the model).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
+ execute code present on the Hub on your local machine.
+ code_revision (`str`, *optional*, defaults to `"main"`):
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than
+ the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
+ system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
+ allowed by git.
+ kwargs (additional keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
+ automatically loaded:
+
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
+ already been done)
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
+ initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
+ corresponds to a configuration attribute will be used to override said attribute with the
+ supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
+ will be passed to the underlying model's `__init__` function.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoConfig, BaseAutoModelClass
+
+ >>> # Download model and configuration from huggingface.co and cache.
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
+
+ >>> # Update configuration during loading
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
+ >>> model.config.output_attentions
+ True
+
+ >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
+ >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json")
+ >>> model = BaseAutoModelClass.from_pretrained(
+ ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config
+ ... )
+ ```
+"""
+
+
+def _get_model_class(config, model_mapping):
+ supported_models = model_mapping[type(config)]
+ if not isinstance(supported_models, (list, tuple)):
+ return supported_models
+
+ name_to_model = {model.__name__: model for model in supported_models}
+ architectures = getattr(config, "architectures", [])
+ for arch in architectures:
+ if arch in name_to_model:
+ return name_to_model[arch]
+ elif f"TF{arch}" in name_to_model:
+ return name_to_model[f"TF{arch}"]
+ elif f"Flax{arch}" in name_to_model:
+ return name_to_model[f"Flax{arch}"]
+
+ # If not architecture is set in the config or match the supported models, the first element of the tuple is the
+ # defaults.
+ return supported_models[0]
+
+
+class _BaseAutoModelClass:
+ # Base class for auto models.
+ _model_mapping = None
+
+ def __init__(self, *args, **kwargs) -> None:
+ raise OSError(
+ f"{self.__class__.__name__} is designed to be instantiated "
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
+ f"`{self.__class__.__name__}.from_config(config)` methods."
+ )
+
+ @classmethod
+ def from_config(cls, config, **kwargs):
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
+ has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
+ has_local_code = type(config) in cls._model_mapping
+ if has_remote_code:
+ class_ref = config.auto_map[cls.__name__]
+ if "--" in class_ref:
+ upstream_repo = class_ref.split("--")[0]
+ else:
+ upstream_repo = None
+ trust_remote_code = resolve_trust_remote_code(
+ trust_remote_code, config._name_or_path, has_local_code, has_remote_code, upstream_repo=upstream_repo
+ )
+
+ if has_remote_code and trust_remote_code:
+ if "--" in class_ref:
+ repo_id, class_ref = class_ref.split("--")
+ else:
+ repo_id = config.name_or_path
+ model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
+ # This block handles the case where the user is loading a model with `trust_remote_code=True`
+ # but a library model exists with the same name. We don't want to override the autoclass
+ # mappings in this case, or all future loads of that model will be the remote code model.
+ if not has_local_code:
+ cls.register(config.__class__, model_class, exist_ok=True)
+ model_class.register_for_auto_class(auto_class=cls)
+ _ = kwargs.pop("code_revision", None)
+ model_class = add_generation_mixin_to_remote_model(model_class)
+ return model_class._from_config(config, **kwargs)
+ elif type(config) in cls._model_mapping:
+ model_class = _get_model_class(config, cls._model_mapping)
+ return model_class._from_config(config, **kwargs)
+
+ raise ValueError(
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
+ f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}."
+ )
+
+ @classmethod
+ def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
+ """Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses."""
+ return config
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], *model_args, **kwargs):
+ config = kwargs.pop("config", None)
+ trust_remote_code = kwargs.get("trust_remote_code")
+ kwargs["_from_auto"] = True
+ hub_kwargs_names = [
+ "cache_dir",
+ "force_download",
+ "local_files_only",
+ "proxies",
+ "resume_download",
+ "revision",
+ "subfolder",
+ "use_auth_token",
+ "token",
+ ]
+ hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
+ code_revision = kwargs.pop("code_revision", None)
+ commit_hash = kwargs.pop("_commit_hash", None)
+ adapter_kwargs = kwargs.pop("adapter_kwargs", None)
+
+ token = hub_kwargs.pop("token", None)
+ use_auth_token = hub_kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if token is not None:
+ hub_kwargs["token"] = token
+
+ if commit_hash is None:
+ if not isinstance(config, PretrainedConfig):
+ # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
+ resolved_config_file = cached_file(
+ pretrained_model_name_or_path,
+ CONFIG_NAME,
+ _raise_exceptions_for_gated_repo=False,
+ _raise_exceptions_for_missing_entries=False,
+ _raise_exceptions_for_connection_errors=False,
+ **hub_kwargs,
+ )
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
+ else:
+ commit_hash = getattr(config, "_commit_hash", None)
+
+ if is_peft_available():
+ if adapter_kwargs is None:
+ adapter_kwargs = {}
+ if token is not None:
+ adapter_kwargs["token"] = token
+
+ maybe_adapter_path = find_adapter_config_file(
+ pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs
+ )
+
+ if maybe_adapter_path is not None:
+ with open(maybe_adapter_path, "r", encoding="utf-8") as f:
+ adapter_config = json.load(f)
+
+ adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path
+ pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
+
+ if not isinstance(config, PretrainedConfig):
+ kwargs_orig = copy.deepcopy(kwargs)
+ # ensure not to pollute the config object with dtype="auto" - since it's
+ # meaningless in the context of the config object - torch.dtype values are acceptable
+ if kwargs.get("torch_dtype") == "auto":
+ _ = kwargs.pop("torch_dtype")
+ if kwargs.get("dtype") == "auto":
+ _ = kwargs.pop("dtype")
+ # to not overwrite the quantization_config if config has a quantization_config
+ if kwargs.get("quantization_config") is not None:
+ _ = kwargs.pop("quantization_config")
+
+ config, kwargs = AutoConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ return_unused_kwargs=True,
+ code_revision=code_revision,
+ _commit_hash=commit_hash,
+ **hub_kwargs,
+ **kwargs,
+ )
+
+ # if torch_dtype=auto was passed here, ensure to pass it on
+ if kwargs_orig.get("torch_dtype", None) == "auto":
+ kwargs["torch_dtype"] = "auto"
+ if kwargs_orig.get("dtype", None) == "auto":
+ kwargs["dtype"] = "auto"
+ if kwargs_orig.get("quantization_config", None) is not None:
+ kwargs["quantization_config"] = kwargs_orig["quantization_config"]
+
+ has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
+ has_local_code = type(config) in cls._model_mapping
+ upstream_repo = None
+ if has_remote_code:
+ class_ref = config.auto_map[cls.__name__]
+ if "--" in class_ref:
+ upstream_repo = class_ref.split("--")[0]
+ trust_remote_code = resolve_trust_remote_code(
+ trust_remote_code,
+ pretrained_model_name_or_path,
+ has_local_code,
+ has_remote_code,
+ upstream_repo=upstream_repo,
+ )
+ kwargs["trust_remote_code"] = trust_remote_code
+
+ # Set the adapter kwargs
+ kwargs["adapter_kwargs"] = adapter_kwargs
+
+ if has_remote_code and trust_remote_code:
+ model_class = get_class_from_dynamic_module(
+ class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
+ )
+ _ = hub_kwargs.pop("code_revision", None)
+ # This block handles the case where the user is loading a model with `trust_remote_code=True`
+ # but a library model exists with the same name. We don't want to override the autoclass
+ # mappings in this case, or all future loads of that model will be the remote code model.
+ if not has_local_code:
+ cls.register(config.__class__, model_class, exist_ok=True)
+ model_class.register_for_auto_class(auto_class=cls)
+ model_class = add_generation_mixin_to_remote_model(model_class)
+ return model_class.from_pretrained(
+ pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
+ )
+ elif type(config) in cls._model_mapping:
+ model_class = _get_model_class(config, cls._model_mapping)
+ if model_class.config_class == config.sub_configs.get("text_config", None):
+ config = config.get_text_config()
+ return model_class.from_pretrained(
+ pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
+ )
+ raise ValueError(
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
+ f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}."
+ )
+
+ @classmethod
+ def register(cls, config_class, model_class, exist_ok=False) -> None:
+ """
+ Register a new model for this class.
+
+ Args:
+ config_class ([`PretrainedConfig`]):
+ The configuration corresponding to the model to register.
+ model_class ([`PreTrainedModel`]):
+ The model to register.
+ """
+ if hasattr(model_class, "config_class") and model_class.config_class.__name__ != config_class.__name__:
+ raise ValueError(
+ "The model class you are passing has a `config_class` attribute that is not consistent with the "
+ f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix "
+ "one of those so they match!"
+ )
+ cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok)
+
+
+class _BaseAutoBackboneClass(_BaseAutoModelClass):
+ # Base class for auto backbone models.
+ _model_mapping = None
+
+ @classmethod
+ def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+ requires_backends(cls, ["vision", "timm"])
+ from ...models.timm_backbone import TimmBackboneConfig
+
+ config = kwargs.pop("config", TimmBackboneConfig())
+
+ if kwargs.get("out_features") is not None:
+ raise ValueError("Cannot specify `out_features` for timm backbones")
+
+ if kwargs.get("output_loading_info", False):
+ raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")
+
+ num_channels = kwargs.pop("num_channels", config.num_channels)
+ features_only = kwargs.pop("features_only", config.features_only)
+ use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
+ out_indices = kwargs.pop("out_indices", config.out_indices)
+ config = TimmBackboneConfig(
+ backbone=pretrained_model_name_or_path,
+ num_channels=num_channels,
+ features_only=features_only,
+ use_pretrained_backbone=use_pretrained_backbone,
+ out_indices=out_indices,
+ )
+ return super().from_config(config, **kwargs)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+ use_timm_backbone = kwargs.pop("use_timm_backbone", False)
+ if use_timm_backbone:
+ return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+
+
+def insert_head_doc(docstring, head_doc: str = ""):
+ if len(head_doc) > 0:
+ return docstring.replace(
+ "one of the model classes of the library ",
+ f"one of the model classes of the library (with a {head_doc} head) ",
+ )
+ return docstring.replace(
+ "one of the model classes of the library ", "one of the base model classes of the library "
+ )
+
+
+def auto_class_update(cls, checkpoint_for_example: str = "google-bert/bert-base-cased", head_doc: str = ""):
+ # Create a new class with the right name from the base class
+ model_mapping = cls._model_mapping
+ name = cls.__name__
+ class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
+ cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
+
+ # Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
+ # have a specific docstrings for them.
+ from_config = copy_func(_BaseAutoModelClass.from_config)
+ from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc)
+ from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
+ from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
+ from_config.__doc__ = from_config_docstring
+ from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
+ cls.from_config = classmethod(from_config)
+
+ if name.startswith("TF"):
+ from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
+ elif name.startswith("Flax"):
+ from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING
+ else:
+ from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING
+ from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)
+ from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc)
+ from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name)
+ from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
+ shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
+ from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
+ from_pretrained.__doc__ = from_pretrained_docstring
+ from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
+ cls.from_pretrained = classmethod(from_pretrained)
+ return cls
+
+
+def get_values(model_mapping):
+ result = []
+ for model in model_mapping.values():
+ if isinstance(model, (list, tuple)):
+ result += list(model)
+ else:
+ result.append(model)
+
+ return result
+
+
+def getattribute_from_module(module, attr):
+ if attr is None:
+ return None
+ if isinstance(attr, tuple):
+ return tuple(getattribute_from_module(module, a) for a in attr)
+ if hasattr(module, attr):
+ return getattr(module, attr)
+ # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
+ # object at the top level.
+ transformers_module = importlib.import_module("transformers")
+
+ if module != transformers_module:
+ try:
+ return getattribute_from_module(transformers_module, attr)
+ except ValueError:
+ raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!")
+ else:
+ raise ValueError(f"Could not find {attr} in {transformers_module}!")
+
+
+def add_generation_mixin_to_remote_model(model_class):
+ """
+ Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model.
+
+ This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make
+ `PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded
+ from the Hub may not have the `generate` method after we remove the inheritance.
+ """
+ # 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing
+ if "torch.nn.modules.module.Module" not in str(model_class.__mro__):
+ return model_class
+
+ # 2. If it already **directly** inherits from GenerationMixin, do nothing
+ if "GenerationMixin" in str(model_class.__bases__):
+ return model_class
+
+ # 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
+ # `prepare_inputs_for_generation` method.
+ has_custom_generate_in_class = hasattr(model_class, "generate") and "GenerationMixin" not in str(
+ getattr(model_class, "generate")
+ )
+ has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str(
+ getattr(model_class, "prepare_inputs_for_generation")
+ )
+ if has_custom_generate_in_class or has_custom_prepare_inputs:
+ model_class_with_generation_mixin = type(
+ model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
+ )
+ return model_class_with_generation_mixin
+ return model_class
+
+
+class _LazyAutoMapping(OrderedDict[type[PretrainedConfig], _LazyAutoMappingValue]):
+ """
+ " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
+
+ Args:
+ - config_mapping: The map model type to config class
+ - model_mapping: The map model type to model (or tokenizer) class
+ """
+
+ def __init__(self, config_mapping, model_mapping) -> None:
+ self._config_mapping = config_mapping
+ self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
+ self._model_mapping = model_mapping
+ self._model_mapping._model_mapping = self
+ self._extra_content = {}
+ self._modules = {}
+
+ def __len__(self) -> int:
+ common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys())
+ return len(common_keys) + len(self._extra_content)
+
+ def __getitem__(self, key: type[PretrainedConfig]) -> _LazyAutoMappingValue:
+ if key in self._extra_content:
+ return self._extra_content[key]
+ model_type = self._reverse_config_mapping[key.__name__]
+ if model_type in self._model_mapping:
+ model_name = self._model_mapping[model_type]
+ return self._load_attr_from_module(model_type, model_name)
+
+ # Maybe there was several model types associated with this config.
+ model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
+ for mtype in model_types:
+ if mtype in self._model_mapping:
+ model_name = self._model_mapping[mtype]
+ return self._load_attr_from_module(mtype, model_name)
+ raise KeyError(key)
+
+ def _load_attr_from_module(self, model_type, attr):
+ module_name = model_type_to_module_name(model_type)
+ if module_name not in self._modules:
+ self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
+ return getattribute_from_module(self._modules[module_name], attr)
+
+ def keys(self) -> list[type[PretrainedConfig]]:
+ mapping_keys = [
+ self._load_attr_from_module(key, name)
+ for key, name in self._config_mapping.items()
+ if key in self._model_mapping
+ ]
+ return mapping_keys + list(self._extra_content.keys())
+
+ def get(self, key: type[PretrainedConfig], default: _T) -> Union[_LazyAutoMappingValue, _T]:
+ try:
+ return self.__getitem__(key)
+ except KeyError:
+ return default
+
+ def __bool__(self) -> bool:
+ return bool(self.keys())
+
+ def values(self) -> list[_LazyAutoMappingValue]:
+ mapping_values = [
+ self._load_attr_from_module(key, name)
+ for key, name in self._model_mapping.items()
+ if key in self._config_mapping
+ ]
+ return mapping_values + list(self._extra_content.values())
+
+ def items(self) -> list[tuple[type[PretrainedConfig], _LazyAutoMappingValue]]:
+ mapping_items = [
+ (
+ self._load_attr_from_module(key, self._config_mapping[key]),
+ self._load_attr_from_module(key, self._model_mapping[key]),
+ )
+ for key in self._model_mapping
+ if key in self._config_mapping
+ ]
+ return mapping_items + list(self._extra_content.items())
+
+ def __iter__(self) -> Iterator[type[PretrainedConfig]]:
+ return iter(self.keys())
+
+ def __contains__(self, item: type) -> bool:
+ if item in self._extra_content:
+ return True
+ if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
+ return False
+ model_type = self._reverse_config_mapping[item.__name__]
+ return model_type in self._model_mapping
+
+ def register(self, key: type[PretrainedConfig], value: _LazyAutoMappingValue, exist_ok=False) -> None:
+ """
+ Register a new model in this mapping.
+ """
+ if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
+ model_type = self._reverse_config_mapping[key.__name__]
+ if model_type in self._model_mapping and not exist_ok:
+ raise ValueError(f"'{key}' is already used by a Transformers model.")
+
+ self._extra_content[key] = value
+
+
+__all__ = ["get_values"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/auto/configuration_auto.py b/venv/lib/python3.13/site-packages/transformers/models/auto/configuration_auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6a12e7cef986fb837abbfee0cc81b64b7148b50
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/auto/configuration_auto.py
@@ -0,0 +1,1404 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Auto Config class."""
+
+import importlib
+import os
+import re
+import warnings
+from collections import OrderedDict
+from collections.abc import Callable, Iterator, KeysView, ValuesView
+from typing import Any, TypeVar, Union
+
+from ...configuration_utils import PretrainedConfig
+from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
+from ...utils import CONFIG_NAME, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+_CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
+
+
+CONFIG_MAPPING_NAMES = OrderedDict[str, str](
+ [
+ # Add configs here
+ ("aimv2", "Aimv2Config"),
+ ("aimv2_vision_model", "Aimv2VisionConfig"),
+ ("albert", "AlbertConfig"),
+ ("align", "AlignConfig"),
+ ("altclip", "AltCLIPConfig"),
+ ("apertus", "ApertusConfig"),
+ ("arcee", "ArceeConfig"),
+ ("aria", "AriaConfig"),
+ ("aria_text", "AriaTextConfig"),
+ ("audio-spectrogram-transformer", "ASTConfig"),
+ ("autoformer", "AutoformerConfig"),
+ ("aya_vision", "AyaVisionConfig"),
+ ("bamba", "BambaConfig"),
+ ("bark", "BarkConfig"),
+ ("bart", "BartConfig"),
+ ("beit", "BeitConfig"),
+ ("bert", "BertConfig"),
+ ("bert-generation", "BertGenerationConfig"),
+ ("big_bird", "BigBirdConfig"),
+ ("bigbird_pegasus", "BigBirdPegasusConfig"),
+ ("biogpt", "BioGptConfig"),
+ ("bit", "BitConfig"),
+ ("bitnet", "BitNetConfig"),
+ ("blenderbot", "BlenderbotConfig"),
+ ("blenderbot-small", "BlenderbotSmallConfig"),
+ ("blip", "BlipConfig"),
+ ("blip-2", "Blip2Config"),
+ ("blip_2_qformer", "Blip2QFormerConfig"),
+ ("bloom", "BloomConfig"),
+ ("blt", "BltConfig"),
+ ("bridgetower", "BridgeTowerConfig"),
+ ("bros", "BrosConfig"),
+ ("camembert", "CamembertConfig"),
+ ("canine", "CanineConfig"),
+ ("chameleon", "ChameleonConfig"),
+ ("chinese_clip", "ChineseCLIPConfig"),
+ ("chinese_clip_vision_model", "ChineseCLIPVisionConfig"),
+ ("clap", "ClapConfig"),
+ ("clip", "CLIPConfig"),
+ ("clip_text_model", "CLIPTextConfig"),
+ ("clip_vision_model", "CLIPVisionConfig"),
+ ("clipseg", "CLIPSegConfig"),
+ ("clvp", "ClvpConfig"),
+ ("code_llama", "LlamaConfig"),
+ ("codegen", "CodeGenConfig"),
+ ("cohere", "CohereConfig"),
+ ("cohere2", "Cohere2Config"),
+ ("cohere2_vision", "Cohere2VisionConfig"),
+ ("colpali", "ColPaliConfig"),
+ ("colqwen2", "ColQwen2Config"),
+ ("conditional_detr", "ConditionalDetrConfig"),
+ ("convbert", "ConvBertConfig"),
+ ("convnext", "ConvNextConfig"),
+ ("convnextv2", "ConvNextV2Config"),
+ ("cpmant", "CpmAntConfig"),
+ ("csm", "CsmConfig"),
+ ("ctrl", "CTRLConfig"),
+ ("cvt", "CvtConfig"),
+ ("d_fine", "DFineConfig"),
+ ("dab-detr", "DabDetrConfig"),
+ ("dac", "DacConfig"),
+ ("data2vec-audio", "Data2VecAudioConfig"),
+ ("data2vec-text", "Data2VecTextConfig"),
+ ("data2vec-vision", "Data2VecVisionConfig"),
+ ("dbrx", "DbrxConfig"),
+ ("deberta", "DebertaConfig"),
+ ("deberta-v2", "DebertaV2Config"),
+ ("decision_transformer", "DecisionTransformerConfig"),
+ ("deepseek_v2", "DeepseekV2Config"),
+ ("deepseek_v3", "DeepseekV3Config"),
+ ("deepseek_vl", "DeepseekVLConfig"),
+ ("deepseek_vl_hybrid", "DeepseekVLHybridConfig"),
+ ("deformable_detr", "DeformableDetrConfig"),
+ ("deit", "DeiTConfig"),
+ ("depth_anything", "DepthAnythingConfig"),
+ ("depth_pro", "DepthProConfig"),
+ ("deta", "DetaConfig"),
+ ("detr", "DetrConfig"),
+ ("dia", "DiaConfig"),
+ ("diffllama", "DiffLlamaConfig"),
+ ("dinat", "DinatConfig"),
+ ("dinov2", "Dinov2Config"),
+ ("dinov2_with_registers", "Dinov2WithRegistersConfig"),
+ ("dinov3_convnext", "DINOv3ConvNextConfig"),
+ ("dinov3_vit", "DINOv3ViTConfig"),
+ ("distilbert", "DistilBertConfig"),
+ ("doge", "DogeConfig"),
+ ("donut-swin", "DonutSwinConfig"),
+ ("dots1", "Dots1Config"),
+ ("dpr", "DPRConfig"),
+ ("dpt", "DPTConfig"),
+ ("edgetam", "EdgeTamConfig"),
+ ("edgetam_video", "EdgeTamVideoConfig"),
+ ("edgetam_vision_model", "EdgeTamVisionConfig"),
+ ("efficientformer", "EfficientFormerConfig"),
+ ("efficientloftr", "EfficientLoFTRConfig"),
+ ("efficientnet", "EfficientNetConfig"),
+ ("electra", "ElectraConfig"),
+ ("emu3", "Emu3Config"),
+ ("encodec", "EncodecConfig"),
+ ("encoder-decoder", "EncoderDecoderConfig"),
+ ("eomt", "EomtConfig"),
+ ("ernie", "ErnieConfig"),
+ ("ernie4_5", "Ernie4_5Config"),
+ ("ernie4_5_moe", "Ernie4_5_MoeConfig"),
+ ("ernie_m", "ErnieMConfig"),
+ ("esm", "EsmConfig"),
+ ("evolla", "EvollaConfig"),
+ ("exaone4", "Exaone4Config"),
+ ("falcon", "FalconConfig"),
+ ("falcon_h1", "FalconH1Config"),
+ ("falcon_mamba", "FalconMambaConfig"),
+ ("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
+ ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGanConfig"),
+ ("flaubert", "FlaubertConfig"),
+ ("flava", "FlavaConfig"),
+ ("flex_olmo", "FlexOlmoConfig"),
+ ("florence2", "Florence2Config"),
+ ("fnet", "FNetConfig"),
+ ("focalnet", "FocalNetConfig"),
+ ("fsmt", "FSMTConfig"),
+ ("funnel", "FunnelConfig"),
+ ("fuyu", "FuyuConfig"),
+ ("gemma", "GemmaConfig"),
+ ("gemma2", "Gemma2Config"),
+ ("gemma3", "Gemma3Config"),
+ ("gemma3_text", "Gemma3TextConfig"),
+ ("gemma3n", "Gemma3nConfig"),
+ ("gemma3n_audio", "Gemma3nAudioConfig"),
+ ("gemma3n_text", "Gemma3nTextConfig"),
+ ("gemma3n_vision", "Gemma3nVisionConfig"),
+ ("git", "GitConfig"),
+ ("glm", "GlmConfig"),
+ ("glm4", "Glm4Config"),
+ ("glm4_moe", "Glm4MoeConfig"),
+ ("glm4v", "Glm4vConfig"),
+ ("glm4v_moe", "Glm4vMoeConfig"),
+ ("glm4v_moe_text", "Glm4vMoeTextConfig"),
+ ("glm4v_text", "Glm4vTextConfig"),
+ ("glpn", "GLPNConfig"),
+ ("got_ocr2", "GotOcr2Config"),
+ ("gpt-sw3", "GPT2Config"),
+ ("gpt2", "GPT2Config"),
+ ("gpt_bigcode", "GPTBigCodeConfig"),
+ ("gpt_neo", "GPTNeoConfig"),
+ ("gpt_neox", "GPTNeoXConfig"),
+ ("gpt_neox_japanese", "GPTNeoXJapaneseConfig"),
+ ("gpt_oss", "GptOssConfig"),
+ ("gptj", "GPTJConfig"),
+ ("gptsan-japanese", "GPTSanJapaneseConfig"),
+ ("granite", "GraniteConfig"),
+ ("granite_speech", "GraniteSpeechConfig"),
+ ("granitemoe", "GraniteMoeConfig"),
+ ("granitemoehybrid", "GraniteMoeHybridConfig"),
+ ("granitemoeshared", "GraniteMoeSharedConfig"),
+ ("granitevision", "LlavaNextConfig"),
+ ("graphormer", "GraphormerConfig"),
+ ("grounding-dino", "GroundingDinoConfig"),
+ ("groupvit", "GroupViTConfig"),
+ ("helium", "HeliumConfig"),
+ ("hgnet_v2", "HGNetV2Config"),
+ ("hiera", "HieraConfig"),
+ ("hubert", "HubertConfig"),
+ ("hunyuan_v1_dense", "HunYuanDenseV1Config"),
+ ("hunyuan_v1_moe", "HunYuanMoEV1Config"),
+ ("ibert", "IBertConfig"),
+ ("idefics", "IdeficsConfig"),
+ ("idefics2", "Idefics2Config"),
+ ("idefics3", "Idefics3Config"),
+ ("idefics3_vision", "Idefics3VisionConfig"),
+ ("ijepa", "IJepaConfig"),
+ ("imagegpt", "ImageGPTConfig"),
+ ("informer", "InformerConfig"),
+ ("instructblip", "InstructBlipConfig"),
+ ("instructblipvideo", "InstructBlipVideoConfig"),
+ ("internvl", "InternVLConfig"),
+ ("internvl_vision", "InternVLVisionConfig"),
+ ("jamba", "JambaConfig"),
+ ("janus", "JanusConfig"),
+ ("jetmoe", "JetMoeConfig"),
+ ("jukebox", "JukeboxConfig"),
+ ("kosmos-2", "Kosmos2Config"),
+ ("kosmos-2.5", "Kosmos2_5Config"),
+ ("kyutai_speech_to_text", "KyutaiSpeechToTextConfig"),
+ ("layoutlm", "LayoutLMConfig"),
+ ("layoutlmv2", "LayoutLMv2Config"),
+ ("layoutlmv3", "LayoutLMv3Config"),
+ ("led", "LEDConfig"),
+ ("levit", "LevitConfig"),
+ ("lfm2", "Lfm2Config"),
+ ("lfm2_vl", "Lfm2VlConfig"),
+ ("lightglue", "LightGlueConfig"),
+ ("lilt", "LiltConfig"),
+ ("llama", "LlamaConfig"),
+ ("llama4", "Llama4Config"),
+ ("llama4_text", "Llama4TextConfig"),
+ ("llava", "LlavaConfig"),
+ ("llava_next", "LlavaNextConfig"),
+ ("llava_next_video", "LlavaNextVideoConfig"),
+ ("llava_onevision", "LlavaOnevisionConfig"),
+ ("longcat_flash", "LongcatFlashConfig"),
+ ("longformer", "LongformerConfig"),
+ ("longt5", "LongT5Config"),
+ ("luke", "LukeConfig"),
+ ("lxmert", "LxmertConfig"),
+ ("m2m_100", "M2M100Config"),
+ ("mamba", "MambaConfig"),
+ ("mamba2", "Mamba2Config"),
+ ("marian", "MarianConfig"),
+ ("markuplm", "MarkupLMConfig"),
+ ("mask2former", "Mask2FormerConfig"),
+ ("maskformer", "MaskFormerConfig"),
+ ("maskformer-swin", "MaskFormerSwinConfig"),
+ ("mbart", "MBartConfig"),
+ ("mctct", "MCTCTConfig"),
+ ("mega", "MegaConfig"),
+ ("megatron-bert", "MegatronBertConfig"),
+ ("metaclip_2", "MetaClip2Config"),
+ ("mgp-str", "MgpstrConfig"),
+ ("mimi", "MimiConfig"),
+ ("minimax", "MiniMaxConfig"),
+ ("ministral", "MinistralConfig"),
+ ("mistral", "MistralConfig"),
+ ("mistral3", "Mistral3Config"),
+ ("mixtral", "MixtralConfig"),
+ ("mlcd", "MLCDVisionConfig"),
+ ("mllama", "MllamaConfig"),
+ ("mm-grounding-dino", "MMGroundingDinoConfig"),
+ ("mobilebert", "MobileBertConfig"),
+ ("mobilenet_v1", "MobileNetV1Config"),
+ ("mobilenet_v2", "MobileNetV2Config"),
+ ("mobilevit", "MobileViTConfig"),
+ ("mobilevitv2", "MobileViTV2Config"),
+ ("modernbert", "ModernBertConfig"),
+ ("modernbert-decoder", "ModernBertDecoderConfig"),
+ ("moonshine", "MoonshineConfig"),
+ ("moshi", "MoshiConfig"),
+ ("mpnet", "MPNetConfig"),
+ ("mpt", "MptConfig"),
+ ("mra", "MraConfig"),
+ ("mt5", "MT5Config"),
+ ("musicgen", "MusicgenConfig"),
+ ("musicgen_melody", "MusicgenMelodyConfig"),
+ ("mvp", "MvpConfig"),
+ ("nat", "NatConfig"),
+ ("nemotron", "NemotronConfig"),
+ ("nezha", "NezhaConfig"),
+ ("nllb-moe", "NllbMoeConfig"),
+ ("nougat", "VisionEncoderDecoderConfig"),
+ ("nystromformer", "NystromformerConfig"),
+ ("olmo", "OlmoConfig"),
+ ("olmo2", "Olmo2Config"),
+ ("olmo3", "Olmo3Config"),
+ ("olmoe", "OlmoeConfig"),
+ ("omdet-turbo", "OmDetTurboConfig"),
+ ("oneformer", "OneFormerConfig"),
+ ("open-llama", "OpenLlamaConfig"),
+ ("openai-gpt", "OpenAIGPTConfig"),
+ ("opt", "OPTConfig"),
+ ("ovis2", "Ovis2Config"),
+ ("owlv2", "Owlv2Config"),
+ ("owlvit", "OwlViTConfig"),
+ ("paligemma", "PaliGemmaConfig"),
+ ("parakeet_ctc", "ParakeetCTCConfig"),
+ ("parakeet_encoder", "ParakeetEncoderConfig"),
+ ("patchtsmixer", "PatchTSMixerConfig"),
+ ("patchtst", "PatchTSTConfig"),
+ ("pegasus", "PegasusConfig"),
+ ("pegasus_x", "PegasusXConfig"),
+ ("perceiver", "PerceiverConfig"),
+ ("perception_encoder", "TimmWrapperConfig"),
+ ("perception_lm", "PerceptionLMConfig"),
+ ("persimmon", "PersimmonConfig"),
+ ("phi", "PhiConfig"),
+ ("phi3", "Phi3Config"),
+ ("phi4_multimodal", "Phi4MultimodalConfig"),
+ ("phimoe", "PhimoeConfig"),
+ ("pix2struct", "Pix2StructConfig"),
+ ("pixtral", "PixtralVisionConfig"),
+ ("plbart", "PLBartConfig"),
+ ("poolformer", "PoolFormerConfig"),
+ ("pop2piano", "Pop2PianoConfig"),
+ ("prompt_depth_anything", "PromptDepthAnythingConfig"),
+ ("prophetnet", "ProphetNetConfig"),
+ ("pvt", "PvtConfig"),
+ ("pvt_v2", "PvtV2Config"),
+ ("qdqbert", "QDQBertConfig"),
+ ("qwen2", "Qwen2Config"),
+ ("qwen2_5_omni", "Qwen2_5OmniConfig"),
+ ("qwen2_5_vl", "Qwen2_5_VLConfig"),
+ ("qwen2_5_vl_text", "Qwen2_5_VLTextConfig"),
+ ("qwen2_audio", "Qwen2AudioConfig"),
+ ("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"),
+ ("qwen2_moe", "Qwen2MoeConfig"),
+ ("qwen2_vl", "Qwen2VLConfig"),
+ ("qwen2_vl_text", "Qwen2VLTextConfig"),
+ ("qwen3", "Qwen3Config"),
+ ("qwen3_moe", "Qwen3MoeConfig"),
+ ("qwen3_next", "Qwen3NextConfig"),
+ ("qwen3_omni_moe", "Qwen3OmniMoeConfig"),
+ ("qwen3_vl", "Qwen3VLConfig"),
+ ("qwen3_vl_moe", "Qwen3VLMoeConfig"),
+ ("qwen3_vl_moe_text", "Qwen3VLMoeTextConfig"),
+ ("qwen3_vl_text", "Qwen3VLTextConfig"),
+ ("rag", "RagConfig"),
+ ("realm", "RealmConfig"),
+ ("recurrent_gemma", "RecurrentGemmaConfig"),
+ ("reformer", "ReformerConfig"),
+ ("regnet", "RegNetConfig"),
+ ("rembert", "RemBertConfig"),
+ ("resnet", "ResNetConfig"),
+ ("retribert", "RetriBertConfig"),
+ ("roberta", "RobertaConfig"),
+ ("roberta-prelayernorm", "RobertaPreLayerNormConfig"),
+ ("roc_bert", "RoCBertConfig"),
+ ("roformer", "RoFormerConfig"),
+ ("rt_detr", "RTDetrConfig"),
+ ("rt_detr_resnet", "RTDetrResNetConfig"),
+ ("rt_detr_v2", "RTDetrV2Config"),
+ ("rwkv", "RwkvConfig"),
+ ("sam", "SamConfig"),
+ ("sam2", "Sam2Config"),
+ ("sam2_hiera_det_model", "Sam2HieraDetConfig"),
+ ("sam2_video", "Sam2VideoConfig"),
+ ("sam2_vision_model", "Sam2VisionConfig"),
+ ("sam_hq", "SamHQConfig"),
+ ("sam_hq_vision_model", "SamHQVisionConfig"),
+ ("sam_vision_model", "SamVisionConfig"),
+ ("seamless_m4t", "SeamlessM4TConfig"),
+ ("seamless_m4t_v2", "SeamlessM4Tv2Config"),
+ ("seed_oss", "SeedOssConfig"),
+ ("segformer", "SegformerConfig"),
+ ("seggpt", "SegGptConfig"),
+ ("sew", "SEWConfig"),
+ ("sew-d", "SEWDConfig"),
+ ("shieldgemma2", "ShieldGemma2Config"),
+ ("siglip", "SiglipConfig"),
+ ("siglip2", "Siglip2Config"),
+ ("siglip2_vision_model", "Siglip2VisionConfig"),
+ ("siglip_vision_model", "SiglipVisionConfig"),
+ ("smollm3", "SmolLM3Config"),
+ ("smolvlm", "SmolVLMConfig"),
+ ("smolvlm_vision", "SmolVLMVisionConfig"),
+ ("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
+ ("speech_to_text", "Speech2TextConfig"),
+ ("speech_to_text_2", "Speech2Text2Config"),
+ ("speecht5", "SpeechT5Config"),
+ ("splinter", "SplinterConfig"),
+ ("squeezebert", "SqueezeBertConfig"),
+ ("stablelm", "StableLmConfig"),
+ ("starcoder2", "Starcoder2Config"),
+ ("superglue", "SuperGlueConfig"),
+ ("superpoint", "SuperPointConfig"),
+ ("swiftformer", "SwiftFormerConfig"),
+ ("swin", "SwinConfig"),
+ ("swin2sr", "Swin2SRConfig"),
+ ("swinv2", "Swinv2Config"),
+ ("switch_transformers", "SwitchTransformersConfig"),
+ ("t5", "T5Config"),
+ ("t5gemma", "T5GemmaConfig"),
+ ("table-transformer", "TableTransformerConfig"),
+ ("tapas", "TapasConfig"),
+ ("textnet", "TextNetConfig"),
+ ("time_series_transformer", "TimeSeriesTransformerConfig"),
+ ("timesfm", "TimesFmConfig"),
+ ("timesformer", "TimesformerConfig"),
+ ("timm_backbone", "TimmBackboneConfig"),
+ ("timm_wrapper", "TimmWrapperConfig"),
+ ("trajectory_transformer", "TrajectoryTransformerConfig"),
+ ("transfo-xl", "TransfoXLConfig"),
+ ("trocr", "TrOCRConfig"),
+ ("tvlt", "TvltConfig"),
+ ("tvp", "TvpConfig"),
+ ("udop", "UdopConfig"),
+ ("umt5", "UMT5Config"),
+ ("unispeech", "UniSpeechConfig"),
+ ("unispeech-sat", "UniSpeechSatConfig"),
+ ("univnet", "UnivNetConfig"),
+ ("upernet", "UperNetConfig"),
+ ("van", "VanConfig"),
+ ("vaultgemma", "VaultGemmaConfig"),
+ ("video_llava", "VideoLlavaConfig"),
+ ("videomae", "VideoMAEConfig"),
+ ("vilt", "ViltConfig"),
+ ("vipllava", "VipLlavaConfig"),
+ ("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
+ ("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
+ ("visual_bert", "VisualBertConfig"),
+ ("vit", "ViTConfig"),
+ ("vit_hybrid", "ViTHybridConfig"),
+ ("vit_mae", "ViTMAEConfig"),
+ ("vit_msn", "ViTMSNConfig"),
+ ("vitdet", "VitDetConfig"),
+ ("vitmatte", "VitMatteConfig"),
+ ("vitpose", "VitPoseConfig"),
+ ("vitpose_backbone", "VitPoseBackboneConfig"),
+ ("vits", "VitsConfig"),
+ ("vivit", "VivitConfig"),
+ ("vjepa2", "VJEPA2Config"),
+ ("voxtral", "VoxtralConfig"),
+ ("voxtral_encoder", "VoxtralEncoderConfig"),
+ ("wav2vec2", "Wav2Vec2Config"),
+ ("wav2vec2-bert", "Wav2Vec2BertConfig"),
+ ("wav2vec2-conformer", "Wav2Vec2ConformerConfig"),
+ ("wavlm", "WavLMConfig"),
+ ("whisper", "WhisperConfig"),
+ ("xclip", "XCLIPConfig"),
+ ("xcodec", "XcodecConfig"),
+ ("xglm", "XGLMConfig"),
+ ("xlm", "XLMConfig"),
+ ("xlm-prophetnet", "XLMProphetNetConfig"),
+ ("xlm-roberta", "XLMRobertaConfig"),
+ ("xlm-roberta-xl", "XLMRobertaXLConfig"),
+ ("xlnet", "XLNetConfig"),
+ ("xlstm", "xLSTMConfig"),
+ ("xmod", "XmodConfig"),
+ ("yolos", "YolosConfig"),
+ ("yoso", "YosoConfig"),
+ ("zamba", "ZambaConfig"),
+ ("zamba2", "Zamba2Config"),
+ ("zoedepth", "ZoeDepthConfig"),
+ ]
+)
+
+
+MODEL_NAMES_MAPPING = OrderedDict[str, str](
+ [
+ # Add full (and cased) model names here
+ ("aimv2", "AIMv2"),
+ ("aimv2_vision_model", "Aimv2VisionModel"),
+ ("albert", "ALBERT"),
+ ("align", "ALIGN"),
+ ("altclip", "AltCLIP"),
+ ("apertus", "Apertus"),
+ ("arcee", "Arcee"),
+ ("aria", "Aria"),
+ ("aria_text", "AriaText"),
+ ("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
+ ("autoformer", "Autoformer"),
+ ("aya_vision", "AyaVision"),
+ ("bamba", "Bamba"),
+ ("bark", "Bark"),
+ ("bart", "BART"),
+ ("barthez", "BARThez"),
+ ("bartpho", "BARTpho"),
+ ("beit", "BEiT"),
+ ("bert", "BERT"),
+ ("bert-generation", "Bert Generation"),
+ ("bert-japanese", "BertJapanese"),
+ ("bertweet", "BERTweet"),
+ ("big_bird", "BigBird"),
+ ("bigbird_pegasus", "BigBird-Pegasus"),
+ ("biogpt", "BioGpt"),
+ ("bit", "BiT"),
+ ("bitnet", "BitNet"),
+ ("blenderbot", "Blenderbot"),
+ ("blenderbot-small", "BlenderbotSmall"),
+ ("blip", "BLIP"),
+ ("blip-2", "BLIP-2"),
+ ("blip_2_qformer", "BLIP-2 QFormer"),
+ ("bloom", "BLOOM"),
+ ("blt", "Blt"),
+ ("bort", "BORT"),
+ ("bridgetower", "BridgeTower"),
+ ("bros", "BROS"),
+ ("byt5", "ByT5"),
+ ("camembert", "CamemBERT"),
+ ("canine", "CANINE"),
+ ("chameleon", "Chameleon"),
+ ("chinese_clip", "Chinese-CLIP"),
+ ("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
+ ("clap", "CLAP"),
+ ("clip", "CLIP"),
+ ("clip_text_model", "CLIPTextModel"),
+ ("clip_vision_model", "CLIPVisionModel"),
+ ("clipseg", "CLIPSeg"),
+ ("clvp", "CLVP"),
+ ("code_llama", "CodeLlama"),
+ ("codegen", "CodeGen"),
+ ("cohere", "Cohere"),
+ ("cohere2", "Cohere2"),
+ ("cohere2_vision", "Cohere2Vision"),
+ ("colpali", "ColPali"),
+ ("colqwen2", "ColQwen2"),
+ ("conditional_detr", "Conditional DETR"),
+ ("convbert", "ConvBERT"),
+ ("convnext", "ConvNeXT"),
+ ("convnextv2", "ConvNeXTV2"),
+ ("cpm", "CPM"),
+ ("cpmant", "CPM-Ant"),
+ ("csm", "CSM"),
+ ("ctrl", "CTRL"),
+ ("cvt", "CvT"),
+ ("d_fine", "D-FINE"),
+ ("dab-detr", "DAB-DETR"),
+ ("dac", "DAC"),
+ ("data2vec-audio", "Data2VecAudio"),
+ ("data2vec-text", "Data2VecText"),
+ ("data2vec-vision", "Data2VecVision"),
+ ("dbrx", "DBRX"),
+ ("deberta", "DeBERTa"),
+ ("deberta-v2", "DeBERTa-v2"),
+ ("decision_transformer", "Decision Transformer"),
+ ("deepseek_v2", "DeepSeek-V2"),
+ ("deepseek_v3", "DeepSeek-V3"),
+ ("deepseek_vl", "DeepseekVL"),
+ ("deepseek_vl_hybrid", "DeepseekVLHybrid"),
+ ("deformable_detr", "Deformable DETR"),
+ ("deit", "DeiT"),
+ ("deplot", "DePlot"),
+ ("depth_anything", "Depth Anything"),
+ ("depth_anything_v2", "Depth Anything V2"),
+ ("depth_pro", "DepthPro"),
+ ("deta", "DETA"),
+ ("detr", "DETR"),
+ ("dia", "Dia"),
+ ("dialogpt", "DialoGPT"),
+ ("diffllama", "DiffLlama"),
+ ("dinat", "DiNAT"),
+ ("dinov2", "DINOv2"),
+ ("dinov2_with_registers", "DINOv2 with Registers"),
+ ("dinov3_convnext", "DINOv3 ConvNext"),
+ ("dinov3_vit", "DINOv3 ViT"),
+ ("distilbert", "DistilBERT"),
+ ("dit", "DiT"),
+ ("doge", "Doge"),
+ ("donut-swin", "DonutSwin"),
+ ("dots1", "dots1"),
+ ("dpr", "DPR"),
+ ("dpt", "DPT"),
+ ("edgetam", "EdgeTAM"),
+ ("edgetam_video", "EdgeTamVideo"),
+ ("edgetam_vision_model", "EdgeTamVisionModel"),
+ ("efficientformer", "EfficientFormer"),
+ ("efficientloftr", "EfficientLoFTR"),
+ ("efficientnet", "EfficientNet"),
+ ("electra", "ELECTRA"),
+ ("emu3", "Emu3"),
+ ("encodec", "EnCodec"),
+ ("encoder-decoder", "Encoder decoder"),
+ ("eomt", "EoMT"),
+ ("ernie", "ERNIE"),
+ ("ernie4_5", "Ernie4_5"),
+ ("ernie4_5_moe", "Ernie4_5_MoE"),
+ ("ernie_m", "ErnieM"),
+ ("esm", "ESM"),
+ ("evolla", "Evolla"),
+ ("exaone4", "EXAONE-4.0"),
+ ("falcon", "Falcon"),
+ ("falcon3", "Falcon3"),
+ ("falcon_h1", "FalconH1"),
+ ("falcon_mamba", "FalconMamba"),
+ ("fastspeech2_conformer", "FastSpeech2Conformer"),
+ ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
+ ("flan-t5", "FLAN-T5"),
+ ("flan-ul2", "FLAN-UL2"),
+ ("flaubert", "FlauBERT"),
+ ("flava", "FLAVA"),
+ ("flex_olmo", "FlexOlmo"),
+ ("florence2", "Florence2"),
+ ("fnet", "FNet"),
+ ("focalnet", "FocalNet"),
+ ("fsmt", "FairSeq Machine-Translation"),
+ ("funnel", "Funnel Transformer"),
+ ("fuyu", "Fuyu"),
+ ("gemma", "Gemma"),
+ ("gemma2", "Gemma2"),
+ ("gemma3", "Gemma3ForConditionalGeneration"),
+ ("gemma3_text", "Gemma3ForCausalLM"),
+ ("gemma3n", "Gemma3nForConditionalGeneration"),
+ ("gemma3n_audio", "Gemma3nAudioEncoder"),
+ ("gemma3n_text", "Gemma3nForCausalLM"),
+ ("gemma3n_vision", "TimmWrapperModel"),
+ ("git", "GIT"),
+ ("glm", "GLM"),
+ ("glm4", "GLM4"),
+ ("glm4_moe", "Glm4MoE"),
+ ("glm4v", "GLM4V"),
+ ("glm4v_moe", "GLM4VMOE"),
+ ("glm4v_moe_text", "GLM4VMOE"),
+ ("glm4v_text", "GLM4V"),
+ ("glpn", "GLPN"),
+ ("got_ocr2", "GOT-OCR2"),
+ ("gpt-sw3", "GPT-Sw3"),
+ ("gpt2", "OpenAI GPT-2"),
+ ("gpt_bigcode", "GPTBigCode"),
+ ("gpt_neo", "GPT Neo"),
+ ("gpt_neox", "GPT NeoX"),
+ ("gpt_neox_japanese", "GPT NeoX Japanese"),
+ ("gpt_oss", "GptOss"),
+ ("gptj", "GPT-J"),
+ ("gptsan-japanese", "GPTSAN-japanese"),
+ ("granite", "Granite"),
+ ("granite_speech", "GraniteSpeech"),
+ ("granitemoe", "GraniteMoeMoe"),
+ ("granitemoehybrid", "GraniteMoeHybrid"),
+ ("granitemoeshared", "GraniteMoeSharedMoe"),
+ ("granitevision", "LLaVA-NeXT"),
+ ("graphormer", "Graphormer"),
+ ("grounding-dino", "Grounding DINO"),
+ ("groupvit", "GroupViT"),
+ ("helium", "Helium"),
+ ("herbert", "HerBERT"),
+ ("hgnet_v2", "HGNet-V2"),
+ ("hiera", "Hiera"),
+ ("hubert", "Hubert"),
+ ("hunyuan_v1_dense", "HunYuanDenseV1"),
+ ("hunyuan_v1_moe", "HunYuanMoeV1"),
+ ("ibert", "I-BERT"),
+ ("idefics", "IDEFICS"),
+ ("idefics2", "Idefics2"),
+ ("idefics3", "Idefics3"),
+ ("idefics3_vision", "Idefics3VisionTransformer"),
+ ("ijepa", "I-JEPA"),
+ ("imagegpt", "ImageGPT"),
+ ("informer", "Informer"),
+ ("instructblip", "InstructBLIP"),
+ ("instructblipvideo", "InstructBlipVideo"),
+ ("internvl", "InternVL"),
+ ("internvl_vision", "InternVLVision"),
+ ("jamba", "Jamba"),
+ ("janus", "Janus"),
+ ("jetmoe", "JetMoe"),
+ ("jukebox", "Jukebox"),
+ ("kosmos-2", "KOSMOS-2"),
+ ("kosmos-2.5", "KOSMOS-2.5"),
+ ("kyutai_speech_to_text", "KyutaiSpeechToText"),
+ ("layoutlm", "LayoutLM"),
+ ("layoutlmv2", "LayoutLMv2"),
+ ("layoutlmv3", "LayoutLMv3"),
+ ("layoutxlm", "LayoutXLM"),
+ ("led", "LED"),
+ ("levit", "LeViT"),
+ ("lfm2", "Lfm2"),
+ ("lfm2_vl", "Lfm2Vl"),
+ ("lightglue", "LightGlue"),
+ ("lilt", "LiLT"),
+ ("llama", "LLaMA"),
+ ("llama2", "Llama2"),
+ ("llama3", "Llama3"),
+ ("llama4", "Llama4"),
+ ("llama4_text", "Llama4ForCausalLM"),
+ ("llava", "LLaVa"),
+ ("llava_next", "LLaVA-NeXT"),
+ ("llava_next_video", "LLaVa-NeXT-Video"),
+ ("llava_onevision", "LLaVA-Onevision"),
+ ("longcat_flash", "LongCatFlash"),
+ ("longformer", "Longformer"),
+ ("longt5", "LongT5"),
+ ("luke", "LUKE"),
+ ("lxmert", "LXMERT"),
+ ("m2m_100", "M2M100"),
+ ("madlad-400", "MADLAD-400"),
+ ("mamba", "Mamba"),
+ ("mamba2", "mamba2"),
+ ("marian", "Marian"),
+ ("markuplm", "MarkupLM"),
+ ("mask2former", "Mask2Former"),
+ ("maskformer", "MaskFormer"),
+ ("maskformer-swin", "MaskFormerSwin"),
+ ("matcha", "MatCha"),
+ ("mbart", "mBART"),
+ ("mbart50", "mBART-50"),
+ ("mctct", "M-CTC-T"),
+ ("mega", "MEGA"),
+ ("megatron-bert", "Megatron-BERT"),
+ ("megatron_gpt2", "Megatron-GPT2"),
+ ("metaclip_2", "MetaCLIP 2"),
+ ("mgp-str", "MGP-STR"),
+ ("mimi", "Mimi"),
+ ("minimax", "MiniMax"),
+ ("ministral", "Ministral"),
+ ("mistral", "Mistral"),
+ ("mistral3", "Mistral3"),
+ ("mixtral", "Mixtral"),
+ ("mlcd", "MLCD"),
+ ("mllama", "Mllama"),
+ ("mluke", "mLUKE"),
+ ("mm-grounding-dino", "MM Grounding DINO"),
+ ("mms", "MMS"),
+ ("mobilebert", "MobileBERT"),
+ ("mobilenet_v1", "MobileNetV1"),
+ ("mobilenet_v2", "MobileNetV2"),
+ ("mobilevit", "MobileViT"),
+ ("mobilevitv2", "MobileViTV2"),
+ ("modernbert", "ModernBERT"),
+ ("modernbert-decoder", "ModernBertDecoder"),
+ ("moonshine", "Moonshine"),
+ ("moshi", "Moshi"),
+ ("mpnet", "MPNet"),
+ ("mpt", "MPT"),
+ ("mra", "MRA"),
+ ("mt5", "MT5"),
+ ("musicgen", "MusicGen"),
+ ("musicgen_melody", "MusicGen Melody"),
+ ("mvp", "MVP"),
+ ("myt5", "myt5"),
+ ("nat", "NAT"),
+ ("nemotron", "Nemotron"),
+ ("nezha", "Nezha"),
+ ("nllb", "NLLB"),
+ ("nllb-moe", "NLLB-MOE"),
+ ("nougat", "Nougat"),
+ ("nystromformer", "Nyströmformer"),
+ ("olmo", "OLMo"),
+ ("olmo2", "OLMo2"),
+ ("olmo3", "Olmo3"),
+ ("olmoe", "OLMoE"),
+ ("omdet-turbo", "OmDet-Turbo"),
+ ("oneformer", "OneFormer"),
+ ("open-llama", "OpenLlama"),
+ ("openai-gpt", "OpenAI GPT"),
+ ("opt", "OPT"),
+ ("ovis2", "Ovis2"),
+ ("owlv2", "OWLv2"),
+ ("owlvit", "OWL-ViT"),
+ ("paligemma", "PaliGemma"),
+ ("parakeet", "Parakeet"),
+ ("parakeet_ctc", "Parakeet"),
+ ("parakeet_encoder", "ParakeetEncoder"),
+ ("patchtsmixer", "PatchTSMixer"),
+ ("patchtst", "PatchTST"),
+ ("pegasus", "Pegasus"),
+ ("pegasus_x", "PEGASUS-X"),
+ ("perceiver", "Perceiver"),
+ ("perception_encoder", "PerceptionEncoder"),
+ ("perception_lm", "PerceptionLM"),
+ ("persimmon", "Persimmon"),
+ ("phi", "Phi"),
+ ("phi3", "Phi3"),
+ ("phi4_multimodal", "Phi4Multimodal"),
+ ("phimoe", "Phimoe"),
+ ("phobert", "PhoBERT"),
+ ("pix2struct", "Pix2Struct"),
+ ("pixtral", "Pixtral"),
+ ("plbart", "PLBart"),
+ ("poolformer", "PoolFormer"),
+ ("pop2piano", "Pop2Piano"),
+ ("prompt_depth_anything", "PromptDepthAnything"),
+ ("prophetnet", "ProphetNet"),
+ ("pvt", "PVT"),
+ ("pvt_v2", "PVTv2"),
+ ("qdqbert", "QDQBert"),
+ ("qwen2", "Qwen2"),
+ ("qwen2_5_omni", "Qwen2_5Omni"),
+ ("qwen2_5_vl", "Qwen2_5_VL"),
+ ("qwen2_5_vl_text", "Qwen2_5_VL"),
+ ("qwen2_audio", "Qwen2Audio"),
+ ("qwen2_audio_encoder", "Qwen2AudioEncoder"),
+ ("qwen2_moe", "Qwen2MoE"),
+ ("qwen2_vl", "Qwen2VL"),
+ ("qwen2_vl_text", "Qwen2VL"),
+ ("qwen3", "Qwen3"),
+ ("qwen3_moe", "Qwen3MoE"),
+ ("qwen3_next", "Qwen3Next"),
+ ("qwen3_omni_moe", "Qwen3OmniMoE"),
+ ("qwen3_vl", "Qwen3VL"),
+ ("qwen3_vl_moe", "Qwen3VLMoe"),
+ ("qwen3_vl_moe_text", "Qwen3VLMoe"),
+ ("qwen3_vl_text", "Qwen3VL"),
+ ("rag", "RAG"),
+ ("realm", "REALM"),
+ ("recurrent_gemma", "RecurrentGemma"),
+ ("reformer", "Reformer"),
+ ("regnet", "RegNet"),
+ ("rembert", "RemBERT"),
+ ("resnet", "ResNet"),
+ ("retribert", "RetriBERT"),
+ ("roberta", "RoBERTa"),
+ ("roberta-prelayernorm", "RoBERTa-PreLayerNorm"),
+ ("roc_bert", "RoCBert"),
+ ("roformer", "RoFormer"),
+ ("rt_detr", "RT-DETR"),
+ ("rt_detr_resnet", "RT-DETR-ResNet"),
+ ("rt_detr_v2", "RT-DETRv2"),
+ ("rwkv", "RWKV"),
+ ("sam", "SAM"),
+ ("sam2", "SAM2"),
+ ("sam2_hiera_det_model", "Sam2HieraDetModel"),
+ ("sam2_video", "Sam2VideoModel"),
+ ("sam2_vision_model", "Sam2VisionModel"),
+ ("sam_hq", "SAM-HQ"),
+ ("sam_hq_vision_model", "SamHQVisionModel"),
+ ("sam_vision_model", "SamVisionModel"),
+ ("seamless_m4t", "SeamlessM4T"),
+ ("seamless_m4t_v2", "SeamlessM4Tv2"),
+ ("seed_oss", "SeedOss"),
+ ("segformer", "SegFormer"),
+ ("seggpt", "SegGPT"),
+ ("sew", "SEW"),
+ ("sew-d", "SEW-D"),
+ ("shieldgemma2", "Shieldgemma2"),
+ ("siglip", "SigLIP"),
+ ("siglip2", "SigLIP2"),
+ ("siglip2_vision_model", "Siglip2VisionModel"),
+ ("siglip_vision_model", "SiglipVisionModel"),
+ ("smollm3", "SmolLM3"),
+ ("smolvlm", "SmolVLM"),
+ ("smolvlm_vision", "SmolVLMVisionTransformer"),
+ ("speech-encoder-decoder", "Speech Encoder decoder"),
+ ("speech_to_text", "Speech2Text"),
+ ("speech_to_text_2", "Speech2Text2"),
+ ("speecht5", "SpeechT5"),
+ ("splinter", "Splinter"),
+ ("squeezebert", "SqueezeBERT"),
+ ("stablelm", "StableLm"),
+ ("starcoder2", "Starcoder2"),
+ ("superglue", "SuperGlue"),
+ ("superpoint", "SuperPoint"),
+ ("swiftformer", "SwiftFormer"),
+ ("swin", "Swin Transformer"),
+ ("swin2sr", "Swin2SR"),
+ ("swinv2", "Swin Transformer V2"),
+ ("switch_transformers", "SwitchTransformers"),
+ ("t5", "T5"),
+ ("t5gemma", "T5Gemma"),
+ ("t5v1.1", "T5v1.1"),
+ ("table-transformer", "Table Transformer"),
+ ("tapas", "TAPAS"),
+ ("tapex", "TAPEX"),
+ ("textnet", "TextNet"),
+ ("time_series_transformer", "Time Series Transformer"),
+ ("timesfm", "TimesFm"),
+ ("timesformer", "TimeSformer"),
+ ("timm_backbone", "TimmBackbone"),
+ ("timm_wrapper", "TimmWrapperModel"),
+ ("trajectory_transformer", "Trajectory Transformer"),
+ ("transfo-xl", "Transformer-XL"),
+ ("trocr", "TrOCR"),
+ ("tvlt", "TVLT"),
+ ("tvp", "TVP"),
+ ("udop", "UDOP"),
+ ("ul2", "UL2"),
+ ("umt5", "UMT5"),
+ ("unispeech", "UniSpeech"),
+ ("unispeech-sat", "UniSpeechSat"),
+ ("univnet", "UnivNet"),
+ ("upernet", "UPerNet"),
+ ("van", "VAN"),
+ ("vaultgemma", "VaultGemma"),
+ ("video_llava", "VideoLlava"),
+ ("videomae", "VideoMAE"),
+ ("vilt", "ViLT"),
+ ("vipllava", "VipLlava"),
+ ("vision-encoder-decoder", "Vision Encoder decoder"),
+ ("vision-text-dual-encoder", "VisionTextDualEncoder"),
+ ("visual_bert", "VisualBERT"),
+ ("vit", "ViT"),
+ ("vit_hybrid", "ViT Hybrid"),
+ ("vit_mae", "ViTMAE"),
+ ("vit_msn", "ViTMSN"),
+ ("vitdet", "VitDet"),
+ ("vitmatte", "ViTMatte"),
+ ("vitpose", "ViTPose"),
+ ("vitpose_backbone", "ViTPoseBackbone"),
+ ("vits", "VITS"),
+ ("vivit", "ViViT"),
+ ("vjepa2", "VJEPA2Model"),
+ ("voxtral", "Voxtral"),
+ ("voxtral_encoder", "Voxtral Encoder"),
+ ("wav2vec2", "Wav2Vec2"),
+ ("wav2vec2-bert", "Wav2Vec2-BERT"),
+ ("wav2vec2-conformer", "Wav2Vec2-Conformer"),
+ ("wav2vec2_phoneme", "Wav2Vec2Phoneme"),
+ ("wavlm", "WavLM"),
+ ("whisper", "Whisper"),
+ ("xclip", "X-CLIP"),
+ ("xcodec", "X-CODEC"),
+ ("xglm", "XGLM"),
+ ("xlm", "XLM"),
+ ("xlm-prophetnet", "XLM-ProphetNet"),
+ ("xlm-roberta", "XLM-RoBERTa"),
+ ("xlm-roberta-xl", "XLM-RoBERTa-XL"),
+ ("xlm-v", "XLM-V"),
+ ("xlnet", "XLNet"),
+ ("xls_r", "XLS-R"),
+ ("xlsr_wav2vec2", "XLSR-Wav2Vec2"),
+ ("xlstm", "xLSTM"),
+ ("xmod", "X-MOD"),
+ ("yolos", "YOLOS"),
+ ("yoso", "YOSO"),
+ ("zamba", "Zamba"),
+ ("zamba2", "Zamba2"),
+ ("zoedepth", "ZoeDepth"),
+ ]
+)
+
+# This is tied to the processing `-` -> `_` in `model_type_to_module_name`. For example, instead of putting
+# `transfo-xl` (as in `CONFIG_MAPPING_NAMES`), we should use `transfo_xl`.
+DEPRECATED_MODELS = [
+ "bort",
+ "deta",
+ "efficientformer",
+ "ernie_m",
+ "gptsan_japanese",
+ "graphormer",
+ "jukebox",
+ "mctct",
+ "mega",
+ "mmbt",
+ "nat",
+ "nezha",
+ "open_llama",
+ "qdqbert",
+ "realm",
+ "retribert",
+ "speech_to_text_2",
+ "tapex",
+ "trajectory_transformer",
+ "transfo_xl",
+ "tvlt",
+ "van",
+ "vit_hybrid",
+ "xlm_prophetnet",
+]
+
+SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
+ [
+ ("openai-gpt", "openai"),
+ ("data2vec-audio", "data2vec"),
+ ("data2vec-text", "data2vec"),
+ ("data2vec-vision", "data2vec"),
+ ("donut-swin", "donut"),
+ ("kosmos-2", "kosmos2"),
+ ("kosmos-2.5", "kosmos2_5"),
+ ("maskformer-swin", "maskformer"),
+ ("xclip", "x_clip"),
+ ("clip_vision_model", "clip"),
+ ("qwen2_audio_encoder", "qwen2_audio"),
+ ("voxtral_encoder", "voxtral"),
+ ("clip_text_model", "clip"),
+ ("aria_text", "aria"),
+ ("gemma3_text", "gemma3"),
+ ("gemma3n_audio", "gemma3n"),
+ ("gemma3n_text", "gemma3n"),
+ ("gemma3n_vision", "gemma3n"),
+ ("glm4v_text", "glm4v"),
+ ("glm4v_moe_text", "glm4v_moe"),
+ ("idefics3_vision", "idefics3"),
+ ("siglip_vision_model", "siglip"),
+ ("siglip2_vision_model", "siglip2"),
+ ("aimv2_vision_model", "aimv2"),
+ ("smolvlm_vision", "smolvlm"),
+ ("chinese_clip_vision_model", "chinese_clip"),
+ ("rt_detr_resnet", "rt_detr"),
+ ("granitevision", "llava_next"),
+ ("internvl_vision", "internvl"),
+ ("qwen2_5_vl_text", "qwen2_5_vl"),
+ ("qwen2_vl_text", "qwen2_vl"),
+ ("qwen3_vl_text", "qwen3_vl"),
+ ("qwen3_vl_moe_text", "qwen3_vl_moe"),
+ ("sam_vision_model", "sam"),
+ ("sam2_vision_model", "sam2"),
+ ("edgetam_vision_model", "edgetam"),
+ ("sam2_hiera_det_model", "sam2"),
+ ("sam_hq_vision_model", "sam_hq"),
+ ("llama4_text", "llama4"),
+ ("blip_2_qformer", "blip_2"),
+ ("fastspeech2_conformer_with_hifigan", "fastspeech2_conformer"),
+ ("perception_encoder", "perception_lm"),
+ ("parakeet_encoder", "parakeet"),
+ ("parakeet_ctc", "parakeet"),
+ ]
+)
+
+
+def model_type_to_module_name(key) -> str:
+ """Converts a config key to the corresponding module."""
+ # Special treatment
+ if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
+ key = SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]
+
+ if key in DEPRECATED_MODELS:
+ key = f"deprecated.{key}"
+ return key
+
+ key = key.replace("-", "_")
+ if key in DEPRECATED_MODELS:
+ key = f"deprecated.{key}"
+
+ return key
+
+
+def config_class_to_model_type(config) -> Union[str, None]:
+ """Converts a config class name to the corresponding model type"""
+ for key, cls in CONFIG_MAPPING_NAMES.items():
+ if cls == config:
+ return key
+ # if key not found check in extra content
+ for key, cls in CONFIG_MAPPING._extra_content.items():
+ if cls.__name__ == config:
+ return key
+ return None
+
+
+class _LazyConfigMapping(OrderedDict[str, type[PretrainedConfig]]):
+ """
+ A dictionary that lazily load its values when they are requested.
+ """
+
+ def __init__(self, mapping) -> None:
+ self._mapping = mapping
+ self._extra_content = {}
+ self._modules = {}
+
+ def __getitem__(self, key: str) -> type[PretrainedConfig]:
+ if key in self._extra_content:
+ return self._extra_content[key]
+ if key not in self._mapping:
+ raise KeyError(key)
+ value = self._mapping[key]
+ module_name = model_type_to_module_name(key)
+ if module_name not in self._modules:
+ self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
+ if hasattr(self._modules[module_name], value):
+ return getattr(self._modules[module_name], value)
+
+ # Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the
+ # object at the top level.
+ transformers_module = importlib.import_module("transformers")
+ return getattr(transformers_module, value)
+
+ def keys(self) -> list[str]:
+ return list(self._mapping.keys()) + list(self._extra_content.keys())
+
+ def values(self) -> list[type[PretrainedConfig]]:
+ return [self[k] for k in self._mapping] + list(self._extra_content.values())
+
+ def items(self) -> list[tuple[str, type[PretrainedConfig]]]:
+ return [(k, self[k]) for k in self._mapping] + list(self._extra_content.items())
+
+ def __iter__(self) -> Iterator[str]:
+ return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
+
+ def __contains__(self, item: object) -> bool:
+ return item in self._mapping or item in self._extra_content
+
+ def register(self, key: str, value: type[PretrainedConfig], exist_ok=False) -> None:
+ """
+ Register a new configuration in this mapping.
+ """
+ if key in self._mapping and not exist_ok:
+ raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
+ self._extra_content[key] = value
+
+
+CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
+
+
+class _LazyLoadAllMappings(OrderedDict[str, str]):
+ """
+ A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
+ etc.)
+
+ Args:
+ mapping: The mapping to load.
+ """
+
+ def __init__(self, mapping):
+ self._mapping = mapping
+ self._initialized = False
+ self._data = {}
+
+ def _initialize(self):
+ if self._initialized:
+ return
+
+ for model_type, map_name in self._mapping.items():
+ module_name = model_type_to_module_name(model_type)
+ module = importlib.import_module(f".{module_name}", "transformers.models")
+ mapping = getattr(module, map_name)
+ self._data.update(mapping)
+
+ self._initialized = True
+
+ def __getitem__(self, key):
+ self._initialize()
+ return self._data[key]
+
+ def keys(self) -> KeysView[str]:
+ self._initialize()
+ return self._data.keys()
+
+ def values(self) -> ValuesView[str]:
+ self._initialize()
+ return self._data.values()
+
+ def items(self) -> KeysView[str]:
+ self._initialize()
+ return self._data.keys()
+
+ def __iter__(self) -> Iterator[str]:
+ self._initialize()
+ return iter(self._data)
+
+ def __contains__(self, item: object) -> bool:
+ self._initialize()
+ return item in self._data
+
+
+def _get_class_name(model_class: Union[str, list[str]]):
+ if isinstance(model_class, (list, tuple)):
+ return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
+ return f"[`{model_class}`]"
+
+
+def _list_model_options(indent, config_to_class=None, use_model_types=True):
+ if config_to_class is None and not use_model_types:
+ raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
+ if use_model_types:
+ if config_to_class is None:
+ model_type_to_name = {model_type: f"[`{config}`]" for model_type, config in CONFIG_MAPPING_NAMES.items()}
+ else:
+ model_type_to_name = {
+ model_type: _get_class_name(model_class)
+ for model_type, model_class in config_to_class.items()
+ if model_type in MODEL_NAMES_MAPPING
+ }
+ lines = [
+ f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
+ for model_type in sorted(model_type_to_name.keys())
+ ]
+ else:
+ config_to_name = {
+ CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
+ for config, clas in config_to_class.items()
+ if config in CONFIG_MAPPING_NAMES
+ }
+ config_to_model_name = {
+ config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
+ }
+ lines = [
+ f"{indent}- [`{config_name}`] configuration class:"
+ f" {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
+ for config_name in sorted(config_to_name.keys())
+ ]
+ return "\n".join(lines)
+
+
+def replace_list_option_in_docstrings(
+ config_to_class=None, use_model_types: bool = True
+) -> Callable[[_CallableT], _CallableT]:
+ def docstring_decorator(fn):
+ docstrings = fn.__doc__
+ if docstrings is None:
+ # Example: -OO
+ return fn
+ lines = docstrings.split("\n")
+ i = 0
+ while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None:
+ i += 1
+ if i < len(lines):
+ indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0]
+ if use_model_types:
+ indent = f"{indent} "
+ lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types)
+ docstrings = "\n".join(lines)
+ else:
+ raise ValueError(
+ f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current"
+ f" docstring is:\n{docstrings}"
+ )
+ fn.__doc__ = docstrings
+ return fn
+
+ return docstring_decorator
+
+
+class AutoConfig:
+ r"""
+ This is a generic configuration class that will be instantiated as one of the configuration classes of the library
+ when created with the [`~AutoConfig.from_pretrained`] class method.
+
+ This class cannot be instantiated directly using `__init__()` (throws an error).
+ """
+
+ def __init__(self) -> None:
+ raise OSError(
+ "AutoConfig is designed to be instantiated "
+ "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
+ )
+
+ @classmethod
+ def for_model(cls, model_type: str, *args, **kwargs) -> PretrainedConfig:
+ if model_type in CONFIG_MAPPING:
+ config_class = CONFIG_MAPPING[model_type]
+ return config_class(*args, **kwargs)
+ raise ValueError(
+ f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
+ )
+
+ @classmethod
+ @replace_list_option_in_docstrings()
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], **kwargs):
+ r"""
+ Instantiate one of the configuration classes of the library from a pretrained model configuration.
+
+ The configuration class to instantiate is selected based on the `model_type` property of the config object that
+ is loaded, or when it's missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
+
+ List options
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co.
+ - A path to a *directory* containing a configuration file saved using the
+ [`~PretrainedConfig.save_pretrained`] method, or the [`~PreTrainedModel.save_pretrained`] method,
+ e.g., `./my_model_directory/`.
+ - A path or url to a saved configuration JSON *file*, e.g.,
+ `./my_model_directory/configuration.json`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download the model weights and configuration files and override the
+ cached versions if they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ If `False`, then this function returns just the final configuration object.
+
+ If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
+ part of `kwargs` which has not been used to update `config` and is otherwise ignored.
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
+ execute code present on the Hub on your local machine.
+ kwargs(additional keyword arguments, *optional*):
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
+ by the `return_unused_kwargs` keyword parameter.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoConfig
+
+ >>> # Download configuration from huggingface.co and cache.
+ >>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")
+
+ >>> # Download configuration from huggingface.co (user-uploaded) and cache.
+ >>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")
+
+ >>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
+ >>> config = AutoConfig.from_pretrained("./test/bert_saved_model/")
+
+ >>> # Load a specific configuration file.
+ >>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")
+
+ >>> # Change some config attributes when loading a pretrained config.
+ >>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
+ >>> config.output_attentions
+ True
+
+ >>> config, unused_kwargs = AutoConfig.from_pretrained(
+ ... "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
+ ... )
+ >>> config.output_attentions
+ True
+
+ >>> unused_kwargs
+ {'foo': False}
+ ```
+ """
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token") is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ kwargs["_from_auto"] = True
+ kwargs["name_or_path"] = pretrained_model_name_or_path
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
+ code_revision = kwargs.pop("code_revision", None)
+
+ config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
+ has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
+ has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING
+ if has_remote_code:
+ class_ref = config_dict["auto_map"]["AutoConfig"]
+ if "--" in class_ref:
+ upstream_repo = class_ref.split("--")[0]
+ else:
+ upstream_repo = None
+ trust_remote_code = resolve_trust_remote_code(
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
+ )
+
+ if has_remote_code and trust_remote_code:
+ config_class = get_class_from_dynamic_module(
+ class_ref, pretrained_model_name_or_path, code_revision=code_revision, **kwargs
+ )
+ config_class.register_for_auto_class()
+ return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
+ elif "model_type" in config_dict:
+ # Apply heuristic: if model_type is mistral but layer_types is present, treat as ministral
+ if config_dict["model_type"] == "mistral" and "layer_types" in config_dict:
+ logger.info(
+ "Detected mistral model with layer_types, treating as ministral for alternating attention compatibility. "
+ )
+ config_dict["model_type"] = "ministral"
+
+ try:
+ config_class = CONFIG_MAPPING[config_dict["model_type"]]
+ except KeyError:
+ raise ValueError(
+ f"The checkpoint you are trying to load has model type `{config_dict['model_type']}` "
+ "but Transformers does not recognize this architecture. This could be because of an "
+ "issue with the checkpoint, or because your version of Transformers is out of date.\n\n"
+ "You can update Transformers with the command `pip install --upgrade transformers`. If this "
+ "does not work, and the checkpoint is very new, then there may not be a release version "
+ "that supports this model yet. In this case, you can get the most up-to-date code by installing "
+ "Transformers from source with the command "
+ "`pip install git+https://github.com/huggingface/transformers.git`"
+ )
+ return config_class.from_dict(config_dict, **unused_kwargs)
+ else:
+ # Fallback: use pattern matching on the string.
+ # We go from longer names to shorter names to catch roberta before bert (for instance)
+ for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True):
+ if pattern in str(pretrained_model_name_or_path):
+ return CONFIG_MAPPING[pattern].from_dict(config_dict, **unused_kwargs)
+
+ raise ValueError(
+ f"Unrecognized model in {pretrained_model_name_or_path}. "
+ f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
+ f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
+ )
+
+ @staticmethod
+ def register(model_type, config, exist_ok=False) -> None:
+ """
+ Register a new configuration for this class.
+
+ Args:
+ model_type (`str`): The model type like "bert" or "gpt".
+ config ([`PretrainedConfig`]): The config to register.
+ """
+ if issubclass(config, PretrainedConfig) and config.model_type != model_type:
+ raise ValueError(
+ "The config you are passing has a `model_type` attribute that is not consistent with the model type "
+ f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
+ "match!"
+ )
+ CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok)
+
+
+__all__ = ["CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/auto/feature_extraction_auto.py b/venv/lib/python3.13/site-packages/transformers/models/auto/feature_extraction_auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d4c4f554d9dcbaceadff65ff84d5fbe818fa96a
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/auto/feature_extraction_auto.py
@@ -0,0 +1,422 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""AutoFeatureExtractor class."""
+
+import importlib
+import json
+import os
+import warnings
+from collections import OrderedDict
+from typing import Optional, Union
+
+# Build the list of all feature extractors
+from ...configuration_utils import PretrainedConfig
+from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
+from ...feature_extraction_utils import FeatureExtractionMixin
+from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, cached_file, logging
+from .auto_factory import _LazyAutoMapping
+from .configuration_auto import (
+ CONFIG_MAPPING_NAMES,
+ AutoConfig,
+ model_type_to_module_name,
+ replace_list_option_in_docstrings,
+)
+
+
+logger = logging.get_logger(__name__)
+
+FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
+ [
+ ("audio-spectrogram-transformer", "ASTFeatureExtractor"),
+ ("beit", "BeitFeatureExtractor"),
+ ("chinese_clip", "ChineseCLIPFeatureExtractor"),
+ ("clap", "ClapFeatureExtractor"),
+ ("clip", "CLIPFeatureExtractor"),
+ ("clipseg", "ViTFeatureExtractor"),
+ ("clvp", "ClvpFeatureExtractor"),
+ ("conditional_detr", "ConditionalDetrFeatureExtractor"),
+ ("convnext", "ConvNextFeatureExtractor"),
+ ("cvt", "ConvNextFeatureExtractor"),
+ ("dac", "DacFeatureExtractor"),
+ ("data2vec-audio", "Wav2Vec2FeatureExtractor"),
+ ("data2vec-vision", "BeitFeatureExtractor"),
+ ("deformable_detr", "DeformableDetrFeatureExtractor"),
+ ("deit", "DeiTFeatureExtractor"),
+ ("detr", "DetrFeatureExtractor"),
+ ("dia", "DiaFeatureExtractor"),
+ ("dinat", "ViTFeatureExtractor"),
+ ("donut-swin", "DonutFeatureExtractor"),
+ ("dpt", "DPTFeatureExtractor"),
+ ("encodec", "EncodecFeatureExtractor"),
+ ("flava", "FlavaFeatureExtractor"),
+ ("gemma3n", "Gemma3nAudioFeatureExtractor"),
+ ("glpn", "GLPNFeatureExtractor"),
+ ("granite_speech", "GraniteSpeechFeatureExtractor"),
+ ("groupvit", "CLIPFeatureExtractor"),
+ ("hubert", "Wav2Vec2FeatureExtractor"),
+ ("imagegpt", "ImageGPTFeatureExtractor"),
+ ("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"),
+ ("layoutlmv2", "LayoutLMv2FeatureExtractor"),
+ ("layoutlmv3", "LayoutLMv3FeatureExtractor"),
+ ("levit", "LevitFeatureExtractor"),
+ ("maskformer", "MaskFormerFeatureExtractor"),
+ ("mctct", "MCTCTFeatureExtractor"),
+ ("mimi", "EncodecFeatureExtractor"),
+ ("mobilenet_v1", "MobileNetV1FeatureExtractor"),
+ ("mobilenet_v2", "MobileNetV2FeatureExtractor"),
+ ("mobilevit", "MobileViTFeatureExtractor"),
+ ("moonshine", "Wav2Vec2FeatureExtractor"),
+ ("moshi", "EncodecFeatureExtractor"),
+ ("nat", "ViTFeatureExtractor"),
+ ("owlvit", "OwlViTFeatureExtractor"),
+ ("parakeet_ctc", "ParakeetFeatureExtractor"),
+ ("parakeet_encoder", "ParakeetFeatureExtractor"),
+ ("perceiver", "PerceiverFeatureExtractor"),
+ ("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
+ ("poolformer", "PoolFormerFeatureExtractor"),
+ ("pop2piano", "Pop2PianoFeatureExtractor"),
+ ("regnet", "ConvNextFeatureExtractor"),
+ ("resnet", "ConvNextFeatureExtractor"),
+ ("seamless_m4t", "SeamlessM4TFeatureExtractor"),
+ ("seamless_m4t_v2", "SeamlessM4TFeatureExtractor"),
+ ("segformer", "SegformerFeatureExtractor"),
+ ("sew", "Wav2Vec2FeatureExtractor"),
+ ("sew-d", "Wav2Vec2FeatureExtractor"),
+ ("speech_to_text", "Speech2TextFeatureExtractor"),
+ ("speecht5", "SpeechT5FeatureExtractor"),
+ ("swiftformer", "ViTFeatureExtractor"),
+ ("swin", "ViTFeatureExtractor"),
+ ("swinv2", "ViTFeatureExtractor"),
+ ("table-transformer", "DetrFeatureExtractor"),
+ ("timesformer", "VideoMAEFeatureExtractor"),
+ ("tvlt", "TvltFeatureExtractor"),
+ ("unispeech", "Wav2Vec2FeatureExtractor"),
+ ("unispeech-sat", "Wav2Vec2FeatureExtractor"),
+ ("univnet", "UnivNetFeatureExtractor"),
+ ("van", "ConvNextFeatureExtractor"),
+ ("videomae", "VideoMAEFeatureExtractor"),
+ ("vilt", "ViltFeatureExtractor"),
+ ("vit", "ViTFeatureExtractor"),
+ ("vit_mae", "ViTFeatureExtractor"),
+ ("vit_msn", "ViTFeatureExtractor"),
+ ("wav2vec2", "Wav2Vec2FeatureExtractor"),
+ ("wav2vec2-bert", "Wav2Vec2FeatureExtractor"),
+ ("wav2vec2-conformer", "Wav2Vec2FeatureExtractor"),
+ ("wavlm", "Wav2Vec2FeatureExtractor"),
+ ("whisper", "WhisperFeatureExtractor"),
+ ("xclip", "CLIPFeatureExtractor"),
+ ("xcodec", "DacFeatureExtractor"),
+ ("yolos", "YolosFeatureExtractor"),
+ ]
+)
+
+FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
+
+
+def feature_extractor_class_from_name(class_name: str):
+ for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
+ if class_name in extractors:
+ module_name = model_type_to_module_name(module_name)
+
+ module = importlib.import_module(f".{module_name}", "transformers.models")
+ try:
+ return getattr(module, class_name)
+ except AttributeError:
+ continue
+
+ for extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.values():
+ if getattr(extractor, "__name__", None) == class_name:
+ return extractor
+
+ # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
+ # init and we return the proper dummy to get an appropriate error message.
+ main_module = importlib.import_module("transformers")
+ if hasattr(main_module, class_name):
+ return getattr(main_module, class_name)
+
+ return None
+
+
+def get_feature_extractor_config(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: Optional[bool] = None,
+ proxies: Optional[dict[str, str]] = None,
+ token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ **kwargs,
+):
+ """
+ Loads the tokenizer configuration from a pretrained model tokenizer configuration.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+
+
+
+ Passing `token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `Dict`: The configuration of the tokenizer.
+
+ Examples:
+
+ ```python
+ # Download configuration from huggingface.co and cache.
+ tokenizer_config = get_tokenizer_config("google-bert/bert-base-uncased")
+ # This model does not have a tokenizer config so the result will be an empty dict.
+ tokenizer_config = get_tokenizer_config("FacebookAI/xlm-roberta-base")
+
+ # Save a pretrained tokenizer locally and you can reload its config
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
+ tokenizer.save_pretrained("tokenizer-test")
+ tokenizer_config = get_tokenizer_config("tokenizer-test")
+ ```"""
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
+ token = use_auth_token
+
+ resolved_config_file = cached_file(
+ pretrained_model_name_or_path,
+ FEATURE_EXTRACTOR_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ local_files_only=local_files_only,
+ _raise_exceptions_for_gated_repo=False,
+ _raise_exceptions_for_missing_entries=False,
+ _raise_exceptions_for_connection_errors=False,
+ )
+ if resolved_config_file is None:
+ logger.info(
+ "Could not locate the feature extractor configuration file, will try to use the model config instead."
+ )
+ return {}
+
+ with open(resolved_config_file, encoding="utf-8") as reader:
+ return json.load(reader)
+
+
+class AutoFeatureExtractor:
+ r"""
+ This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the
+ library when created with the [`AutoFeatureExtractor.from_pretrained`] class method.
+
+ This class cannot be instantiated directly using `__init__()` (throws an error).
+ """
+
+ def __init__(self):
+ raise OSError(
+ "AutoFeatureExtractor is designed to be instantiated "
+ "using the `AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)` method."
+ )
+
+ @classmethod
+ @replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING_NAMES)
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ r"""
+ Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary.
+
+ The feature extractor class to instantiate is selected based on the `model_type` property of the config object
+ (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
+ missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
+
+ List options
+
+ Params:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a feature extractor file saved using the
+ [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g.,
+ `./my_model_directory/`.
+ - a path or url to a saved feature extractor JSON *file*, e.g.,
+ `./my_model_directory/preprocessor_config.json`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the feature extractor files and override the cached versions
+ if they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ If `False`, then this function returns just the final feature extractor object. If `True`, then this
+ functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
+ consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
+ `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
+ execute code present on the Hub on your local machine.
+ kwargs (`dict[str, Any]`, *optional*):
+ The values in kwargs of any keys which are feature extractor attributes will be used to override the
+ loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
+ controlled by the `return_unused_kwargs` keyword parameter.
+
+
+
+ Passing `token=True` is required when you want to use a private model.
+
+
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoFeatureExtractor
+
+ >>> # Download feature extractor from huggingface.co and cache.
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
+
+ >>> # If feature extractor files are in a directory (e.g. feature extractor was saved using *save_pretrained('./test/saved_model/')*)
+ >>> # feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
+ ```"""
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token") is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ config = kwargs.pop("config", None)
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
+ kwargs["_from_auto"] = True
+
+ config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
+ feature_extractor_class = config_dict.get("feature_extractor_type", None)
+ feature_extractor_auto_map = None
+ if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
+ feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
+
+ # If we don't find the feature extractor class in the feature extractor config, let's try the model config.
+ if feature_extractor_class is None and feature_extractor_auto_map is None:
+ if not isinstance(config, PretrainedConfig):
+ config = AutoConfig.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ )
+ # It could be in `config.feature_extractor_type``
+ feature_extractor_class = getattr(config, "feature_extractor_type", None)
+ if hasattr(config, "auto_map") and "AutoFeatureExtractor" in config.auto_map:
+ feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"]
+
+ if feature_extractor_class is not None:
+ feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
+
+ has_remote_code = feature_extractor_auto_map is not None
+ has_local_code = feature_extractor_class is not None or type(config) in FEATURE_EXTRACTOR_MAPPING
+ if has_remote_code:
+ if "--" in feature_extractor_auto_map:
+ upstream_repo = feature_extractor_auto_map.split("--")[0]
+ else:
+ upstream_repo = None
+ trust_remote_code = resolve_trust_remote_code(
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
+ )
+
+ if has_remote_code and trust_remote_code:
+ feature_extractor_class = get_class_from_dynamic_module(
+ feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs
+ )
+ _ = kwargs.pop("code_revision", None)
+ feature_extractor_class.register_for_auto_class()
+ return feature_extractor_class.from_dict(config_dict, **kwargs)
+ elif feature_extractor_class is not None:
+ return feature_extractor_class.from_dict(config_dict, **kwargs)
+ # Last try: we use the FEATURE_EXTRACTOR_MAPPING.
+ elif type(config) in FEATURE_EXTRACTOR_MAPPING:
+ feature_extractor_class = FEATURE_EXTRACTOR_MAPPING[type(config)]
+ return feature_extractor_class.from_dict(config_dict, **kwargs)
+
+ raise ValueError(
+ f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a "
+ f"`feature_extractor_type` key in its {FEATURE_EXTRACTOR_NAME} of {CONFIG_NAME}, or one of the following "
+ f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES)}"
+ )
+
+ @staticmethod
+ def register(config_class, feature_extractor_class, exist_ok=False):
+ """
+ Register a new feature extractor for this class.
+
+ Args:
+ config_class ([`PretrainedConfig`]):
+ The configuration corresponding to the model to register.
+ feature_extractor_class ([`FeatureExtractorMixin`]): The feature extractor to register.
+ """
+ FEATURE_EXTRACTOR_MAPPING.register(config_class, feature_extractor_class, exist_ok=exist_ok)
+
+
+__all__ = ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/auto/image_processing_auto.py b/venv/lib/python3.13/site-packages/transformers/models/auto/image_processing_auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b71712dfc7bbb8c3e98fe87464151a4a579f695
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/auto/image_processing_auto.py
@@ -0,0 +1,688 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""AutoImageProcessor class."""
+
+import importlib
+import json
+import os
+import warnings
+from collections import OrderedDict
+from typing import TYPE_CHECKING, Optional, Union
+
+# Build the list of all image processors
+from ...configuration_utils import PretrainedConfig
+from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
+from ...image_processing_utils import ImageProcessingMixin
+from ...image_processing_utils_fast import BaseImageProcessorFast
+from ...utils import (
+ CONFIG_NAME,
+ IMAGE_PROCESSOR_NAME,
+ cached_file,
+ is_timm_config_dict,
+ is_timm_local_checkpoint,
+ is_torchvision_available,
+ is_vision_available,
+ logging,
+)
+from ...utils.import_utils import requires
+from .auto_factory import _LazyAutoMapping
+from .configuration_auto import (
+ CONFIG_MAPPING_NAMES,
+ AutoConfig,
+ model_type_to_module_name,
+ replace_list_option_in_docstrings,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+FORCE_FAST_IMAGE_PROCESSOR = ["Qwen2VLImageProcessor"]
+
+
+if TYPE_CHECKING:
+ # This significantly improves completion suggestion performance when
+ # the transformers package is used with Microsoft's Pylance language server.
+ IMAGE_PROCESSOR_MAPPING_NAMES: OrderedDict[str, tuple[Optional[str], Optional[str]]] = OrderedDict()
+else:
+ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
+ [
+ ("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
+ ("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
+ ("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
+ ("aria", ("AriaImageProcessor", None)),
+ ("beit", ("BeitImageProcessor", "BeitImageProcessorFast")),
+ ("bit", ("BitImageProcessor", "BitImageProcessorFast")),
+ ("blip", ("BlipImageProcessor", "BlipImageProcessorFast")),
+ ("blip-2", ("BlipImageProcessor", "BlipImageProcessorFast")),
+ ("bridgetower", ("BridgeTowerImageProcessor", "BridgeTowerImageProcessorFast")),
+ ("chameleon", ("ChameleonImageProcessor", "ChameleonImageProcessorFast")),
+ ("chinese_clip", ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast")),
+ ("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
+ ("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("cohere2_vision", (None, "Cohere2VisionImageProcessorFast")),
+ ("conditional_detr", ("ConditionalDetrImageProcessor", "ConditionalDetrImageProcessorFast")),
+ ("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
+ ("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
+ ("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
+ ("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")),
+ ("deepseek_vl", ("DeepseekVLImageProcessor", "DeepseekVLImageProcessorFast")),
+ ("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor", "DeepseekVLHybridImageProcessorFast")),
+ ("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
+ ("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
+ ("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")),
+ ("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")),
+ ("deta", ("DetaImageProcessor", None)),
+ ("detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
+ ("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("dinov2", ("BitImageProcessor", "BitImageProcessorFast")),
+ ("dinov3_vit", (None, "DINOv3ViTImageProcessorFast")),
+ ("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")),
+ ("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")),
+ ("edgetam", (None, "Sam2ImageProcessorFast")),
+ ("efficientformer", ("EfficientFormerImageProcessor", None)),
+ ("efficientloftr", ("EfficientLoFTRImageProcessor", "EfficientLoFTRImageProcessorFast")),
+ ("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
+ ("eomt", ("EomtImageProcessor", "EomtImageProcessorFast")),
+ ("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")),
+ ("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
+ ("fuyu", ("FuyuImageProcessor", None)),
+ ("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
+ ("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
+ ("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
+ ("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")),
+ ("glpn", ("GLPNImageProcessor", None)),
+ ("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
+ ("grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
+ ("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
+ ("hiera", ("BitImageProcessor", "BitImageProcessorFast")),
+ ("idefics", ("IdeficsImageProcessor", None)),
+ ("idefics2", ("Idefics2ImageProcessor", "Idefics2ImageProcessorFast")),
+ ("idefics3", ("Idefics3ImageProcessor", "Idefics3ImageProcessorFast")),
+ ("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("imagegpt", ("ImageGPTImageProcessor", "ImageGPTImageProcessorFast")),
+ ("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
+ ("instructblipvideo", ("InstructBlipVideoImageProcessor", None)),
+ ("janus", ("JanusImageProcessor", "JanusImageProcessorFast")),
+ ("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
+ ("kosmos-2.5", ("Kosmos2_5ImageProcessor", "Kosmos2_5ImageProcessorFast")),
+ ("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
+ ("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
+ ("levit", ("LevitImageProcessor", "LevitImageProcessorFast")),
+ ("lfm2_vl", (None, "Lfm2VlImageProcessorFast")),
+ ("lightglue", ("LightGlueImageProcessor", None)),
+ ("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
+ ("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
+ ("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
+ ("llava_next_video", ("LlavaNextVideoImageProcessor", None)),
+ ("llava_onevision", ("LlavaOnevisionImageProcessor", "LlavaOnevisionImageProcessorFast")),
+ ("mask2former", ("Mask2FormerImageProcessor", "Mask2FormerImageProcessorFast")),
+ ("maskformer", ("MaskFormerImageProcessor", "MaskFormerImageProcessorFast")),
+ ("metaclip_2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
+ ("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
+ ("mlcd", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
+ ("mllama", ("MllamaImageProcessor", None)),
+ ("mm-grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
+ ("mobilenet_v1", ("MobileNetV1ImageProcessor", "MobileNetV1ImageProcessorFast")),
+ ("mobilenet_v2", ("MobileNetV2ImageProcessor", "MobileNetV2ImageProcessorFast")),
+ ("mobilevit", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")),
+ ("mobilevitv2", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")),
+ ("nat", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("nougat", ("NougatImageProcessor", "NougatImageProcessorFast")),
+ ("oneformer", ("OneFormerImageProcessor", "OneFormerImageProcessorFast")),
+ ("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")),
+ ("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")),
+ ("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
+ ("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
+ ("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
+ ("perception_lm", (None, "PerceptionLMImageProcessorFast")),
+ ("phi4_multimodal", (None, "Phi4MultimodalImageProcessorFast")),
+ ("pix2struct", ("Pix2StructImageProcessor", None)),
+ ("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
+ ("poolformer", ("PoolFormerImageProcessor", "PoolFormerImageProcessorFast")),
+ ("prompt_depth_anything", ("PromptDepthAnythingImageProcessor", "PromptDepthAnythingImageProcessorFast")),
+ ("pvt", ("PvtImageProcessor", "PvtImageProcessorFast")),
+ ("pvt_v2", ("PvtImageProcessor", "PvtImageProcessorFast")),
+ ("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
+ ("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
+ ("qwen3_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
+ ("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
+ ("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
+ ("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
+ ("sam", ("SamImageProcessor", "SamImageProcessorFast")),
+ ("sam2", (None, "Sam2ImageProcessorFast")),
+ ("sam_hq", ("SamImageProcessor", "SamImageProcessorFast")),
+ ("segformer", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
+ ("seggpt", ("SegGptImageProcessor", None)),
+ ("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
+ ("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
+ ("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
+ ("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")),
+ ("superglue", ("SuperGlueImageProcessor", None)),
+ ("superpoint", ("SuperPointImageProcessor", "SuperPointImageProcessorFast")),
+ ("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")),
+ ("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("table-transformer", ("DetrImageProcessor", "DetrImageProcessorFast")),
+ ("textnet", ("TextNetImageProcessor", "TextNetImageProcessorFast")),
+ ("timesformer", ("VideoMAEImageProcessor", None)),
+ ("timm_wrapper", ("TimmWrapperImageProcessor", None)),
+ ("tvlt", ("TvltImageProcessor", None)),
+ ("tvp", ("TvpImageProcessor", "TvpImageProcessorFast")),
+ ("udop", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
+ ("upernet", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
+ ("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
+ ("videomae", ("VideoMAEImageProcessor", None)),
+ ("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")),
+ ("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
+ ("vit", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("vit_hybrid", ("ViTHybridImageProcessor", None)),
+ ("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("vitmatte", ("VitMatteImageProcessor", "VitMatteImageProcessorFast")),
+ ("xclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
+ ("yolos", ("YolosImageProcessor", "YolosImageProcessorFast")),
+ ("zoedepth", ("ZoeDepthImageProcessor", "ZoeDepthImageProcessorFast")),
+ ]
+ )
+
+# Override to None if the packages are not available
+for model_type, (slow_class, fast_class) in IMAGE_PROCESSOR_MAPPING_NAMES.items():
+ if not is_vision_available():
+ slow_class = None
+ if not is_torchvision_available():
+ fast_class = None
+
+ IMAGE_PROCESSOR_MAPPING_NAMES[model_type] = (slow_class, fast_class)
+
+IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
+
+
+def get_image_processor_class_from_name(class_name: str):
+ if class_name == "BaseImageProcessorFast":
+ return BaseImageProcessorFast
+
+ for module_name, extractors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
+ if class_name in extractors:
+ module_name = model_type_to_module_name(module_name)
+
+ module = importlib.import_module(f".{module_name}", "transformers.models")
+ try:
+ return getattr(module, class_name)
+ except AttributeError:
+ continue
+
+ for extractors in IMAGE_PROCESSOR_MAPPING._extra_content.values():
+ for extractor in extractors:
+ if getattr(extractor, "__name__", None) == class_name:
+ return extractor
+
+ # We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main
+ # init and we return the proper dummy to get an appropriate error message.
+ main_module = importlib.import_module("transformers")
+ if hasattr(main_module, class_name):
+ return getattr(main_module, class_name)
+
+ return None
+
+
+def get_image_processor_config(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: Optional[bool] = None,
+ proxies: Optional[dict[str, str]] = None,
+ token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ **kwargs,
+):
+ """
+ Loads the image processor configuration from a pretrained model image processor configuration.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the image processor configuration from local files.
+
+
+
+ Passing `token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `Dict`: The configuration of the image processor.
+
+ Examples:
+
+ ```python
+ # Download configuration from huggingface.co and cache.
+ image_processor_config = get_image_processor_config("google-bert/bert-base-uncased")
+ # This model does not have a image processor config so the result will be an empty dict.
+ image_processor_config = get_image_processor_config("FacebookAI/xlm-roberta-base")
+
+ # Save a pretrained image processor locally and you can reload its config
+ from transformers import AutoTokenizer
+
+ image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
+ image_processor.save_pretrained("image-processor-test")
+ image_processor_config = get_image_processor_config("image-processor-test")
+ ```"""
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
+ token = use_auth_token
+
+ resolved_config_file = cached_file(
+ pretrained_model_name_or_path,
+ IMAGE_PROCESSOR_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ local_files_only=local_files_only,
+ _raise_exceptions_for_gated_repo=False,
+ _raise_exceptions_for_missing_entries=False,
+ _raise_exceptions_for_connection_errors=False,
+ )
+ if resolved_config_file is None:
+ logger.info(
+ "Could not locate the image processor configuration file, will try to use the model config instead."
+ )
+ return {}
+
+ with open(resolved_config_file, encoding="utf-8") as reader:
+ return json.load(reader)
+
+
+def _warning_fast_image_processor_available(fast_class):
+ logger.warning(
+ f"Fast image processor class {fast_class} is available for this model. "
+ "Using slow image processor class. To use the fast image processor class set `use_fast=True`."
+ )
+
+
+@requires(backends=("vision",))
+class AutoImageProcessor:
+ r"""
+ This is a generic image processor class that will be instantiated as one of the image processor classes of the
+ library when created with the [`AutoImageProcessor.from_pretrained`] class method.
+
+ This class cannot be instantiated directly using `__init__()` (throws an error).
+ """
+
+ def __init__(self):
+ raise OSError(
+ "AutoImageProcessor is designed to be instantiated "
+ "using the `AutoImageProcessor.from_pretrained(pretrained_model_name_or_path)` method."
+ )
+
+ @classmethod
+ @replace_list_option_in_docstrings(IMAGE_PROCESSOR_MAPPING_NAMES)
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
+ r"""
+ Instantiate one of the image processor classes of the library from a pretrained model vocabulary.
+
+ The image processor class to instantiate is selected based on the `model_type` property of the config object
+ (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
+ missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
+
+ List options
+
+ Params:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained image_processor hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a image processor file saved using the
+ [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
+ `./my_model_directory/`.
+ - a path or url to a saved image processor JSON *file*, e.g.,
+ `./my_model_directory/preprocessor_config.json`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model image processor should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the image processor files and override the cached versions if
+ they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ use_fast (`bool`, *optional*, defaults to `False`):
+ Use a fast torchvision-base image processor if it is supported for a given model.
+ If a fast image processor is not available for a given model, a normal numpy-based image processor
+ is returned instead.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ If `False`, then this function returns just the final image processor object. If `True`, then this
+ functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
+ consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
+ `kwargs` which has not been used to update `image_processor` and is otherwise ignored.
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
+ execute code present on the Hub on your local machine.
+ image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
+ The name of the file in the model directory to use for the image processor config.
+ kwargs (`dict[str, Any]`, *optional*):
+ The values in kwargs of any keys which are image processor attributes will be used to override the
+ loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
+ controlled by the `return_unused_kwargs` keyword parameter.
+
+
+
+ Passing `token=True` is required when you want to use a private model.
+
+
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor
+
+ >>> # Download image processor from huggingface.co and cache.
+ >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
+
+ >>> # If image processor files are in a directory (e.g. image processor was saved using *save_pretrained('./test/saved_model/')*)
+ >>> # image_processor = AutoImageProcessor.from_pretrained("./test/saved_model/")
+ ```"""
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token") is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ config = kwargs.pop("config", None)
+ # TODO: @yoni, change in v4.48 (use_fast set to True by default)
+ use_fast = kwargs.pop("use_fast", None)
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
+ kwargs["_from_auto"] = True
+
+ # Resolve the image processor config filename
+ if "image_processor_filename" in kwargs:
+ image_processor_filename = kwargs.pop("image_processor_filename")
+ elif is_timm_local_checkpoint(pretrained_model_name_or_path):
+ image_processor_filename = CONFIG_NAME
+ else:
+ image_processor_filename = IMAGE_PROCESSOR_NAME
+
+ # Load the image processor config
+ try:
+ # Main path for all transformers models and local TimmWrapper checkpoints
+ config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
+ pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
+ )
+ except Exception as initial_exception:
+ # Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
+ # instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
+ # except the model name, the only way to check if a remote checkpoint is a timm model is to try to
+ # load `config.json` and if it fails with some error, we raise the initial exception.
+ try:
+ config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
+ pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
+ )
+ except Exception:
+ raise initial_exception
+
+ # In case we have a config_dict, but it's not a timm config dict, we raise the initial exception,
+ # because only timm models have image processing in `config.json`.
+ if not is_timm_config_dict(config_dict):
+ raise initial_exception
+
+ image_processor_type = config_dict.get("image_processor_type", None)
+ image_processor_auto_map = None
+ if "AutoImageProcessor" in config_dict.get("auto_map", {}):
+ image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
+
+ # If we still don't have the image processor class, check if we're loading from a previous feature extractor config
+ # and if so, infer the image processor class from there.
+ if image_processor_type is None and image_processor_auto_map is None:
+ feature_extractor_class = config_dict.pop("feature_extractor_type", None)
+ if feature_extractor_class is not None:
+ image_processor_type = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
+ if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
+ feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
+ image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor")
+
+ # If we don't find the image processor class in the image processor config, let's try the model config.
+ if image_processor_type is None and image_processor_auto_map is None:
+ if not isinstance(config, PretrainedConfig):
+ config = AutoConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ trust_remote_code=trust_remote_code,
+ **kwargs,
+ )
+ # It could be in `config.image_processor_type``
+ image_processor_type = getattr(config, "image_processor_type", None)
+ if hasattr(config, "auto_map") and "AutoImageProcessor" in config.auto_map:
+ image_processor_auto_map = config.auto_map["AutoImageProcessor"]
+
+ image_processor_class = None
+ # TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
+ if image_processor_type is not None:
+ # if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor.
+ if use_fast is None:
+ use_fast = image_processor_type.endswith("Fast")
+ if not use_fast and image_processor_type in FORCE_FAST_IMAGE_PROCESSOR and is_torchvision_available():
+ use_fast = True
+ logger.warning_once(
+ f"The image processor of type `{image_processor_type}` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. "
+ "This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. "
+ "Note that this behavior will be extended to all models in a future release."
+ )
+ if not use_fast:
+ logger.warning_once(
+ "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
+ "`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
+ "This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
+ )
+ if use_fast and not image_processor_type.endswith("Fast"):
+ image_processor_type += "Fast"
+ if use_fast and not is_torchvision_available():
+ # check if there is a slow image processor class to fallback to
+ image_processor_class = get_image_processor_class_from_name(image_processor_type[:-4])
+ if image_processor_class is None:
+ raise ValueError(
+ f"`{image_processor_type}` requires `torchvision` to be installed. Please install `torchvision` and try again."
+ )
+ logger.warning_once(
+ "Using `use_fast=True` but `torchvision` is not available. Falling back to the slow image processor."
+ )
+ use_fast = False
+ if use_fast:
+ for image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.values():
+ if image_processor_type in image_processors:
+ break
+ else:
+ image_processor_type = image_processor_type[:-4]
+ use_fast = False
+ logger.warning_once(
+ "`use_fast` is set to `True` but the image processor class does not have a fast version. "
+ " Falling back to the slow version."
+ )
+ image_processor_class = get_image_processor_class_from_name(image_processor_type)
+ else:
+ image_processor_type_slow = image_processor_type.removesuffix("Fast")
+ image_processor_class = get_image_processor_class_from_name(image_processor_type_slow)
+ if image_processor_class is None and image_processor_type.endswith("Fast"):
+ raise ValueError(
+ f"`{image_processor_type}` does not have a slow version. Please set `use_fast=True` when instantiating the processor."
+ )
+
+ has_remote_code = image_processor_auto_map is not None
+ has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
+ if has_remote_code:
+ if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple):
+ # In some configs, only the slow image processor class is stored
+ image_processor_auto_map = (image_processor_auto_map, None)
+ if use_fast and image_processor_auto_map[1] is not None:
+ class_ref = image_processor_auto_map[1]
+ else:
+ class_ref = image_processor_auto_map[0]
+ if "--" in class_ref:
+ upstream_repo = class_ref.split("--")[0]
+ else:
+ upstream_repo = None
+ trust_remote_code = resolve_trust_remote_code(
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
+ )
+
+ if has_remote_code and trust_remote_code:
+ if not use_fast and image_processor_auto_map[1] is not None:
+ _warning_fast_image_processor_available(image_processor_auto_map[1])
+
+ image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
+ _ = kwargs.pop("code_revision", None)
+ image_processor_class.register_for_auto_class()
+ return image_processor_class.from_dict(config_dict, **kwargs)
+ elif image_processor_class is not None:
+ return image_processor_class.from_dict(config_dict, **kwargs)
+ # Last try: we use the IMAGE_PROCESSOR_MAPPING.
+ elif type(config) in IMAGE_PROCESSOR_MAPPING:
+ image_processor_tuple = IMAGE_PROCESSOR_MAPPING[type(config)]
+
+ image_processor_class_py, image_processor_class_fast = image_processor_tuple
+
+ if not use_fast and image_processor_class_fast is not None:
+ _warning_fast_image_processor_available(image_processor_class_fast)
+
+ if image_processor_class_fast and (use_fast or image_processor_class_py is None):
+ return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
+ else:
+ if image_processor_class_py is not None:
+ return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
+ else:
+ raise ValueError(
+ "This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
+ )
+ raise ValueError(
+ f"Unrecognized image processor in {pretrained_model_name_or_path}. Should have a "
+ f"`image_processor_type` key in its {IMAGE_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
+ f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in IMAGE_PROCESSOR_MAPPING_NAMES)}"
+ )
+
+ @staticmethod
+ def register(
+ config_class,
+ image_processor_class=None,
+ slow_image_processor_class=None,
+ fast_image_processor_class=None,
+ exist_ok=False,
+ ):
+ """
+ Register a new image processor for this class.
+
+ Args:
+ config_class ([`PretrainedConfig`]):
+ The configuration corresponding to the model to register.
+ image_processor_class ([`ImageProcessingMixin`]): The image processor to register.
+ """
+ if image_processor_class is not None:
+ if slow_image_processor_class is not None:
+ raise ValueError("Cannot specify both image_processor_class and slow_image_processor_class")
+ warnings.warn(
+ "The image_processor_class argument is deprecated and will be removed in v4.42. Please use `slow_image_processor_class`, or `fast_image_processor_class` instead",
+ FutureWarning,
+ )
+ slow_image_processor_class = image_processor_class
+
+ if slow_image_processor_class is None and fast_image_processor_class is None:
+ raise ValueError("You need to specify either slow_image_processor_class or fast_image_processor_class")
+ if slow_image_processor_class is not None and issubclass(slow_image_processor_class, BaseImageProcessorFast):
+ raise ValueError("You passed a fast image processor in as the `slow_image_processor_class`.")
+ if fast_image_processor_class is not None and not issubclass(
+ fast_image_processor_class, BaseImageProcessorFast
+ ):
+ raise ValueError("The `fast_image_processor_class` should inherit from `BaseImageProcessorFast`.")
+
+ if (
+ slow_image_processor_class is not None
+ and fast_image_processor_class is not None
+ and issubclass(fast_image_processor_class, BaseImageProcessorFast)
+ and fast_image_processor_class.slow_image_processor_class != slow_image_processor_class
+ ):
+ raise ValueError(
+ "The fast processor class you are passing has a `slow_image_processor_class` attribute that is not "
+ "consistent with the slow processor class you passed (fast tokenizer has "
+ f"{fast_image_processor_class.slow_image_processor_class} and you passed {slow_image_processor_class}. Fix one of those "
+ "so they match!"
+ )
+
+ # Avoid resetting a set slow/fast image processor if we are passing just the other ones.
+ if config_class in IMAGE_PROCESSOR_MAPPING._extra_content:
+ existing_slow, existing_fast = IMAGE_PROCESSOR_MAPPING[config_class]
+ if slow_image_processor_class is None:
+ slow_image_processor_class = existing_slow
+ if fast_image_processor_class is None:
+ fast_image_processor_class = existing_fast
+
+ IMAGE_PROCESSOR_MAPPING.register(
+ config_class, (slow_image_processor_class, fast_image_processor_class), exist_ok=exist_ok
+ )
+
+
+__all__ = ["IMAGE_PROCESSOR_MAPPING", "AutoImageProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/auto/modeling_auto.py b/venv/lib/python3.13/site-packages/transformers/models/auto/modeling_auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..298834bebe9303b90a325d48f4be6ab732bfe051
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/auto/modeling_auto.py
@@ -0,0 +1,2382 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Auto Model class."""
+
+import os
+import warnings
+from collections import OrderedDict
+from typing import TYPE_CHECKING, Union
+
+from ...utils import logging
+from .auto_factory import (
+ _BaseAutoBackboneClass,
+ _BaseAutoModelClass,
+ _LazyAutoMapping,
+ auto_class_update,
+)
+from .configuration_auto import CONFIG_MAPPING_NAMES
+
+
+if TYPE_CHECKING:
+ from ...generation import GenerationMixin
+ from ...modeling_utils import PreTrainedModel
+
+ # class for better type annotations
+ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
+ pass
+
+
+logger = logging.get_logger(__name__)
+
+MODEL_MAPPING_NAMES = OrderedDict(
+ [
+ # Base model mapping
+ ("aimv2", "Aimv2Model"),
+ ("aimv2_vision_model", "Aimv2VisionModel"),
+ ("albert", "AlbertModel"),
+ ("align", "AlignModel"),
+ ("altclip", "AltCLIPModel"),
+ ("apertus", "ApertusModel"),
+ ("arcee", "ArceeModel"),
+ ("aria", "AriaModel"),
+ ("aria_text", "AriaTextModel"),
+ ("audio-spectrogram-transformer", "ASTModel"),
+ ("autoformer", "AutoformerModel"),
+ ("aya_vision", "AyaVisionModel"),
+ ("bamba", "BambaModel"),
+ ("bark", "BarkModel"),
+ ("bart", "BartModel"),
+ ("beit", "BeitModel"),
+ ("bert", "BertModel"),
+ ("bert-generation", "BertGenerationEncoder"),
+ ("big_bird", "BigBirdModel"),
+ ("bigbird_pegasus", "BigBirdPegasusModel"),
+ ("biogpt", "BioGptModel"),
+ ("bit", "BitModel"),
+ ("bitnet", "BitNetModel"),
+ ("blenderbot", "BlenderbotModel"),
+ ("blenderbot-small", "BlenderbotSmallModel"),
+ ("blip", "BlipModel"),
+ ("blip-2", "Blip2Model"),
+ ("blip_2_qformer", "Blip2QFormerModel"),
+ ("bloom", "BloomModel"),
+ ("blt", "BltModel"),
+ ("bridgetower", "BridgeTowerModel"),
+ ("bros", "BrosModel"),
+ ("camembert", "CamembertModel"),
+ ("canine", "CanineModel"),
+ ("chameleon", "ChameleonModel"),
+ ("chinese_clip", "ChineseCLIPModel"),
+ ("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
+ ("clap", "ClapModel"),
+ ("clip", "CLIPModel"),
+ ("clip_text_model", "CLIPTextModel"),
+ ("clip_vision_model", "CLIPVisionModel"),
+ ("clipseg", "CLIPSegModel"),
+ ("clvp", "ClvpModelForConditionalGeneration"),
+ ("code_llama", "LlamaModel"),
+ ("codegen", "CodeGenModel"),
+ ("cohere", "CohereModel"),
+ ("cohere2", "Cohere2Model"),
+ ("cohere2_vision", "Cohere2VisionModel"),
+ ("conditional_detr", "ConditionalDetrModel"),
+ ("convbert", "ConvBertModel"),
+ ("convnext", "ConvNextModel"),
+ ("convnextv2", "ConvNextV2Model"),
+ ("cpmant", "CpmAntModel"),
+ ("csm", "CsmForConditionalGeneration"),
+ ("ctrl", "CTRLModel"),
+ ("cvt", "CvtModel"),
+ ("d_fine", "DFineModel"),
+ ("dab-detr", "DabDetrModel"),
+ ("dac", "DacModel"),
+ ("data2vec-audio", "Data2VecAudioModel"),
+ ("data2vec-text", "Data2VecTextModel"),
+ ("data2vec-vision", "Data2VecVisionModel"),
+ ("dbrx", "DbrxModel"),
+ ("deberta", "DebertaModel"),
+ ("deberta-v2", "DebertaV2Model"),
+ ("decision_transformer", "DecisionTransformerModel"),
+ ("deepseek_v2", "DeepseekV2Model"),
+ ("deepseek_v3", "DeepseekV3Model"),
+ ("deepseek_vl", "DeepseekVLModel"),
+ ("deepseek_vl_hybrid", "DeepseekVLHybridModel"),
+ ("deformable_detr", "DeformableDetrModel"),
+ ("deit", "DeiTModel"),
+ ("depth_pro", "DepthProModel"),
+ ("deta", "DetaModel"),
+ ("detr", "DetrModel"),
+ ("dia", "DiaModel"),
+ ("diffllama", "DiffLlamaModel"),
+ ("dinat", "DinatModel"),
+ ("dinov2", "Dinov2Model"),
+ ("dinov2_with_registers", "Dinov2WithRegistersModel"),
+ ("dinov3_convnext", "DINOv3ConvNextModel"),
+ ("dinov3_vit", "DINOv3ViTModel"),
+ ("distilbert", "DistilBertModel"),
+ ("doge", "DogeModel"),
+ ("donut-swin", "DonutSwinModel"),
+ ("dots1", "Dots1Model"),
+ ("dpr", "DPRQuestionEncoder"),
+ ("dpt", "DPTModel"),
+ ("edgetam", "EdgeTamModel"),
+ ("edgetam_video", "EdgeTamVideoModel"),
+ ("edgetam_vision_model", "EdgeTamVisionModel"),
+ ("efficientformer", "EfficientFormerModel"),
+ ("efficientloftr", "EfficientLoFTRModel"),
+ ("efficientnet", "EfficientNetModel"),
+ ("electra", "ElectraModel"),
+ ("emu3", "Emu3Model"),
+ ("encodec", "EncodecModel"),
+ ("ernie", "ErnieModel"),
+ ("ernie4_5", "Ernie4_5Model"),
+ ("ernie4_5_moe", "Ernie4_5_MoeModel"),
+ ("ernie_m", "ErnieMModel"),
+ ("esm", "EsmModel"),
+ ("evolla", "EvollaModel"),
+ ("exaone4", "Exaone4Model"),
+ ("falcon", "FalconModel"),
+ ("falcon_h1", "FalconH1Model"),
+ ("falcon_mamba", "FalconMambaModel"),
+ ("fastspeech2_conformer", "FastSpeech2ConformerModel"),
+ ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
+ ("flaubert", "FlaubertModel"),
+ ("flava", "FlavaModel"),
+ ("flex_olmo", "FlexOlmoModel"),
+ ("florence2", "Florence2Model"),
+ ("fnet", "FNetModel"),
+ ("focalnet", "FocalNetModel"),
+ ("fsmt", "FSMTModel"),
+ ("funnel", ("FunnelModel", "FunnelBaseModel")),
+ ("fuyu", "FuyuModel"),
+ ("gemma", "GemmaModel"),
+ ("gemma2", "Gemma2Model"),
+ ("gemma3", "Gemma3Model"),
+ ("gemma3_text", "Gemma3TextModel"),
+ ("gemma3n", "Gemma3nModel"),
+ ("gemma3n_audio", "Gemma3nAudioEncoder"),
+ ("gemma3n_text", "Gemma3nTextModel"),
+ ("gemma3n_vision", "TimmWrapperModel"),
+ ("git", "GitModel"),
+ ("glm", "GlmModel"),
+ ("glm4", "Glm4Model"),
+ ("glm4_moe", "Glm4MoeModel"),
+ ("glm4v", "Glm4vModel"),
+ ("glm4v_moe", "Glm4vMoeModel"),
+ ("glm4v_moe_text", "Glm4vMoeTextModel"),
+ ("glm4v_text", "Glm4vTextModel"),
+ ("glpn", "GLPNModel"),
+ ("got_ocr2", "GotOcr2Model"),
+ ("gpt-sw3", "GPT2Model"),
+ ("gpt2", "GPT2Model"),
+ ("gpt_bigcode", "GPTBigCodeModel"),
+ ("gpt_neo", "GPTNeoModel"),
+ ("gpt_neox", "GPTNeoXModel"),
+ ("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
+ ("gpt_oss", "GptOssModel"),
+ ("gptj", "GPTJModel"),
+ ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
+ ("granite", "GraniteModel"),
+ ("granitemoe", "GraniteMoeModel"),
+ ("granitemoehybrid", "GraniteMoeHybridModel"),
+ ("granitemoeshared", "GraniteMoeSharedModel"),
+ ("graphormer", "GraphormerModel"),
+ ("grounding-dino", "GroundingDinoModel"),
+ ("groupvit", "GroupViTModel"),
+ ("helium", "HeliumModel"),
+ ("hgnet_v2", "HGNetV2Backbone"),
+ ("hiera", "HieraModel"),
+ ("hubert", "HubertModel"),
+ ("hunyuan_v1_dense", "HunYuanDenseV1Model"),
+ ("hunyuan_v1_moe", "HunYuanMoEV1Model"),
+ ("ibert", "IBertModel"),
+ ("idefics", "IdeficsModel"),
+ ("idefics2", "Idefics2Model"),
+ ("idefics3", "Idefics3Model"),
+ ("idefics3_vision", "Idefics3VisionTransformer"),
+ ("ijepa", "IJepaModel"),
+ ("imagegpt", "ImageGPTModel"),
+ ("informer", "InformerModel"),
+ ("instructblip", "InstructBlipModel"),
+ ("instructblipvideo", "InstructBlipVideoModel"),
+ ("internvl", "InternVLModel"),
+ ("internvl_vision", "InternVLVisionModel"),
+ ("jamba", "JambaModel"),
+ ("janus", "JanusModel"),
+ ("jetmoe", "JetMoeModel"),
+ ("jukebox", "JukeboxModel"),
+ ("kosmos-2", "Kosmos2Model"),
+ ("kosmos-2.5", "Kosmos2_5Model"),
+ ("kyutai_speech_to_text", "KyutaiSpeechToTextModel"),
+ ("layoutlm", "LayoutLMModel"),
+ ("layoutlmv2", "LayoutLMv2Model"),
+ ("layoutlmv3", "LayoutLMv3Model"),
+ ("led", "LEDModel"),
+ ("levit", "LevitModel"),
+ ("lfm2", "Lfm2Model"),
+ ("lfm2_vl", "Lfm2VlModel"),
+ ("lightglue", "LightGlueForKeypointMatching"),
+ ("lilt", "LiltModel"),
+ ("llama", "LlamaModel"),
+ ("llama4", "Llama4ForConditionalGeneration"),
+ ("llama4_text", "Llama4TextModel"),
+ ("llava", "LlavaModel"),
+ ("llava_next", "LlavaNextModel"),
+ ("llava_next_video", "LlavaNextVideoModel"),
+ ("llava_onevision", "LlavaOnevisionModel"),
+ ("longcat_flash", "LongcatFlashModel"),
+ ("longformer", "LongformerModel"),
+ ("longt5", "LongT5Model"),
+ ("luke", "LukeModel"),
+ ("lxmert", "LxmertModel"),
+ ("m2m_100", "M2M100Model"),
+ ("mamba", "MambaModel"),
+ ("mamba2", "Mamba2Model"),
+ ("marian", "MarianModel"),
+ ("markuplm", "MarkupLMModel"),
+ ("mask2former", "Mask2FormerModel"),
+ ("maskformer", "MaskFormerModel"),
+ ("maskformer-swin", "MaskFormerSwinModel"),
+ ("mbart", "MBartModel"),
+ ("mctct", "MCTCTModel"),
+ ("mega", "MegaModel"),
+ ("megatron-bert", "MegatronBertModel"),
+ ("metaclip_2", "MetaClip2Model"),
+ ("mgp-str", "MgpstrForSceneTextRecognition"),
+ ("mimi", "MimiModel"),
+ ("minimax", "MiniMaxModel"),
+ ("ministral", "MinistralModel"),
+ ("mistral", "MistralModel"),
+ ("mistral3", "Mistral3Model"),
+ ("mixtral", "MixtralModel"),
+ ("mlcd", "MLCDVisionModel"),
+ ("mllama", "MllamaModel"),
+ ("mm-grounding-dino", "MMGroundingDinoModel"),
+ ("mobilebert", "MobileBertModel"),
+ ("mobilenet_v1", "MobileNetV1Model"),
+ ("mobilenet_v2", "MobileNetV2Model"),
+ ("mobilevit", "MobileViTModel"),
+ ("mobilevitv2", "MobileViTV2Model"),
+ ("modernbert", "ModernBertModel"),
+ ("modernbert-decoder", "ModernBertDecoderModel"),
+ ("moonshine", "MoonshineModel"),
+ ("moshi", "MoshiModel"),
+ ("mpnet", "MPNetModel"),
+ ("mpt", "MptModel"),
+ ("mra", "MraModel"),
+ ("mt5", "MT5Model"),
+ ("musicgen", "MusicgenModel"),
+ ("musicgen_melody", "MusicgenMelodyModel"),
+ ("mvp", "MvpModel"),
+ ("nat", "NatModel"),
+ ("nemotron", "NemotronModel"),
+ ("nezha", "NezhaModel"),
+ ("nllb-moe", "NllbMoeModel"),
+ ("nystromformer", "NystromformerModel"),
+ ("olmo", "OlmoModel"),
+ ("olmo2", "Olmo2Model"),
+ ("olmo3", "Olmo3Model"),
+ ("olmoe", "OlmoeModel"),
+ ("omdet-turbo", "OmDetTurboForObjectDetection"),
+ ("oneformer", "OneFormerModel"),
+ ("open-llama", "OpenLlamaModel"),
+ ("openai-gpt", "OpenAIGPTModel"),
+ ("opt", "OPTModel"),
+ ("ovis2", "Ovis2Model"),
+ ("owlv2", "Owlv2Model"),
+ ("owlvit", "OwlViTModel"),
+ ("paligemma", "PaliGemmaModel"),
+ ("parakeet_ctc", "ParakeetForCTC"),
+ ("parakeet_encoder", "ParakeetEncoder"),
+ ("patchtsmixer", "PatchTSMixerModel"),
+ ("patchtst", "PatchTSTModel"),
+ ("pegasus", "PegasusModel"),
+ ("pegasus_x", "PegasusXModel"),
+ ("perceiver", "PerceiverModel"),
+ ("perception_encoder", "PerceptionEncoder"),
+ ("perception_lm", "PerceptionLMModel"),
+ ("persimmon", "PersimmonModel"),
+ ("phi", "PhiModel"),
+ ("phi3", "Phi3Model"),
+ ("phi4_multimodal", "Phi4MultimodalModel"),
+ ("phimoe", "PhimoeModel"),
+ ("pixtral", "PixtralVisionModel"),
+ ("plbart", "PLBartModel"),
+ ("poolformer", "PoolFormerModel"),
+ ("prophetnet", "ProphetNetModel"),
+ ("pvt", "PvtModel"),
+ ("pvt_v2", "PvtV2Model"),
+ ("qdqbert", "QDQBertModel"),
+ ("qwen2", "Qwen2Model"),
+ ("qwen2_5_vl", "Qwen2_5_VLModel"),
+ ("qwen2_5_vl_text", "Qwen2_5_VLTextModel"),
+ ("qwen2_audio_encoder", "Qwen2AudioEncoder"),
+ ("qwen2_moe", "Qwen2MoeModel"),
+ ("qwen2_vl", "Qwen2VLModel"),
+ ("qwen2_vl_text", "Qwen2VLTextModel"),
+ ("qwen3", "Qwen3Model"),
+ ("qwen3_moe", "Qwen3MoeModel"),
+ ("qwen3_next", "Qwen3NextModel"),
+ ("qwen3_vl", "Qwen3VLModel"),
+ ("qwen3_vl_moe", "Qwen3VLMoeModel"),
+ ("qwen3_vl_moe_text", "Qwen3VLMoeTextModel"),
+ ("qwen3_vl_text", "Qwen3VLTextModel"),
+ ("recurrent_gemma", "RecurrentGemmaModel"),
+ ("reformer", "ReformerModel"),
+ ("regnet", "RegNetModel"),
+ ("rembert", "RemBertModel"),
+ ("resnet", "ResNetModel"),
+ ("retribert", "RetriBertModel"),
+ ("roberta", "RobertaModel"),
+ ("roberta-prelayernorm", "RobertaPreLayerNormModel"),
+ ("roc_bert", "RoCBertModel"),
+ ("roformer", "RoFormerModel"),
+ ("rt_detr", "RTDetrModel"),
+ ("rt_detr_v2", "RTDetrV2Model"),
+ ("rwkv", "RwkvModel"),
+ ("sam", "SamModel"),
+ ("sam2", "Sam2Model"),
+ ("sam2_hiera_det_model", "Sam2HieraDetModel"),
+ ("sam2_video", "Sam2VideoModel"),
+ ("sam2_vision_model", "Sam2VisionModel"),
+ ("sam_hq", "SamHQModel"),
+ ("sam_hq_vision_model", "SamHQVisionModel"),
+ ("sam_vision_model", "SamVisionModel"),
+ ("seamless_m4t", "SeamlessM4TModel"),
+ ("seamless_m4t_v2", "SeamlessM4Tv2Model"),
+ ("seed_oss", "SeedOssModel"),
+ ("segformer", "SegformerModel"),
+ ("seggpt", "SegGptModel"),
+ ("sew", "SEWModel"),
+ ("sew-d", "SEWDModel"),
+ ("siglip", "SiglipModel"),
+ ("siglip2", "Siglip2Model"),
+ ("siglip2_vision_model", "Siglip2VisionModel"),
+ ("siglip_vision_model", "SiglipVisionModel"),
+ ("smollm3", "SmolLM3Model"),
+ ("smolvlm", "SmolVLMModel"),
+ ("smolvlm_vision", "SmolVLMVisionTransformer"),
+ ("speech_to_text", "Speech2TextModel"),
+ ("speecht5", "SpeechT5Model"),
+ ("splinter", "SplinterModel"),
+ ("squeezebert", "SqueezeBertModel"),
+ ("stablelm", "StableLmModel"),
+ ("starcoder2", "Starcoder2Model"),
+ ("swiftformer", "SwiftFormerModel"),
+ ("swin", "SwinModel"),
+ ("swin2sr", "Swin2SRModel"),
+ ("swinv2", "Swinv2Model"),
+ ("switch_transformers", "SwitchTransformersModel"),
+ ("t5", "T5Model"),
+ ("t5gemma", "T5GemmaModel"),
+ ("table-transformer", "TableTransformerModel"),
+ ("tapas", "TapasModel"),
+ ("textnet", "TextNetModel"),
+ ("time_series_transformer", "TimeSeriesTransformerModel"),
+ ("timesfm", "TimesFmModel"),
+ ("timesformer", "TimesformerModel"),
+ ("timm_backbone", "TimmBackbone"),
+ ("timm_wrapper", "TimmWrapperModel"),
+ ("trajectory_transformer", "TrajectoryTransformerModel"),
+ ("transfo-xl", "TransfoXLModel"),
+ ("tvlt", "TvltModel"),
+ ("tvp", "TvpModel"),
+ ("udop", "UdopModel"),
+ ("umt5", "UMT5Model"),
+ ("unispeech", "UniSpeechModel"),
+ ("unispeech-sat", "UniSpeechSatModel"),
+ ("univnet", "UnivNetModel"),
+ ("van", "VanModel"),
+ ("vaultgemma", "VaultGemmaModel"),
+ ("video_llava", "VideoLlavaModel"),
+ ("videomae", "VideoMAEModel"),
+ ("vilt", "ViltModel"),
+ ("vipllava", "VipLlavaModel"),
+ ("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
+ ("visual_bert", "VisualBertModel"),
+ ("vit", "ViTModel"),
+ ("vit_hybrid", "ViTHybridModel"),
+ ("vit_mae", "ViTMAEModel"),
+ ("vit_msn", "ViTMSNModel"),
+ ("vitdet", "VitDetModel"),
+ ("vits", "VitsModel"),
+ ("vivit", "VivitModel"),
+ ("vjepa2", "VJEPA2Model"),
+ ("voxtral", "VoxtralForConditionalGeneration"),
+ ("voxtral_encoder", "VoxtralEncoder"),
+ ("wav2vec2", "Wav2Vec2Model"),
+ ("wav2vec2-bert", "Wav2Vec2BertModel"),
+ ("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
+ ("wavlm", "WavLMModel"),
+ ("whisper", "WhisperModel"),
+ ("xclip", "XCLIPModel"),
+ ("xcodec", "XcodecModel"),
+ ("xglm", "XGLMModel"),
+ ("xlm", "XLMModel"),
+ ("xlm-prophetnet", "XLMProphetNetModel"),
+ ("xlm-roberta", "XLMRobertaModel"),
+ ("xlm-roberta-xl", "XLMRobertaXLModel"),
+ ("xlnet", "XLNetModel"),
+ ("xlstm", "xLSTMModel"),
+ ("xmod", "XmodModel"),
+ ("yolos", "YolosModel"),
+ ("yoso", "YosoModel"),
+ ("zamba", "ZambaModel"),
+ ("zamba2", "Zamba2Model"),
+ ]
+)
+
+MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for pre-training mapping
+ ("albert", "AlbertForPreTraining"),
+ ("bart", "BartForConditionalGeneration"),
+ ("bert", "BertForPreTraining"),
+ ("big_bird", "BigBirdForPreTraining"),
+ ("bloom", "BloomForCausalLM"),
+ ("camembert", "CamembertForMaskedLM"),
+ ("colpali", "ColPaliForRetrieval"),
+ ("colqwen2", "ColQwen2ForRetrieval"),
+ ("ctrl", "CTRLLMHeadModel"),
+ ("data2vec-text", "Data2VecTextForMaskedLM"),
+ ("deberta", "DebertaForMaskedLM"),
+ ("deberta-v2", "DebertaV2ForMaskedLM"),
+ ("distilbert", "DistilBertForMaskedLM"),
+ ("electra", "ElectraForPreTraining"),
+ ("ernie", "ErnieForPreTraining"),
+ ("evolla", "EvollaForProteinText2Text"),
+ ("exaone4", "Exaone4ForCausalLM"),
+ ("falcon_mamba", "FalconMambaForCausalLM"),
+ ("flaubert", "FlaubertWithLMHeadModel"),
+ ("flava", "FlavaForPreTraining"),
+ ("florence2", "Florence2ForConditionalGeneration"),
+ ("fnet", "FNetForPreTraining"),
+ ("fsmt", "FSMTForConditionalGeneration"),
+ ("funnel", "FunnelForPreTraining"),
+ ("gemma3", "Gemma3ForConditionalGeneration"),
+ ("gpt-sw3", "GPT2LMHeadModel"),
+ ("gpt2", "GPT2LMHeadModel"),
+ ("gpt_bigcode", "GPTBigCodeForCausalLM"),
+ ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
+ ("hiera", "HieraForPreTraining"),
+ ("ibert", "IBertForMaskedLM"),
+ ("idefics", "IdeficsForVisionText2Text"),
+ ("idefics2", "Idefics2ForConditionalGeneration"),
+ ("idefics3", "Idefics3ForConditionalGeneration"),
+ ("janus", "JanusForConditionalGeneration"),
+ ("layoutlm", "LayoutLMForMaskedLM"),
+ ("llava", "LlavaForConditionalGeneration"),
+ ("llava_next", "LlavaNextForConditionalGeneration"),
+ ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
+ ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
+ ("longformer", "LongformerForMaskedLM"),
+ ("luke", "LukeForMaskedLM"),
+ ("lxmert", "LxmertForPreTraining"),
+ ("mamba", "MambaForCausalLM"),
+ ("mamba2", "Mamba2ForCausalLM"),
+ ("mega", "MegaForMaskedLM"),
+ ("megatron-bert", "MegatronBertForPreTraining"),
+ ("mistral3", "Mistral3ForConditionalGeneration"),
+ ("mllama", "MllamaForConditionalGeneration"),
+ ("mobilebert", "MobileBertForPreTraining"),
+ ("mpnet", "MPNetForMaskedLM"),
+ ("mpt", "MptForCausalLM"),
+ ("mra", "MraForMaskedLM"),
+ ("mvp", "MvpForConditionalGeneration"),
+ ("nezha", "NezhaForPreTraining"),
+ ("nllb-moe", "NllbMoeForConditionalGeneration"),
+ ("openai-gpt", "OpenAIGPTLMHeadModel"),
+ ("paligemma", "PaliGemmaForConditionalGeneration"),
+ ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
+ ("retribert", "RetriBertModel"),
+ ("roberta", "RobertaForMaskedLM"),
+ ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
+ ("roc_bert", "RoCBertForPreTraining"),
+ ("rwkv", "RwkvForCausalLM"),
+ ("splinter", "SplinterForPreTraining"),
+ ("squeezebert", "SqueezeBertForMaskedLM"),
+ ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
+ ("t5", "T5ForConditionalGeneration"),
+ ("t5gemma", "T5GemmaForConditionalGeneration"),
+ ("tapas", "TapasForMaskedLM"),
+ ("transfo-xl", "TransfoXLLMHeadModel"),
+ ("tvlt", "TvltForPreTraining"),
+ ("unispeech", "UniSpeechForPreTraining"),
+ ("unispeech-sat", "UniSpeechSatForPreTraining"),
+ ("video_llava", "VideoLlavaForConditionalGeneration"),
+ ("videomae", "VideoMAEForPreTraining"),
+ ("vipllava", "VipLlavaForConditionalGeneration"),
+ ("visual_bert", "VisualBertForPreTraining"),
+ ("vit_mae", "ViTMAEForPreTraining"),
+ ("voxtral", "VoxtralForConditionalGeneration"),
+ ("wav2vec2", "Wav2Vec2ForPreTraining"),
+ ("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"),
+ ("xlm", "XLMWithLMHeadModel"),
+ ("xlm-roberta", "XLMRobertaForMaskedLM"),
+ ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
+ ("xlnet", "XLNetLMHeadModel"),
+ ("xlstm", "xLSTMForCausalLM"),
+ ("xmod", "XmodForMaskedLM"),
+ ]
+)
+
+MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
+ [
+ # Model with LM heads mapping
+ ("albert", "AlbertForMaskedLM"),
+ ("bart", "BartForConditionalGeneration"),
+ ("bert", "BertForMaskedLM"),
+ ("big_bird", "BigBirdForMaskedLM"),
+ ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
+ ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
+ ("bloom", "BloomForCausalLM"),
+ ("camembert", "CamembertForMaskedLM"),
+ ("codegen", "CodeGenForCausalLM"),
+ ("convbert", "ConvBertForMaskedLM"),
+ ("cpmant", "CpmAntForCausalLM"),
+ ("ctrl", "CTRLLMHeadModel"),
+ ("data2vec-text", "Data2VecTextForMaskedLM"),
+ ("deberta", "DebertaForMaskedLM"),
+ ("deberta-v2", "DebertaV2ForMaskedLM"),
+ ("dia", "DiaForConditionalGeneration"),
+ ("distilbert", "DistilBertForMaskedLM"),
+ ("electra", "ElectraForMaskedLM"),
+ ("encoder-decoder", "EncoderDecoderModel"),
+ ("ernie", "ErnieForMaskedLM"),
+ ("esm", "EsmForMaskedLM"),
+ ("exaone4", "Exaone4ForCausalLM"),
+ ("falcon_mamba", "FalconMambaForCausalLM"),
+ ("flaubert", "FlaubertWithLMHeadModel"),
+ ("fnet", "FNetForMaskedLM"),
+ ("fsmt", "FSMTForConditionalGeneration"),
+ ("funnel", "FunnelForMaskedLM"),
+ ("git", "GitForCausalLM"),
+ ("gpt-sw3", "GPT2LMHeadModel"),
+ ("gpt2", "GPT2LMHeadModel"),
+ ("gpt_bigcode", "GPTBigCodeForCausalLM"),
+ ("gpt_neo", "GPTNeoForCausalLM"),
+ ("gpt_neox", "GPTNeoXForCausalLM"),
+ ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
+ ("gptj", "GPTJForCausalLM"),
+ ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
+ ("ibert", "IBertForMaskedLM"),
+ ("layoutlm", "LayoutLMForMaskedLM"),
+ ("led", "LEDForConditionalGeneration"),
+ ("longformer", "LongformerForMaskedLM"),
+ ("longt5", "LongT5ForConditionalGeneration"),
+ ("luke", "LukeForMaskedLM"),
+ ("m2m_100", "M2M100ForConditionalGeneration"),
+ ("mamba", "MambaForCausalLM"),
+ ("mamba2", "Mamba2ForCausalLM"),
+ ("marian", "MarianMTModel"),
+ ("mega", "MegaForMaskedLM"),
+ ("megatron-bert", "MegatronBertForCausalLM"),
+ ("mobilebert", "MobileBertForMaskedLM"),
+ ("moonshine", "MoonshineForConditionalGeneration"),
+ ("mpnet", "MPNetForMaskedLM"),
+ ("mpt", "MptForCausalLM"),
+ ("mra", "MraForMaskedLM"),
+ ("mvp", "MvpForConditionalGeneration"),
+ ("nezha", "NezhaForMaskedLM"),
+ ("nllb-moe", "NllbMoeForConditionalGeneration"),
+ ("nystromformer", "NystromformerForMaskedLM"),
+ ("openai-gpt", "OpenAIGPTLMHeadModel"),
+ ("pegasus_x", "PegasusXForConditionalGeneration"),
+ ("plbart", "PLBartForConditionalGeneration"),
+ ("pop2piano", "Pop2PianoForConditionalGeneration"),
+ ("qdqbert", "QDQBertForMaskedLM"),
+ ("reformer", "ReformerModelWithLMHead"),
+ ("rembert", "RemBertForMaskedLM"),
+ ("roberta", "RobertaForMaskedLM"),
+ ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
+ ("roc_bert", "RoCBertForMaskedLM"),
+ ("roformer", "RoFormerForMaskedLM"),
+ ("rwkv", "RwkvForCausalLM"),
+ ("speech_to_text", "Speech2TextForConditionalGeneration"),
+ ("squeezebert", "SqueezeBertForMaskedLM"),
+ ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
+ ("t5", "T5ForConditionalGeneration"),
+ ("t5gemma", "T5GemmaForConditionalGeneration"),
+ ("tapas", "TapasForMaskedLM"),
+ ("transfo-xl", "TransfoXLLMHeadModel"),
+ ("wav2vec2", "Wav2Vec2ForMaskedLM"),
+ ("whisper", "WhisperForConditionalGeneration"),
+ ("xlm", "XLMWithLMHeadModel"),
+ ("xlm-roberta", "XLMRobertaForMaskedLM"),
+ ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
+ ("xlnet", "XLNetLMHeadModel"),
+ ("xmod", "XmodForMaskedLM"),
+ ("yoso", "YosoForMaskedLM"),
+ ]
+)
+
+MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Causal LM mapping
+ ("apertus", "ApertusForCausalLM"),
+ ("arcee", "ArceeForCausalLM"),
+ ("aria_text", "AriaTextForCausalLM"),
+ ("bamba", "BambaForCausalLM"),
+ ("bart", "BartForCausalLM"),
+ ("bert", "BertLMHeadModel"),
+ ("bert-generation", "BertGenerationDecoder"),
+ ("big_bird", "BigBirdForCausalLM"),
+ ("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
+ ("biogpt", "BioGptForCausalLM"),
+ ("bitnet", "BitNetForCausalLM"),
+ ("blenderbot", "BlenderbotForCausalLM"),
+ ("blenderbot-small", "BlenderbotSmallForCausalLM"),
+ ("bloom", "BloomForCausalLM"),
+ ("blt", "BltForCausalLM"),
+ ("camembert", "CamembertForCausalLM"),
+ ("code_llama", "LlamaForCausalLM"),
+ ("codegen", "CodeGenForCausalLM"),
+ ("cohere", "CohereForCausalLM"),
+ ("cohere2", "Cohere2ForCausalLM"),
+ ("cpmant", "CpmAntForCausalLM"),
+ ("ctrl", "CTRLLMHeadModel"),
+ ("data2vec-text", "Data2VecTextForCausalLM"),
+ ("dbrx", "DbrxForCausalLM"),
+ ("deepseek_v2", "DeepseekV2ForCausalLM"),
+ ("deepseek_v3", "DeepseekV3ForCausalLM"),
+ ("diffllama", "DiffLlamaForCausalLM"),
+ ("doge", "DogeForCausalLM"),
+ ("dots1", "Dots1ForCausalLM"),
+ ("electra", "ElectraForCausalLM"),
+ ("emu3", "Emu3ForCausalLM"),
+ ("ernie", "ErnieForCausalLM"),
+ ("ernie4_5", "Ernie4_5ForCausalLM"),
+ ("ernie4_5_moe", "Ernie4_5_MoeForCausalLM"),
+ ("exaone4", "Exaone4ForCausalLM"),
+ ("falcon", "FalconForCausalLM"),
+ ("falcon_h1", "FalconH1ForCausalLM"),
+ ("falcon_mamba", "FalconMambaForCausalLM"),
+ ("flex_olmo", "FlexOlmoForCausalLM"),
+ ("fuyu", "FuyuForCausalLM"),
+ ("gemma", "GemmaForCausalLM"),
+ ("gemma2", "Gemma2ForCausalLM"),
+ ("gemma3", "Gemma3ForConditionalGeneration"),
+ ("gemma3_text", "Gemma3ForCausalLM"),
+ ("gemma3n", "Gemma3nForConditionalGeneration"),
+ ("gemma3n_text", "Gemma3nForCausalLM"),
+ ("git", "GitForCausalLM"),
+ ("glm", "GlmForCausalLM"),
+ ("glm4", "Glm4ForCausalLM"),
+ ("glm4_moe", "Glm4MoeForCausalLM"),
+ ("got_ocr2", "GotOcr2ForConditionalGeneration"),
+ ("gpt-sw3", "GPT2LMHeadModel"),
+ ("gpt2", "GPT2LMHeadModel"),
+ ("gpt_bigcode", "GPTBigCodeForCausalLM"),
+ ("gpt_neo", "GPTNeoForCausalLM"),
+ ("gpt_neox", "GPTNeoXForCausalLM"),
+ ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
+ ("gpt_oss", "GptOssForCausalLM"),
+ ("gptj", "GPTJForCausalLM"),
+ ("granite", "GraniteForCausalLM"),
+ ("granitemoe", "GraniteMoeForCausalLM"),
+ ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),
+ ("granitemoeshared", "GraniteMoeSharedForCausalLM"),
+ ("helium", "HeliumForCausalLM"),
+ ("hunyuan_v1_dense", "HunYuanDenseV1ForCausalLM"),
+ ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
+ ("jamba", "JambaForCausalLM"),
+ ("jetmoe", "JetMoeForCausalLM"),
+ ("lfm2", "Lfm2ForCausalLM"),
+ ("llama", "LlamaForCausalLM"),
+ ("llama4", "Llama4ForCausalLM"),
+ ("llama4_text", "Llama4ForCausalLM"),
+ ("longcat_flash", "LongcatFlashForCausalLM"),
+ ("mamba", "MambaForCausalLM"),
+ ("mamba2", "Mamba2ForCausalLM"),
+ ("marian", "MarianForCausalLM"),
+ ("mbart", "MBartForCausalLM"),
+ ("mega", "MegaForCausalLM"),
+ ("megatron-bert", "MegatronBertForCausalLM"),
+ ("minimax", "MiniMaxForCausalLM"),
+ ("ministral", "MinistralForCausalLM"),
+ ("mistral", "MistralForCausalLM"),
+ ("mixtral", "MixtralForCausalLM"),
+ ("mllama", "MllamaForCausalLM"),
+ ("modernbert-decoder", "ModernBertDecoderForCausalLM"),
+ ("moshi", "MoshiForCausalLM"),
+ ("mpt", "MptForCausalLM"),
+ ("musicgen", "MusicgenForCausalLM"),
+ ("musicgen_melody", "MusicgenMelodyForCausalLM"),
+ ("mvp", "MvpForCausalLM"),
+ ("nemotron", "NemotronForCausalLM"),
+ ("olmo", "OlmoForCausalLM"),
+ ("olmo2", "Olmo2ForCausalLM"),
+ ("olmo3", "Olmo3ForCausalLM"),
+ ("olmoe", "OlmoeForCausalLM"),
+ ("open-llama", "OpenLlamaForCausalLM"),
+ ("openai-gpt", "OpenAIGPTLMHeadModel"),
+ ("opt", "OPTForCausalLM"),
+ ("pegasus", "PegasusForCausalLM"),
+ ("persimmon", "PersimmonForCausalLM"),
+ ("phi", "PhiForCausalLM"),
+ ("phi3", "Phi3ForCausalLM"),
+ ("phi4_multimodal", "Phi4MultimodalForCausalLM"),
+ ("phimoe", "PhimoeForCausalLM"),
+ ("plbart", "PLBartForCausalLM"),
+ ("prophetnet", "ProphetNetForCausalLM"),
+ ("qdqbert", "QDQBertLMHeadModel"),
+ ("qwen2", "Qwen2ForCausalLM"),
+ ("qwen2_moe", "Qwen2MoeForCausalLM"),
+ ("qwen3", "Qwen3ForCausalLM"),
+ ("qwen3_moe", "Qwen3MoeForCausalLM"),
+ ("qwen3_next", "Qwen3NextForCausalLM"),
+ ("recurrent_gemma", "RecurrentGemmaForCausalLM"),
+ ("reformer", "ReformerModelWithLMHead"),
+ ("rembert", "RemBertForCausalLM"),
+ ("roberta", "RobertaForCausalLM"),
+ ("roberta-prelayernorm", "RobertaPreLayerNormForCausalLM"),
+ ("roc_bert", "RoCBertForCausalLM"),
+ ("roformer", "RoFormerForCausalLM"),
+ ("rwkv", "RwkvForCausalLM"),
+ ("seed_oss", "SeedOssForCausalLM"),
+ ("smollm3", "SmolLM3ForCausalLM"),
+ ("speech_to_text_2", "Speech2Text2ForCausalLM"),
+ ("stablelm", "StableLmForCausalLM"),
+ ("starcoder2", "Starcoder2ForCausalLM"),
+ ("transfo-xl", "TransfoXLLMHeadModel"),
+ ("trocr", "TrOCRForCausalLM"),
+ ("vaultgemma", "VaultGemmaForCausalLM"),
+ ("whisper", "WhisperForCausalLM"),
+ ("xglm", "XGLMForCausalLM"),
+ ("xlm", "XLMWithLMHeadModel"),
+ ("xlm-prophetnet", "XLMProphetNetForCausalLM"),
+ ("xlm-roberta", "XLMRobertaForCausalLM"),
+ ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
+ ("xlnet", "XLNetLMHeadModel"),
+ ("xlstm", "xLSTMForCausalLM"),
+ ("xmod", "XmodForCausalLM"),
+ ("zamba", "ZambaForCausalLM"),
+ ("zamba2", "Zamba2ForCausalLM"),
+ ]
+)
+
+MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Image mapping
+ ("aimv2_vision_model", "Aimv2VisionModel"),
+ ("beit", "BeitModel"),
+ ("bit", "BitModel"),
+ ("cohere2_vision", "Cohere2VisionModel"),
+ ("conditional_detr", "ConditionalDetrModel"),
+ ("convnext", "ConvNextModel"),
+ ("convnextv2", "ConvNextV2Model"),
+ ("dab-detr", "DabDetrModel"),
+ ("data2vec-vision", "Data2VecVisionModel"),
+ ("deformable_detr", "DeformableDetrModel"),
+ ("deit", "DeiTModel"),
+ ("depth_pro", "DepthProModel"),
+ ("deta", "DetaModel"),
+ ("detr", "DetrModel"),
+ ("dinat", "DinatModel"),
+ ("dinov2", "Dinov2Model"),
+ ("dinov2_with_registers", "Dinov2WithRegistersModel"),
+ ("dinov3_convnext", "DINOv3ConvNextModel"),
+ ("dinov3_vit", "DINOv3ViTModel"),
+ ("dpt", "DPTModel"),
+ ("efficientformer", "EfficientFormerModel"),
+ ("efficientnet", "EfficientNetModel"),
+ ("focalnet", "FocalNetModel"),
+ ("glpn", "GLPNModel"),
+ ("hiera", "HieraModel"),
+ ("ijepa", "IJepaModel"),
+ ("imagegpt", "ImageGPTModel"),
+ ("levit", "LevitModel"),
+ ("llama4", "Llama4VisionModel"),
+ ("mlcd", "MLCDVisionModel"),
+ ("mllama", "MllamaVisionModel"),
+ ("mobilenet_v1", "MobileNetV1Model"),
+ ("mobilenet_v2", "MobileNetV2Model"),
+ ("mobilevit", "MobileViTModel"),
+ ("mobilevitv2", "MobileViTV2Model"),
+ ("nat", "NatModel"),
+ ("poolformer", "PoolFormerModel"),
+ ("pvt", "PvtModel"),
+ ("regnet", "RegNetModel"),
+ ("resnet", "ResNetModel"),
+ ("segformer", "SegformerModel"),
+ ("siglip_vision_model", "SiglipVisionModel"),
+ ("swiftformer", "SwiftFormerModel"),
+ ("swin", "SwinModel"),
+ ("swin2sr", "Swin2SRModel"),
+ ("swinv2", "Swinv2Model"),
+ ("table-transformer", "TableTransformerModel"),
+ ("timesformer", "TimesformerModel"),
+ ("timm_backbone", "TimmBackbone"),
+ ("timm_wrapper", "TimmWrapperModel"),
+ ("van", "VanModel"),
+ ("videomae", "VideoMAEModel"),
+ ("vit", "ViTModel"),
+ ("vit_hybrid", "ViTHybridModel"),
+ ("vit_mae", "ViTMAEModel"),
+ ("vit_msn", "ViTMSNModel"),
+ ("vitdet", "VitDetModel"),
+ ("vivit", "VivitModel"),
+ ("yolos", "YolosModel"),
+ ]
+)
+
+MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
+ [
+ ("deit", "DeiTForMaskedImageModeling"),
+ ("focalnet", "FocalNetForMaskedImageModeling"),
+ ("swin", "SwinForMaskedImageModeling"),
+ ("swinv2", "Swinv2ForMaskedImageModeling"),
+ ("vit", "ViTForMaskedImageModeling"),
+ ]
+)
+
+
+MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
+ # Model for Causal Image Modeling mapping
+ [
+ ("imagegpt", "ImageGPTForCausalImageModeling"),
+ ]
+)
+
+MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Image Classification mapping
+ ("beit", "BeitForImageClassification"),
+ ("bit", "BitForImageClassification"),
+ ("clip", "CLIPForImageClassification"),
+ ("convnext", "ConvNextForImageClassification"),
+ ("convnextv2", "ConvNextV2ForImageClassification"),
+ ("cvt", "CvtForImageClassification"),
+ ("data2vec-vision", "Data2VecVisionForImageClassification"),
+ (
+ "deit",
+ ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher"),
+ ),
+ ("dinat", "DinatForImageClassification"),
+ ("dinov2", "Dinov2ForImageClassification"),
+ ("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"),
+ ("donut-swin", "DonutSwinForImageClassification"),
+ (
+ "efficientformer",
+ (
+ "EfficientFormerForImageClassification",
+ "EfficientFormerForImageClassificationWithTeacher",
+ ),
+ ),
+ ("efficientnet", "EfficientNetForImageClassification"),
+ ("focalnet", "FocalNetForImageClassification"),
+ ("hgnet_v2", "HGNetV2ForImageClassification"),
+ ("hiera", "HieraForImageClassification"),
+ ("ijepa", "IJepaForImageClassification"),
+ ("imagegpt", "ImageGPTForImageClassification"),
+ (
+ "levit",
+ ("LevitForImageClassification", "LevitForImageClassificationWithTeacher"),
+ ),
+ ("metaclip_2", "MetaClip2ForImageClassification"),
+ ("mobilenet_v1", "MobileNetV1ForImageClassification"),
+ ("mobilenet_v2", "MobileNetV2ForImageClassification"),
+ ("mobilevit", "MobileViTForImageClassification"),
+ ("mobilevitv2", "MobileViTV2ForImageClassification"),
+ ("nat", "NatForImageClassification"),
+ (
+ "perceiver",
+ (
+ "PerceiverForImageClassificationLearned",
+ "PerceiverForImageClassificationFourier",
+ "PerceiverForImageClassificationConvProcessing",
+ ),
+ ),
+ ("poolformer", "PoolFormerForImageClassification"),
+ ("pvt", "PvtForImageClassification"),
+ ("pvt_v2", "PvtV2ForImageClassification"),
+ ("regnet", "RegNetForImageClassification"),
+ ("resnet", "ResNetForImageClassification"),
+ ("segformer", "SegformerForImageClassification"),
+ ("shieldgemma2", "ShieldGemma2ForImageClassification"),
+ ("siglip", "SiglipForImageClassification"),
+ ("siglip2", "Siglip2ForImageClassification"),
+ ("swiftformer", "SwiftFormerForImageClassification"),
+ ("swin", "SwinForImageClassification"),
+ ("swinv2", "Swinv2ForImageClassification"),
+ ("textnet", "TextNetForImageClassification"),
+ ("timm_wrapper", "TimmWrapperForImageClassification"),
+ ("van", "VanForImageClassification"),
+ ("vit", "ViTForImageClassification"),
+ ("vit_hybrid", "ViTHybridForImageClassification"),
+ ("vit_msn", "ViTMSNForImageClassification"),
+ ]
+)
+
+MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Do not add new models here, this class will be deprecated in the future.
+ # Model for Image Segmentation mapping
+ ("detr", "DetrForSegmentation"),
+ ]
+)
+
+MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Semantic Segmentation mapping
+ ("beit", "BeitForSemanticSegmentation"),
+ ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
+ ("dpt", "DPTForSemanticSegmentation"),
+ ("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"),
+ ("mobilevit", "MobileViTForSemanticSegmentation"),
+ ("mobilevitv2", "MobileViTV2ForSemanticSegmentation"),
+ ("segformer", "SegformerForSemanticSegmentation"),
+ ("upernet", "UperNetForSemanticSegmentation"),
+ ]
+)
+
+MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Instance Segmentation mapping
+ # MaskFormerForInstanceSegmentation can be removed from this mapping in v5
+ ("maskformer", "MaskFormerForInstanceSegmentation"),
+ ]
+)
+
+MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Universal Segmentation mapping
+ ("detr", "DetrForSegmentation"),
+ ("eomt", "EomtForUniversalSegmentation"),
+ ("mask2former", "Mask2FormerForUniversalSegmentation"),
+ ("maskformer", "MaskFormerForInstanceSegmentation"),
+ ("oneformer", "OneFormerForUniversalSegmentation"),
+ ]
+)
+
+MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ ("timesformer", "TimesformerForVideoClassification"),
+ ("videomae", "VideoMAEForVideoClassification"),
+ ("vivit", "VivitForVideoClassification"),
+ ("vjepa2", "VJEPA2ForVideoClassification"),
+ ]
+)
+
+MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
+ [
+ ("blip", "BlipForConditionalGeneration"),
+ ("blip-2", "Blip2ForConditionalGeneration"),
+ ("chameleon", "ChameleonForConditionalGeneration"),
+ ("git", "GitForCausalLM"),
+ ("idefics2", "Idefics2ForConditionalGeneration"),
+ ("idefics3", "Idefics3ForConditionalGeneration"),
+ ("instructblip", "InstructBlipForConditionalGeneration"),
+ ("instructblipvideo", "InstructBlipVideoForConditionalGeneration"),
+ ("kosmos-2", "Kosmos2ForConditionalGeneration"),
+ ("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"),
+ ("llava", "LlavaForConditionalGeneration"),
+ ("llava_next", "LlavaNextForConditionalGeneration"),
+ ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
+ ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
+ ("mistral3", "Mistral3ForConditionalGeneration"),
+ ("mllama", "MllamaForConditionalGeneration"),
+ ("ovis2", "Ovis2ForConditionalGeneration"),
+ ("paligemma", "PaliGemmaForConditionalGeneration"),
+ ("pix2struct", "Pix2StructForConditionalGeneration"),
+ ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
+ ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
+ ("qwen3_vl", "Qwen3VLForConditionalGeneration"),
+ ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"),
+ ("video_llava", "VideoLlavaForConditionalGeneration"),
+ ("vipllava", "VipLlavaForConditionalGeneration"),
+ ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
+ ]
+)
+
+MODEL_FOR_RETRIEVAL_MAPPING_NAMES = OrderedDict(
+ [
+ ("colpali", "ColPaliForRetrieval"),
+ ]
+)
+
+MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
+ [
+ ("aria", "AriaForConditionalGeneration"),
+ ("aya_vision", "AyaVisionForConditionalGeneration"),
+ ("blip", "BlipForConditionalGeneration"),
+ ("blip-2", "Blip2ForConditionalGeneration"),
+ ("chameleon", "ChameleonForConditionalGeneration"),
+ ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),
+ ("deepseek_vl", "DeepseekVLForConditionalGeneration"),
+ ("deepseek_vl_hybrid", "DeepseekVLHybridForConditionalGeneration"),
+ ("emu3", "Emu3ForConditionalGeneration"),
+ ("evolla", "EvollaForProteinText2Text"),
+ ("florence2", "Florence2ForConditionalGeneration"),
+ ("fuyu", "FuyuForCausalLM"),
+ ("gemma3", "Gemma3ForConditionalGeneration"),
+ ("gemma3n", "Gemma3nForConditionalGeneration"),
+ ("git", "GitForCausalLM"),
+ ("glm4v", "Glm4vForConditionalGeneration"),
+ ("glm4v_moe", "Glm4vMoeForConditionalGeneration"),
+ ("got_ocr2", "GotOcr2ForConditionalGeneration"),
+ ("idefics", "IdeficsForVisionText2Text"),
+ ("idefics2", "Idefics2ForConditionalGeneration"),
+ ("idefics3", "Idefics3ForConditionalGeneration"),
+ ("instructblip", "InstructBlipForConditionalGeneration"),
+ ("internvl", "InternVLForConditionalGeneration"),
+ ("janus", "JanusForConditionalGeneration"),
+ ("kosmos-2", "Kosmos2ForConditionalGeneration"),
+ ("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"),
+ ("lfm2_vl", "Lfm2VlForConditionalGeneration"),
+ ("llama4", "Llama4ForConditionalGeneration"),
+ ("llava", "LlavaForConditionalGeneration"),
+ ("llava_next", "LlavaNextForConditionalGeneration"),
+ ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
+ ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
+ ("mistral3", "Mistral3ForConditionalGeneration"),
+ ("mllama", "MllamaForConditionalGeneration"),
+ ("ovis2", "Ovis2ForConditionalGeneration"),
+ ("paligemma", "PaliGemmaForConditionalGeneration"),
+ ("perception_lm", "PerceptionLMForConditionalGeneration"),
+ ("pix2struct", "Pix2StructForConditionalGeneration"),
+ ("pixtral", "LlavaForConditionalGeneration"),
+ ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
+ ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
+ ("qwen3_vl", "Qwen3VLForConditionalGeneration"),
+ ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"),
+ ("shieldgemma2", "Gemma3ForConditionalGeneration"),
+ ("smolvlm", "SmolVLMForConditionalGeneration"),
+ ("udop", "UdopForConditionalGeneration"),
+ ("vipllava", "VipLlavaForConditionalGeneration"),
+ ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
+ ]
+)
+
+MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Masked LM mapping
+ ("albert", "AlbertForMaskedLM"),
+ ("bart", "BartForConditionalGeneration"),
+ ("bert", "BertForMaskedLM"),
+ ("big_bird", "BigBirdForMaskedLM"),
+ ("camembert", "CamembertForMaskedLM"),
+ ("convbert", "ConvBertForMaskedLM"),
+ ("data2vec-text", "Data2VecTextForMaskedLM"),
+ ("deberta", "DebertaForMaskedLM"),
+ ("deberta-v2", "DebertaV2ForMaskedLM"),
+ ("distilbert", "DistilBertForMaskedLM"),
+ ("electra", "ElectraForMaskedLM"),
+ ("ernie", "ErnieForMaskedLM"),
+ ("esm", "EsmForMaskedLM"),
+ ("flaubert", "FlaubertWithLMHeadModel"),
+ ("fnet", "FNetForMaskedLM"),
+ ("funnel", "FunnelForMaskedLM"),
+ ("ibert", "IBertForMaskedLM"),
+ ("layoutlm", "LayoutLMForMaskedLM"),
+ ("longformer", "LongformerForMaskedLM"),
+ ("luke", "LukeForMaskedLM"),
+ ("mbart", "MBartForConditionalGeneration"),
+ ("mega", "MegaForMaskedLM"),
+ ("megatron-bert", "MegatronBertForMaskedLM"),
+ ("mobilebert", "MobileBertForMaskedLM"),
+ ("modernbert", "ModernBertForMaskedLM"),
+ ("mpnet", "MPNetForMaskedLM"),
+ ("mra", "MraForMaskedLM"),
+ ("mvp", "MvpForConditionalGeneration"),
+ ("nezha", "NezhaForMaskedLM"),
+ ("nystromformer", "NystromformerForMaskedLM"),
+ ("perceiver", "PerceiverForMaskedLM"),
+ ("qdqbert", "QDQBertForMaskedLM"),
+ ("reformer", "ReformerForMaskedLM"),
+ ("rembert", "RemBertForMaskedLM"),
+ ("roberta", "RobertaForMaskedLM"),
+ ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
+ ("roc_bert", "RoCBertForMaskedLM"),
+ ("roformer", "RoFormerForMaskedLM"),
+ ("squeezebert", "SqueezeBertForMaskedLM"),
+ ("tapas", "TapasForMaskedLM"),
+ ("wav2vec2", "Wav2Vec2ForMaskedLM"),
+ ("xlm", "XLMWithLMHeadModel"),
+ ("xlm-roberta", "XLMRobertaForMaskedLM"),
+ ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
+ ("xmod", "XmodForMaskedLM"),
+ ("yoso", "YosoForMaskedLM"),
+ ]
+)
+
+MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Object Detection mapping
+ ("conditional_detr", "ConditionalDetrForObjectDetection"),
+ ("d_fine", "DFineForObjectDetection"),
+ ("dab-detr", "DabDetrForObjectDetection"),
+ ("deformable_detr", "DeformableDetrForObjectDetection"),
+ ("deta", "DetaForObjectDetection"),
+ ("detr", "DetrForObjectDetection"),
+ ("rt_detr", "RTDetrForObjectDetection"),
+ ("rt_detr_v2", "RTDetrV2ForObjectDetection"),
+ ("table-transformer", "TableTransformerForObjectDetection"),
+ ("yolos", "YolosForObjectDetection"),
+ ]
+)
+
+MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Zero Shot Object Detection mapping
+ ("grounding-dino", "GroundingDinoForObjectDetection"),
+ ("mm-grounding-dino", "MMGroundingDinoForObjectDetection"),
+ ("omdet-turbo", "OmDetTurboForObjectDetection"),
+ ("owlv2", "Owlv2ForObjectDetection"),
+ ("owlvit", "OwlViTForObjectDetection"),
+ ]
+)
+
+MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for depth estimation mapping
+ ("depth_anything", "DepthAnythingForDepthEstimation"),
+ ("depth_pro", "DepthProForDepthEstimation"),
+ ("dpt", "DPTForDepthEstimation"),
+ ("glpn", "GLPNForDepthEstimation"),
+ ("prompt_depth_anything", "PromptDepthAnythingForDepthEstimation"),
+ ("zoedepth", "ZoeDepthForDepthEstimation"),
+ ]
+)
+MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Seq2Seq Causal LM mapping
+ ("bart", "BartForConditionalGeneration"),
+ ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
+ ("blenderbot", "BlenderbotForConditionalGeneration"),
+ ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
+ ("encoder-decoder", "EncoderDecoderModel"),
+ ("fsmt", "FSMTForConditionalGeneration"),
+ ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
+ ("granite_speech", "GraniteSpeechForConditionalGeneration"),
+ ("led", "LEDForConditionalGeneration"),
+ ("longt5", "LongT5ForConditionalGeneration"),
+ ("m2m_100", "M2M100ForConditionalGeneration"),
+ ("marian", "MarianMTModel"),
+ ("mbart", "MBartForConditionalGeneration"),
+ ("mt5", "MT5ForConditionalGeneration"),
+ ("mvp", "MvpForConditionalGeneration"),
+ ("nllb-moe", "NllbMoeForConditionalGeneration"),
+ ("pegasus", "PegasusForConditionalGeneration"),
+ ("pegasus_x", "PegasusXForConditionalGeneration"),
+ ("plbart", "PLBartForConditionalGeneration"),
+ ("prophetnet", "ProphetNetForConditionalGeneration"),
+ ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
+ ("seamless_m4t", "SeamlessM4TForTextToText"),
+ ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"),
+ ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
+ ("t5", "T5ForConditionalGeneration"),
+ ("t5gemma", "T5GemmaForConditionalGeneration"),
+ ("umt5", "UMT5ForConditionalGeneration"),
+ ("voxtral", "VoxtralForConditionalGeneration"),
+ ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
+ ]
+)
+
+MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
+ [
+ ("dia", "DiaForConditionalGeneration"),
+ ("granite_speech", "GraniteSpeechForConditionalGeneration"),
+ ("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"),
+ ("moonshine", "MoonshineForConditionalGeneration"),
+ ("pop2piano", "Pop2PianoForConditionalGeneration"),
+ ("seamless_m4t", "SeamlessM4TForSpeechToText"),
+ ("seamless_m4t_v2", "SeamlessM4Tv2ForSpeechToText"),
+ ("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
+ ("speech_to_text", "Speech2TextForConditionalGeneration"),
+ ("speecht5", "SpeechT5ForSpeechToText"),
+ ("whisper", "WhisperForConditionalGeneration"),
+ ]
+)
+
+MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Sequence Classification mapping
+ ("albert", "AlbertForSequenceClassification"),
+ ("arcee", "ArceeForSequenceClassification"),
+ ("bart", "BartForSequenceClassification"),
+ ("bert", "BertForSequenceClassification"),
+ ("big_bird", "BigBirdForSequenceClassification"),
+ ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
+ ("biogpt", "BioGptForSequenceClassification"),
+ ("bloom", "BloomForSequenceClassification"),
+ ("camembert", "CamembertForSequenceClassification"),
+ ("canine", "CanineForSequenceClassification"),
+ ("code_llama", "LlamaForSequenceClassification"),
+ ("convbert", "ConvBertForSequenceClassification"),
+ ("ctrl", "CTRLForSequenceClassification"),
+ ("data2vec-text", "Data2VecTextForSequenceClassification"),
+ ("deberta", "DebertaForSequenceClassification"),
+ ("deberta-v2", "DebertaV2ForSequenceClassification"),
+ ("deepseek_v2", "DeepseekV2ForSequenceClassification"),
+ ("deepseek_v3", "DeepseekV3ForSequenceClassification"),
+ ("diffllama", "DiffLlamaForSequenceClassification"),
+ ("distilbert", "DistilBertForSequenceClassification"),
+ ("doge", "DogeForSequenceClassification"),
+ ("electra", "ElectraForSequenceClassification"),
+ ("ernie", "ErnieForSequenceClassification"),
+ ("ernie_m", "ErnieMForSequenceClassification"),
+ ("esm", "EsmForSequenceClassification"),
+ ("exaone4", "Exaone4ForSequenceClassification"),
+ ("falcon", "FalconForSequenceClassification"),
+ ("flaubert", "FlaubertForSequenceClassification"),
+ ("fnet", "FNetForSequenceClassification"),
+ ("funnel", "FunnelForSequenceClassification"),
+ ("gemma", "GemmaForSequenceClassification"),
+ ("gemma2", "Gemma2ForSequenceClassification"),
+ ("gemma3", "Gemma3ForSequenceClassification"),
+ ("gemma3_text", "Gemma3TextForSequenceClassification"),
+ ("glm", "GlmForSequenceClassification"),
+ ("glm4", "Glm4ForSequenceClassification"),
+ ("gpt-sw3", "GPT2ForSequenceClassification"),
+ ("gpt2", "GPT2ForSequenceClassification"),
+ ("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
+ ("gpt_neo", "GPTNeoForSequenceClassification"),
+ ("gpt_neox", "GPTNeoXForSequenceClassification"),
+ ("gpt_oss", "GptOssForSequenceClassification"),
+ ("gptj", "GPTJForSequenceClassification"),
+ ("helium", "HeliumForSequenceClassification"),
+ ("hunyuan_v1_dense", "HunYuanDenseV1ForSequenceClassification"),
+ ("hunyuan_v1_moe", "HunYuanMoEV1ForSequenceClassification"),
+ ("ibert", "IBertForSequenceClassification"),
+ ("jamba", "JambaForSequenceClassification"),
+ ("jetmoe", "JetMoeForSequenceClassification"),
+ ("layoutlm", "LayoutLMForSequenceClassification"),
+ ("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
+ ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
+ ("led", "LEDForSequenceClassification"),
+ ("lilt", "LiltForSequenceClassification"),
+ ("llama", "LlamaForSequenceClassification"),
+ ("longformer", "LongformerForSequenceClassification"),
+ ("luke", "LukeForSequenceClassification"),
+ ("markuplm", "MarkupLMForSequenceClassification"),
+ ("mbart", "MBartForSequenceClassification"),
+ ("mega", "MegaForSequenceClassification"),
+ ("megatron-bert", "MegatronBertForSequenceClassification"),
+ ("minimax", "MiniMaxForSequenceClassification"),
+ ("ministral", "MinistralForSequenceClassification"),
+ ("mistral", "MistralForSequenceClassification"),
+ ("mixtral", "MixtralForSequenceClassification"),
+ ("mobilebert", "MobileBertForSequenceClassification"),
+ ("modernbert", "ModernBertForSequenceClassification"),
+ ("modernbert-decoder", "ModernBertDecoderForSequenceClassification"),
+ ("mpnet", "MPNetForSequenceClassification"),
+ ("mpt", "MptForSequenceClassification"),
+ ("mra", "MraForSequenceClassification"),
+ ("mt5", "MT5ForSequenceClassification"),
+ ("mvp", "MvpForSequenceClassification"),
+ ("nemotron", "NemotronForSequenceClassification"),
+ ("nezha", "NezhaForSequenceClassification"),
+ ("nystromformer", "NystromformerForSequenceClassification"),
+ ("open-llama", "OpenLlamaForSequenceClassification"),
+ ("openai-gpt", "OpenAIGPTForSequenceClassification"),
+ ("opt", "OPTForSequenceClassification"),
+ ("perceiver", "PerceiverForSequenceClassification"),
+ ("persimmon", "PersimmonForSequenceClassification"),
+ ("phi", "PhiForSequenceClassification"),
+ ("phi3", "Phi3ForSequenceClassification"),
+ ("phimoe", "PhimoeForSequenceClassification"),
+ ("plbart", "PLBartForSequenceClassification"),
+ ("qdqbert", "QDQBertForSequenceClassification"),
+ ("qwen2", "Qwen2ForSequenceClassification"),
+ ("qwen2_moe", "Qwen2MoeForSequenceClassification"),
+ ("qwen3", "Qwen3ForSequenceClassification"),
+ ("qwen3_moe", "Qwen3MoeForSequenceClassification"),
+ ("qwen3_next", "Qwen3NextForSequenceClassification"),
+ ("reformer", "ReformerForSequenceClassification"),
+ ("rembert", "RemBertForSequenceClassification"),
+ ("roberta", "RobertaForSequenceClassification"),
+ ("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"),
+ ("roc_bert", "RoCBertForSequenceClassification"),
+ ("roformer", "RoFormerForSequenceClassification"),
+ ("seed_oss", "SeedOssForSequenceClassification"),
+ ("smollm3", "SmolLM3ForSequenceClassification"),
+ ("squeezebert", "SqueezeBertForSequenceClassification"),
+ ("stablelm", "StableLmForSequenceClassification"),
+ ("starcoder2", "Starcoder2ForSequenceClassification"),
+ ("t5", "T5ForSequenceClassification"),
+ ("t5gemma", "T5GemmaForSequenceClassification"),
+ ("tapas", "TapasForSequenceClassification"),
+ ("transfo-xl", "TransfoXLForSequenceClassification"),
+ ("umt5", "UMT5ForSequenceClassification"),
+ ("xlm", "XLMForSequenceClassification"),
+ ("xlm-roberta", "XLMRobertaForSequenceClassification"),
+ ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
+ ("xlnet", "XLNetForSequenceClassification"),
+ ("xmod", "XmodForSequenceClassification"),
+ ("yoso", "YosoForSequenceClassification"),
+ ("zamba", "ZambaForSequenceClassification"),
+ ("zamba2", "Zamba2ForSequenceClassification"),
+ ]
+)
+
+MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Question Answering mapping
+ ("albert", "AlbertForQuestionAnswering"),
+ ("arcee", "ArceeForQuestionAnswering"),
+ ("bart", "BartForQuestionAnswering"),
+ ("bert", "BertForQuestionAnswering"),
+ ("big_bird", "BigBirdForQuestionAnswering"),
+ ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
+ ("bloom", "BloomForQuestionAnswering"),
+ ("camembert", "CamembertForQuestionAnswering"),
+ ("canine", "CanineForQuestionAnswering"),
+ ("convbert", "ConvBertForQuestionAnswering"),
+ ("data2vec-text", "Data2VecTextForQuestionAnswering"),
+ ("deberta", "DebertaForQuestionAnswering"),
+ ("deberta-v2", "DebertaV2ForQuestionAnswering"),
+ ("diffllama", "DiffLlamaForQuestionAnswering"),
+ ("distilbert", "DistilBertForQuestionAnswering"),
+ ("electra", "ElectraForQuestionAnswering"),
+ ("ernie", "ErnieForQuestionAnswering"),
+ ("ernie_m", "ErnieMForQuestionAnswering"),
+ ("exaone4", "Exaone4ForQuestionAnswering"),
+ ("falcon", "FalconForQuestionAnswering"),
+ ("flaubert", "FlaubertForQuestionAnsweringSimple"),
+ ("fnet", "FNetForQuestionAnswering"),
+ ("funnel", "FunnelForQuestionAnswering"),
+ ("gpt2", "GPT2ForQuestionAnswering"),
+ ("gpt_neo", "GPTNeoForQuestionAnswering"),
+ ("gpt_neox", "GPTNeoXForQuestionAnswering"),
+ ("gptj", "GPTJForQuestionAnswering"),
+ ("ibert", "IBertForQuestionAnswering"),
+ ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
+ ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
+ ("led", "LEDForQuestionAnswering"),
+ ("lilt", "LiltForQuestionAnswering"),
+ ("llama", "LlamaForQuestionAnswering"),
+ ("longformer", "LongformerForQuestionAnswering"),
+ ("luke", "LukeForQuestionAnswering"),
+ ("lxmert", "LxmertForQuestionAnswering"),
+ ("markuplm", "MarkupLMForQuestionAnswering"),
+ ("mbart", "MBartForQuestionAnswering"),
+ ("mega", "MegaForQuestionAnswering"),
+ ("megatron-bert", "MegatronBertForQuestionAnswering"),
+ ("minimax", "MiniMaxForQuestionAnswering"),
+ ("ministral", "MinistralForQuestionAnswering"),
+ ("mistral", "MistralForQuestionAnswering"),
+ ("mixtral", "MixtralForQuestionAnswering"),
+ ("mobilebert", "MobileBertForQuestionAnswering"),
+ ("modernbert", "ModernBertForQuestionAnswering"),
+ ("mpnet", "MPNetForQuestionAnswering"),
+ ("mpt", "MptForQuestionAnswering"),
+ ("mra", "MraForQuestionAnswering"),
+ ("mt5", "MT5ForQuestionAnswering"),
+ ("mvp", "MvpForQuestionAnswering"),
+ ("nemotron", "NemotronForQuestionAnswering"),
+ ("nezha", "NezhaForQuestionAnswering"),
+ ("nystromformer", "NystromformerForQuestionAnswering"),
+ ("opt", "OPTForQuestionAnswering"),
+ ("qdqbert", "QDQBertForQuestionAnswering"),
+ ("qwen2", "Qwen2ForQuestionAnswering"),
+ ("qwen2_moe", "Qwen2MoeForQuestionAnswering"),
+ ("qwen3", "Qwen3ForQuestionAnswering"),
+ ("qwen3_moe", "Qwen3MoeForQuestionAnswering"),
+ ("qwen3_next", "Qwen3NextForQuestionAnswering"),
+ ("reformer", "ReformerForQuestionAnswering"),
+ ("rembert", "RemBertForQuestionAnswering"),
+ ("roberta", "RobertaForQuestionAnswering"),
+ ("roberta-prelayernorm", "RobertaPreLayerNormForQuestionAnswering"),
+ ("roc_bert", "RoCBertForQuestionAnswering"),
+ ("roformer", "RoFormerForQuestionAnswering"),
+ ("seed_oss", "SeedOssForQuestionAnswering"),
+ ("smollm3", "SmolLM3ForQuestionAnswering"),
+ ("splinter", "SplinterForQuestionAnswering"),
+ ("squeezebert", "SqueezeBertForQuestionAnswering"),
+ ("t5", "T5ForQuestionAnswering"),
+ ("umt5", "UMT5ForQuestionAnswering"),
+ ("xlm", "XLMForQuestionAnsweringSimple"),
+ ("xlm-roberta", "XLMRobertaForQuestionAnswering"),
+ ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
+ ("xlnet", "XLNetForQuestionAnsweringSimple"),
+ ("xmod", "XmodForQuestionAnswering"),
+ ("yoso", "YosoForQuestionAnswering"),
+ ]
+)
+
+MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Table Question Answering mapping
+ ("tapas", "TapasForQuestionAnswering"),
+ ]
+)
+
+MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
+ [
+ ("blip", "BlipForQuestionAnswering"),
+ ("blip-2", "Blip2ForConditionalGeneration"),
+ ("vilt", "ViltForQuestionAnswering"),
+ ]
+)
+
+MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
+ [
+ ("layoutlm", "LayoutLMForQuestionAnswering"),
+ ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
+ ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
+ ]
+)
+
+MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Token Classification mapping
+ ("albert", "AlbertForTokenClassification"),
+ ("apertus", "ApertusForTokenClassification"),
+ ("arcee", "ArceeForTokenClassification"),
+ ("bert", "BertForTokenClassification"),
+ ("big_bird", "BigBirdForTokenClassification"),
+ ("biogpt", "BioGptForTokenClassification"),
+ ("bloom", "BloomForTokenClassification"),
+ ("bros", "BrosForTokenClassification"),
+ ("camembert", "CamembertForTokenClassification"),
+ ("canine", "CanineForTokenClassification"),
+ ("convbert", "ConvBertForTokenClassification"),
+ ("data2vec-text", "Data2VecTextForTokenClassification"),
+ ("deberta", "DebertaForTokenClassification"),
+ ("deberta-v2", "DebertaV2ForTokenClassification"),
+ ("deepseek_v3", "DeepseekV3ForTokenClassification"),
+ ("diffllama", "DiffLlamaForTokenClassification"),
+ ("distilbert", "DistilBertForTokenClassification"),
+ ("electra", "ElectraForTokenClassification"),
+ ("ernie", "ErnieForTokenClassification"),
+ ("ernie_m", "ErnieMForTokenClassification"),
+ ("esm", "EsmForTokenClassification"),
+ ("exaone4", "Exaone4ForTokenClassification"),
+ ("falcon", "FalconForTokenClassification"),
+ ("flaubert", "FlaubertForTokenClassification"),
+ ("fnet", "FNetForTokenClassification"),
+ ("funnel", "FunnelForTokenClassification"),
+ ("gemma", "GemmaForTokenClassification"),
+ ("gemma2", "Gemma2ForTokenClassification"),
+ ("glm", "GlmForTokenClassification"),
+ ("glm4", "Glm4ForTokenClassification"),
+ ("gpt-sw3", "GPT2ForTokenClassification"),
+ ("gpt2", "GPT2ForTokenClassification"),
+ ("gpt_bigcode", "GPTBigCodeForTokenClassification"),
+ ("gpt_neo", "GPTNeoForTokenClassification"),
+ ("gpt_neox", "GPTNeoXForTokenClassification"),
+ ("gpt_oss", "GptOssForTokenClassification"),
+ ("helium", "HeliumForTokenClassification"),
+ ("ibert", "IBertForTokenClassification"),
+ ("layoutlm", "LayoutLMForTokenClassification"),
+ ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
+ ("layoutlmv3", "LayoutLMv3ForTokenClassification"),
+ ("lilt", "LiltForTokenClassification"),
+ ("llama", "LlamaForTokenClassification"),
+ ("longformer", "LongformerForTokenClassification"),
+ ("luke", "LukeForTokenClassification"),
+ ("markuplm", "MarkupLMForTokenClassification"),
+ ("mega", "MegaForTokenClassification"),
+ ("megatron-bert", "MegatronBertForTokenClassification"),
+ ("minimax", "MiniMaxForTokenClassification"),
+ ("ministral", "MinistralForTokenClassification"),
+ ("mistral", "MistralForTokenClassification"),
+ ("mixtral", "MixtralForTokenClassification"),
+ ("mobilebert", "MobileBertForTokenClassification"),
+ ("modernbert", "ModernBertForTokenClassification"),
+ ("mpnet", "MPNetForTokenClassification"),
+ ("mpt", "MptForTokenClassification"),
+ ("mra", "MraForTokenClassification"),
+ ("mt5", "MT5ForTokenClassification"),
+ ("nemotron", "NemotronForTokenClassification"),
+ ("nezha", "NezhaForTokenClassification"),
+ ("nystromformer", "NystromformerForTokenClassification"),
+ ("persimmon", "PersimmonForTokenClassification"),
+ ("phi", "PhiForTokenClassification"),
+ ("phi3", "Phi3ForTokenClassification"),
+ ("qdqbert", "QDQBertForTokenClassification"),
+ ("qwen2", "Qwen2ForTokenClassification"),
+ ("qwen2_moe", "Qwen2MoeForTokenClassification"),
+ ("qwen3", "Qwen3ForTokenClassification"),
+ ("qwen3_moe", "Qwen3MoeForTokenClassification"),
+ ("qwen3_next", "Qwen3NextForTokenClassification"),
+ ("rembert", "RemBertForTokenClassification"),
+ ("roberta", "RobertaForTokenClassification"),
+ ("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
+ ("roc_bert", "RoCBertForTokenClassification"),
+ ("roformer", "RoFormerForTokenClassification"),
+ ("seed_oss", "SeedOssForTokenClassification"),
+ ("smollm3", "SmolLM3ForTokenClassification"),
+ ("squeezebert", "SqueezeBertForTokenClassification"),
+ ("stablelm", "StableLmForTokenClassification"),
+ ("starcoder2", "Starcoder2ForTokenClassification"),
+ ("t5", "T5ForTokenClassification"),
+ ("t5gemma", "T5GemmaForTokenClassification"),
+ ("umt5", "UMT5ForTokenClassification"),
+ ("xlm", "XLMForTokenClassification"),
+ ("xlm-roberta", "XLMRobertaForTokenClassification"),
+ ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
+ ("xlnet", "XLNetForTokenClassification"),
+ ("xmod", "XmodForTokenClassification"),
+ ("yoso", "YosoForTokenClassification"),
+ ]
+)
+
+MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Multiple Choice mapping
+ ("albert", "AlbertForMultipleChoice"),
+ ("bert", "BertForMultipleChoice"),
+ ("big_bird", "BigBirdForMultipleChoice"),
+ ("camembert", "CamembertForMultipleChoice"),
+ ("canine", "CanineForMultipleChoice"),
+ ("convbert", "ConvBertForMultipleChoice"),
+ ("data2vec-text", "Data2VecTextForMultipleChoice"),
+ ("deberta-v2", "DebertaV2ForMultipleChoice"),
+ ("distilbert", "DistilBertForMultipleChoice"),
+ ("electra", "ElectraForMultipleChoice"),
+ ("ernie", "ErnieForMultipleChoice"),
+ ("ernie_m", "ErnieMForMultipleChoice"),
+ ("flaubert", "FlaubertForMultipleChoice"),
+ ("fnet", "FNetForMultipleChoice"),
+ ("funnel", "FunnelForMultipleChoice"),
+ ("ibert", "IBertForMultipleChoice"),
+ ("longformer", "LongformerForMultipleChoice"),
+ ("luke", "LukeForMultipleChoice"),
+ ("mega", "MegaForMultipleChoice"),
+ ("megatron-bert", "MegatronBertForMultipleChoice"),
+ ("mobilebert", "MobileBertForMultipleChoice"),
+ ("modernbert", "ModernBertForMultipleChoice"),
+ ("mpnet", "MPNetForMultipleChoice"),
+ ("mra", "MraForMultipleChoice"),
+ ("nezha", "NezhaForMultipleChoice"),
+ ("nystromformer", "NystromformerForMultipleChoice"),
+ ("qdqbert", "QDQBertForMultipleChoice"),
+ ("rembert", "RemBertForMultipleChoice"),
+ ("roberta", "RobertaForMultipleChoice"),
+ ("roberta-prelayernorm", "RobertaPreLayerNormForMultipleChoice"),
+ ("roc_bert", "RoCBertForMultipleChoice"),
+ ("roformer", "RoFormerForMultipleChoice"),
+ ("squeezebert", "SqueezeBertForMultipleChoice"),
+ ("xlm", "XLMForMultipleChoice"),
+ ("xlm-roberta", "XLMRobertaForMultipleChoice"),
+ ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
+ ("xlnet", "XLNetForMultipleChoice"),
+ ("xmod", "XmodForMultipleChoice"),
+ ("yoso", "YosoForMultipleChoice"),
+ ]
+)
+
+MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
+ [
+ ("bert", "BertForNextSentencePrediction"),
+ ("ernie", "ErnieForNextSentencePrediction"),
+ ("fnet", "FNetForNextSentencePrediction"),
+ ("megatron-bert", "MegatronBertForNextSentencePrediction"),
+ ("mobilebert", "MobileBertForNextSentencePrediction"),
+ ("nezha", "NezhaForNextSentencePrediction"),
+ ("qdqbert", "QDQBertForNextSentencePrediction"),
+ ]
+)
+
+MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Audio Classification mapping
+ ("audio-spectrogram-transformer", "ASTForAudioClassification"),
+ ("data2vec-audio", "Data2VecAudioForSequenceClassification"),
+ ("hubert", "HubertForSequenceClassification"),
+ ("sew", "SEWForSequenceClassification"),
+ ("sew-d", "SEWDForSequenceClassification"),
+ ("unispeech", "UniSpeechForSequenceClassification"),
+ ("unispeech-sat", "UniSpeechSatForSequenceClassification"),
+ ("wav2vec2", "Wav2Vec2ForSequenceClassification"),
+ ("wav2vec2-bert", "Wav2Vec2BertForSequenceClassification"),
+ ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
+ ("wavlm", "WavLMForSequenceClassification"),
+ ("whisper", "WhisperForAudioClassification"),
+ ]
+)
+
+MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Connectionist temporal classification (CTC) mapping
+ ("data2vec-audio", "Data2VecAudioForCTC"),
+ ("hubert", "HubertForCTC"),
+ ("mctct", "MCTCTForCTC"),
+ ("parakeet_ctc", "ParakeetForCTC"),
+ ("sew", "SEWForCTC"),
+ ("sew-d", "SEWDForCTC"),
+ ("unispeech", "UniSpeechForCTC"),
+ ("unispeech-sat", "UniSpeechSatForCTC"),
+ ("wav2vec2", "Wav2Vec2ForCTC"),
+ ("wav2vec2-bert", "Wav2Vec2BertForCTC"),
+ ("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
+ ("wavlm", "WavLMForCTC"),
+ ]
+)
+
+MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Audio Classification mapping
+ ("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
+ ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
+ ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
+ ("wav2vec2-bert", "Wav2Vec2BertForAudioFrameClassification"),
+ ("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
+ ("wavlm", "WavLMForAudioFrameClassification"),
+ ]
+)
+
+MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Audio Classification mapping
+ ("data2vec-audio", "Data2VecAudioForXVector"),
+ ("unispeech-sat", "UniSpeechSatForXVector"),
+ ("wav2vec2", "Wav2Vec2ForXVector"),
+ ("wav2vec2-bert", "Wav2Vec2BertForXVector"),
+ ("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
+ ("wavlm", "WavLMForXVector"),
+ ]
+)
+
+MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Text-To-Spectrogram mapping
+ ("fastspeech2_conformer", "FastSpeech2ConformerModel"),
+ ("speecht5", "SpeechT5ForTextToSpeech"),
+ ]
+)
+
+MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Text-To-Waveform mapping
+ ("bark", "BarkModel"),
+ ("csm", "CsmForConditionalGeneration"),
+ ("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
+ ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
+ ("musicgen", "MusicgenForConditionalGeneration"),
+ ("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),
+ ("qwen2_5_omni", "Qwen2_5OmniForConditionalGeneration"),
+ ("qwen3_omni_moe", "Qwen3OmniMoeForConditionalGeneration"),
+ ("seamless_m4t", "SeamlessM4TForTextToSpeech"),
+ ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToSpeech"),
+ ("vits", "VitsModel"),
+ ]
+)
+
+MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Zero Shot Image Classification mapping
+ ("align", "AlignModel"),
+ ("altclip", "AltCLIPModel"),
+ ("blip", "BlipModel"),
+ ("blip-2", "Blip2ForImageTextRetrieval"),
+ ("chinese_clip", "ChineseCLIPModel"),
+ ("clip", "CLIPModel"),
+ ("clipseg", "CLIPSegModel"),
+ ("metaclip_2", "MetaClip2Model"),
+ ("siglip", "SiglipModel"),
+ ("siglip2", "Siglip2Model"),
+ ]
+)
+
+MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
+ [
+ # Backbone mapping
+ ("beit", "BeitBackbone"),
+ ("bit", "BitBackbone"),
+ ("convnext", "ConvNextBackbone"),
+ ("convnextv2", "ConvNextV2Backbone"),
+ ("dinat", "DinatBackbone"),
+ ("dinov2", "Dinov2Backbone"),
+ ("dinov2_with_registers", "Dinov2WithRegistersBackbone"),
+ ("focalnet", "FocalNetBackbone"),
+ ("hgnet_v2", "HGNetV2Backbone"),
+ ("hiera", "HieraBackbone"),
+ ("maskformer-swin", "MaskFormerSwinBackbone"),
+ ("nat", "NatBackbone"),
+ ("pvt_v2", "PvtV2Backbone"),
+ ("resnet", "ResNetBackbone"),
+ ("rt_detr_resnet", "RTDetrResNetBackbone"),
+ ("swin", "SwinBackbone"),
+ ("swinv2", "Swinv2Backbone"),
+ ("textnet", "TextNetBackbone"),
+ ("timm_backbone", "TimmBackbone"),
+ ("vitdet", "VitDetBackbone"),
+ ("vitpose_backbone", "VitPoseBackbone"),
+ ]
+)
+
+MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
+ [
+ ("edgetam", "EdgeTamModel"),
+ ("edgetam_video", "EdgeTamModel"),
+ ("sam", "SamModel"),
+ ("sam2", "Sam2Model"),
+ ("sam2_video", "Sam2Model"),
+ ("sam_hq", "SamHQModel"),
+ ]
+)
+
+
+MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
+ [
+ ("superpoint", "SuperPointForKeypointDetection"),
+ ]
+)
+
+MODEL_FOR_KEYPOINT_MATCHING_MAPPING_NAMES = OrderedDict(
+ [
+ ("efficientloftr", "EfficientLoFTRForKeypointMatching"),
+ ("lightglue", "LightGlueForKeypointMatching"),
+ ("superglue", "SuperGlueForKeypointMatching"),
+ ]
+)
+
+MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
+ [
+ ("albert", "AlbertModel"),
+ ("bert", "BertModel"),
+ ("big_bird", "BigBirdModel"),
+ ("clip_text_model", "CLIPTextModel"),
+ ("data2vec-text", "Data2VecTextModel"),
+ ("deberta", "DebertaModel"),
+ ("deberta-v2", "DebertaV2Model"),
+ ("distilbert", "DistilBertModel"),
+ ("electra", "ElectraModel"),
+ ("emu3", "Emu3TextModel"),
+ ("flaubert", "FlaubertModel"),
+ ("ibert", "IBertModel"),
+ ("llama4", "Llama4TextModel"),
+ ("longformer", "LongformerModel"),
+ ("mllama", "MllamaTextModel"),
+ ("mobilebert", "MobileBertModel"),
+ ("mt5", "MT5EncoderModel"),
+ ("nystromformer", "NystromformerModel"),
+ ("reformer", "ReformerModel"),
+ ("rembert", "RemBertModel"),
+ ("roberta", "RobertaModel"),
+ ("roberta-prelayernorm", "RobertaPreLayerNormModel"),
+ ("roc_bert", "RoCBertModel"),
+ ("roformer", "RoFormerModel"),
+ ("squeezebert", "SqueezeBertModel"),
+ ("t5", "T5EncoderModel"),
+ ("t5gemma", "T5GemmaEncoderModel"),
+ ("umt5", "UMT5EncoderModel"),
+ ("xlm", "XLMModel"),
+ ("xlm-roberta", "XLMRobertaModel"),
+ ("xlm-roberta-xl", "XLMRobertaXLModel"),
+ ]
+)
+
+MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ ("patchtsmixer", "PatchTSMixerForTimeSeriesClassification"),
+ ("patchtst", "PatchTSTForClassification"),
+ ]
+)
+
+MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict(
+ [
+ ("patchtsmixer", "PatchTSMixerForRegression"),
+ ("patchtst", "PatchTSTForRegression"),
+ ]
+)
+
+MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = OrderedDict(
+ [
+ ("timesfm", "TimesFmModelForPrediction"),
+ ]
+)
+
+MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
+ [
+ ("swin2sr", "Swin2SRForImageSuperResolution"),
+ ]
+)
+
+MODEL_FOR_AUDIO_TOKENIZATION_NAMES = OrderedDict(
+ [
+ ("dac", "DacModel"),
+ ]
+)
+
+MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
+MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
+MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
+MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
+MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
+)
+MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
+)
+MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
+)
+MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
+)
+MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
+)
+MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
+)
+MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES
+)
+MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
+)
+MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
+MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
+)
+MODEL_FOR_RETRIEVAL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_RETRIEVAL_MAPPING_NAMES)
+MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
+)
+MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
+)
+MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
+MODEL_FOR_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_MAPPING_NAMES)
+MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
+)
+MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
+MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
+)
+MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
+MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
+)
+MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
+)
+MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
+)
+MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
+)
+MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
+)
+MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
+MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
+)
+MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
+)
+MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
+MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
+MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
+)
+MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
+
+MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES
+)
+
+MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES)
+
+MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
+
+MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
+
+MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES
+)
+
+MODEL_FOR_KEYPOINT_MATCHING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_MATCHING_MAPPING_NAMES)
+
+MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
+
+MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES
+)
+
+MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES
+)
+
+MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES
+)
+
+MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
+
+MODEL_FOR_AUDIO_TOKENIZATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_TOKENIZATION_NAMES)
+
+
+class AutoModelForMaskGeneration(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
+
+
+class AutoModelForKeypointDetection(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING
+
+
+class AutoModelForKeypointMatching(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_KEYPOINT_MATCHING_MAPPING
+
+
+class AutoModelForTextEncoding(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
+
+
+class AutoModelForImageToImage(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING
+
+
+class AutoModel(_BaseAutoModelClass):
+ _model_mapping = MODEL_MAPPING
+
+
+AutoModel = auto_class_update(AutoModel)
+
+
+class AutoModelForPreTraining(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_PRETRAINING_MAPPING
+
+
+AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
+
+
+# Private on purpose, the public class will add the deprecation warnings.
+class _AutoModelWithLMHead(_BaseAutoModelClass):
+ _model_mapping = MODEL_WITH_LM_HEAD_MAPPING
+
+
+_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
+
+
+class AutoModelForCausalLM(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
+
+ # override to give better return typehint
+ @classmethod
+ def from_pretrained(
+ cls: type["AutoModelForCausalLM"],
+ pretrained_model_name_or_path: Union[str, os.PathLike[str]],
+ *model_args,
+ **kwargs,
+ ) -> "_BaseModelWithGenerate":
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+
+
+AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
+
+
+class AutoModelForMaskedLM(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_MASKED_LM_MAPPING
+
+
+AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")
+
+
+class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
+
+
+AutoModelForSeq2SeqLM = auto_class_update(
+ AutoModelForSeq2SeqLM,
+ head_doc="sequence-to-sequence language modeling",
+ checkpoint_for_example="google-t5/t5-base",
+)
+
+
+class AutoModelForSequenceClassification(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
+
+
+AutoModelForSequenceClassification = auto_class_update(
+ AutoModelForSequenceClassification, head_doc="sequence classification"
+)
+
+
+class AutoModelForQuestionAnswering(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
+
+
+AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")
+
+
+class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
+
+
+AutoModelForTableQuestionAnswering = auto_class_update(
+ AutoModelForTableQuestionAnswering,
+ head_doc="table question answering",
+ checkpoint_for_example="google/tapas-base-finetuned-wtq",
+)
+
+
+class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
+
+
+AutoModelForVisualQuestionAnswering = auto_class_update(
+ AutoModelForVisualQuestionAnswering,
+ head_doc="visual question answering",
+ checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
+)
+
+
+class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
+
+
+AutoModelForDocumentQuestionAnswering = auto_class_update(
+ AutoModelForDocumentQuestionAnswering,
+ head_doc="document question answering",
+ checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
+)
+
+
+class AutoModelForTokenClassification(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
+
+
+AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")
+
+
+class AutoModelForMultipleChoice(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING
+
+
+AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")
+
+
+class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
+
+
+AutoModelForNextSentencePrediction = auto_class_update(
+ AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
+)
+
+
+class AutoModelForImageClassification(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
+
+
+AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
+
+
+class AutoModelForZeroShotImageClassification(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
+
+
+AutoModelForZeroShotImageClassification = auto_class_update(
+ AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
+)
+
+
+class AutoModelForImageSegmentation(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
+
+
+AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")
+
+
+class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
+
+
+AutoModelForSemanticSegmentation = auto_class_update(
+ AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
+)
+
+
+class AutoModelForTimeSeriesPrediction(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING
+
+
+AutoModelForTimeSeriesPrediction = auto_class_update(
+ AutoModelForTimeSeriesPrediction, head_doc="time-series prediction"
+)
+
+
+class AutoModelForUniversalSegmentation(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING
+
+
+AutoModelForUniversalSegmentation = auto_class_update(
+ AutoModelForUniversalSegmentation, head_doc="universal image segmentation"
+)
+
+
+class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING
+
+
+AutoModelForInstanceSegmentation = auto_class_update(
+ AutoModelForInstanceSegmentation, head_doc="instance segmentation"
+)
+
+
+class AutoModelForObjectDetection(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
+
+
+AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
+
+
+class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
+
+
+AutoModelForZeroShotObjectDetection = auto_class_update(
+ AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
+)
+
+
+class AutoModelForDepthEstimation(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
+
+
+AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation")
+
+
+class AutoModelForVideoClassification(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
+
+
+AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification")
+
+
+# Private on purpose, the public class will add the deprecation warnings.
+class _AutoModelForVision2Seq(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
+
+
+_AutoModelForVision2Seq = auto_class_update(_AutoModelForVision2Seq, head_doc="vision-to-text modeling")
+
+
+class AutoModelForImageTextToText(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
+
+ # override to give better return typehint
+ @classmethod
+ def from_pretrained(
+ cls: type["AutoModelForImageTextToText"],
+ pretrained_model_name_or_path: Union[str, os.PathLike[str]],
+ *model_args,
+ **kwargs,
+ ) -> "_BaseModelWithGenerate":
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+
+
+AutoModelForImageTextToText = auto_class_update(AutoModelForImageTextToText, head_doc="image-text-to-text modeling")
+
+
+class AutoModelForAudioClassification(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
+
+
+AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")
+
+
+class AutoModelForCTC(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_CTC_MAPPING
+
+
+AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")
+
+
+class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
+
+
+AutoModelForSpeechSeq2Seq = auto_class_update(
+ AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
+)
+
+
+class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING
+
+
+AutoModelForAudioFrameClassification = auto_class_update(
+ AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
+)
+
+
+class AutoModelForAudioXVector(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
+
+
+class AutoModelForTextToSpectrogram(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
+
+
+class AutoModelForTextToWaveform(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
+
+
+class AutoBackbone(_BaseAutoBackboneClass):
+ _model_mapping = MODEL_FOR_BACKBONE_MAPPING
+
+
+AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")
+
+
+class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
+
+
+AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")
+
+
+class AutoModelForAudioTokenization(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_AUDIO_TOKENIZATION_MAPPING
+
+
+AutoModelForAudioTokenization = auto_class_update(
+ AutoModelForAudioTokenization, head_doc="audio tokenization through codebooks"
+)
+
+
+class AutoModelWithLMHead(_AutoModelWithLMHead):
+ @classmethod
+ def from_config(cls, config, **kwargs):
+ warnings.warn(
+ "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
+ "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
+ "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
+ FutureWarning,
+ )
+ return super().from_config(config, **kwargs)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+ warnings.warn(
+ "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
+ "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
+ "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
+ FutureWarning,
+ )
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+
+
+class AutoModelForVision2Seq(_AutoModelForVision2Seq):
+ @classmethod
+ def from_config(cls, config, **kwargs):
+ warnings.warn(
+ "The class `AutoModelForVision2Seq` is deprecated and will be removed in v5.0. Please use "
+ "`AutoModelForImageTextToText` instead.",
+ FutureWarning,
+ )
+ return super().from_config(config, **kwargs)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+ warnings.warn(
+ "The class `AutoModelForVision2Seq` is deprecated and will be removed in v5.0. Please use "
+ "`AutoModelForImageTextToText` instead.",
+ FutureWarning,
+ )
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+
+
+__all__ = [
+ "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_AUDIO_TOKENIZATION_MAPPING",
+ "MODEL_FOR_AUDIO_XVECTOR_MAPPING",
+ "MODEL_FOR_BACKBONE_MAPPING",
+ "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING",
+ "MODEL_FOR_CAUSAL_LM_MAPPING",
+ "MODEL_FOR_CTC_MAPPING",
+ "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
+ "MODEL_FOR_DEPTH_ESTIMATION_MAPPING",
+ "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_IMAGE_MAPPING",
+ "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
+ "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
+ "MODEL_FOR_KEYPOINT_DETECTION_MAPPING",
+ "MODEL_FOR_KEYPOINT_MATCHING_MAPPING",
+ "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
+ "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
+ "MODEL_FOR_MASKED_LM_MAPPING",
+ "MODEL_FOR_MASK_GENERATION_MAPPING",
+ "MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
+ "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
+ "MODEL_FOR_OBJECT_DETECTION_MAPPING",
+ "MODEL_FOR_PRETRAINING_MAPPING",
+ "MODEL_FOR_QUESTION_ANSWERING_MAPPING",
+ "MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
+ "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
+ "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
+ "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
+ "MODEL_FOR_TEXT_ENCODING_MAPPING",
+ "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING",
+ "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING",
+ "MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING",
+ "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
+ "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_VISION_2_SEQ_MAPPING",
+ "MODEL_FOR_RETRIEVAL_MAPPING",
+ "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING",
+ "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
+ "MODEL_MAPPING",
+ "MODEL_WITH_LM_HEAD_MAPPING",
+ "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
+ "MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING",
+ "AutoModel",
+ "AutoBackbone",
+ "AutoModelForAudioClassification",
+ "AutoModelForAudioFrameClassification",
+ "AutoModelForAudioTokenization",
+ "AutoModelForAudioXVector",
+ "AutoModelForCausalLM",
+ "AutoModelForCTC",
+ "AutoModelForDepthEstimation",
+ "AutoModelForImageClassification",
+ "AutoModelForImageSegmentation",
+ "AutoModelForImageToImage",
+ "AutoModelForInstanceSegmentation",
+ "AutoModelForKeypointDetection",
+ "AutoModelForKeypointMatching",
+ "AutoModelForMaskGeneration",
+ "AutoModelForTextEncoding",
+ "AutoModelForMaskedImageModeling",
+ "AutoModelForMaskedLM",
+ "AutoModelForMultipleChoice",
+ "AutoModelForNextSentencePrediction",
+ "AutoModelForObjectDetection",
+ "AutoModelForPreTraining",
+ "AutoModelForQuestionAnswering",
+ "AutoModelForSemanticSegmentation",
+ "AutoModelForSeq2SeqLM",
+ "AutoModelForSequenceClassification",
+ "AutoModelForSpeechSeq2Seq",
+ "AutoModelForTableQuestionAnswering",
+ "AutoModelForTextToSpectrogram",
+ "AutoModelForTextToWaveform",
+ "AutoModelForTimeSeriesPrediction",
+ "AutoModelForTokenClassification",
+ "AutoModelForUniversalSegmentation",
+ "AutoModelForVideoClassification",
+ "AutoModelForVision2Seq",
+ "AutoModelForVisualQuestionAnswering",
+ "AutoModelForDocumentQuestionAnswering",
+ "AutoModelWithLMHead",
+ "AutoModelForZeroShotImageClassification",
+ "AutoModelForZeroShotObjectDetection",
+ "AutoModelForImageTextToText",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/auto/modeling_flax_auto.py b/venv/lib/python3.13/site-packages/transformers/models/auto/modeling_flax_auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..0588d03cb6cdb43b94cc3fcd73b1791d1a5ee809
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/auto/modeling_flax_auto.py
@@ -0,0 +1,413 @@
+# coding=utf-8
+# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Auto Model class."""
+
+from collections import OrderedDict
+
+from ...utils import logging
+from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
+from .configuration_auto import CONFIG_MAPPING_NAMES
+
+
+logger = logging.get_logger(__name__)
+
+
+FLAX_MODEL_MAPPING_NAMES = OrderedDict(
+ [
+ # Base model mapping
+ ("albert", "FlaxAlbertModel"),
+ ("bart", "FlaxBartModel"),
+ ("beit", "FlaxBeitModel"),
+ ("bert", "FlaxBertModel"),
+ ("big_bird", "FlaxBigBirdModel"),
+ ("blenderbot", "FlaxBlenderbotModel"),
+ ("blenderbot-small", "FlaxBlenderbotSmallModel"),
+ ("bloom", "FlaxBloomModel"),
+ ("clip", "FlaxCLIPModel"),
+ ("dinov2", "FlaxDinov2Model"),
+ ("distilbert", "FlaxDistilBertModel"),
+ ("electra", "FlaxElectraModel"),
+ ("gemma", "FlaxGemmaModel"),
+ ("gpt-sw3", "FlaxGPT2Model"),
+ ("gpt2", "FlaxGPT2Model"),
+ ("gpt_neo", "FlaxGPTNeoModel"),
+ ("gptj", "FlaxGPTJModel"),
+ ("llama", "FlaxLlamaModel"),
+ ("longt5", "FlaxLongT5Model"),
+ ("marian", "FlaxMarianModel"),
+ ("mbart", "FlaxMBartModel"),
+ ("mistral", "FlaxMistralModel"),
+ ("mt5", "FlaxMT5Model"),
+ ("opt", "FlaxOPTModel"),
+ ("pegasus", "FlaxPegasusModel"),
+ ("regnet", "FlaxRegNetModel"),
+ ("resnet", "FlaxResNetModel"),
+ ("roberta", "FlaxRobertaModel"),
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"),
+ ("roformer", "FlaxRoFormerModel"),
+ ("t5", "FlaxT5Model"),
+ ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
+ ("vit", "FlaxViTModel"),
+ ("wav2vec2", "FlaxWav2Vec2Model"),
+ ("whisper", "FlaxWhisperModel"),
+ ("xglm", "FlaxXGLMModel"),
+ ("xlm-roberta", "FlaxXLMRobertaModel"),
+ ]
+)
+
+FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for pre-training mapping
+ ("albert", "FlaxAlbertForPreTraining"),
+ ("bart", "FlaxBartForConditionalGeneration"),
+ ("bert", "FlaxBertForPreTraining"),
+ ("big_bird", "FlaxBigBirdForPreTraining"),
+ ("electra", "FlaxElectraForPreTraining"),
+ ("longt5", "FlaxLongT5ForConditionalGeneration"),
+ ("mbart", "FlaxMBartForConditionalGeneration"),
+ ("mt5", "FlaxMT5ForConditionalGeneration"),
+ ("roberta", "FlaxRobertaForMaskedLM"),
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
+ ("roformer", "FlaxRoFormerForMaskedLM"),
+ ("t5", "FlaxT5ForConditionalGeneration"),
+ ("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
+ ("whisper", "FlaxWhisperForConditionalGeneration"),
+ ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
+ ]
+)
+
+FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Masked LM mapping
+ ("albert", "FlaxAlbertForMaskedLM"),
+ ("bart", "FlaxBartForConditionalGeneration"),
+ ("bert", "FlaxBertForMaskedLM"),
+ ("big_bird", "FlaxBigBirdForMaskedLM"),
+ ("distilbert", "FlaxDistilBertForMaskedLM"),
+ ("electra", "FlaxElectraForMaskedLM"),
+ ("mbart", "FlaxMBartForConditionalGeneration"),
+ ("roberta", "FlaxRobertaForMaskedLM"),
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
+ ("roformer", "FlaxRoFormerForMaskedLM"),
+ ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
+ ]
+)
+
+FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Seq2Seq Causal LM mapping
+ ("bart", "FlaxBartForConditionalGeneration"),
+ ("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
+ ("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
+ ("encoder-decoder", "FlaxEncoderDecoderModel"),
+ ("longt5", "FlaxLongT5ForConditionalGeneration"),
+ ("marian", "FlaxMarianMTModel"),
+ ("mbart", "FlaxMBartForConditionalGeneration"),
+ ("mt5", "FlaxMT5ForConditionalGeneration"),
+ ("pegasus", "FlaxPegasusForConditionalGeneration"),
+ ("t5", "FlaxT5ForConditionalGeneration"),
+ ]
+)
+
+FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Image-classification
+ ("beit", "FlaxBeitForImageClassification"),
+ ("dinov2", "FlaxDinov2ForImageClassification"),
+ ("regnet", "FlaxRegNetForImageClassification"),
+ ("resnet", "FlaxResNetForImageClassification"),
+ ("vit", "FlaxViTForImageClassification"),
+ ]
+)
+
+FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
+ [
+ ("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"),
+ ]
+)
+
+FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Causal LM mapping
+ ("bart", "FlaxBartForCausalLM"),
+ ("bert", "FlaxBertForCausalLM"),
+ ("big_bird", "FlaxBigBirdForCausalLM"),
+ ("bloom", "FlaxBloomForCausalLM"),
+ ("electra", "FlaxElectraForCausalLM"),
+ ("gemma", "FlaxGemmaForCausalLM"),
+ ("gpt-sw3", "FlaxGPT2LMHeadModel"),
+ ("gpt2", "FlaxGPT2LMHeadModel"),
+ ("gpt_neo", "FlaxGPTNeoForCausalLM"),
+ ("gptj", "FlaxGPTJForCausalLM"),
+ ("llama", "FlaxLlamaForCausalLM"),
+ ("mistral", "FlaxMistralForCausalLM"),
+ ("opt", "FlaxOPTForCausalLM"),
+ ("roberta", "FlaxRobertaForCausalLM"),
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
+ ("xglm", "FlaxXGLMForCausalLM"),
+ ("xlm-roberta", "FlaxXLMRobertaForCausalLM"),
+ ]
+)
+
+FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Sequence Classification mapping
+ ("albert", "FlaxAlbertForSequenceClassification"),
+ ("bart", "FlaxBartForSequenceClassification"),
+ ("bert", "FlaxBertForSequenceClassification"),
+ ("big_bird", "FlaxBigBirdForSequenceClassification"),
+ ("distilbert", "FlaxDistilBertForSequenceClassification"),
+ ("electra", "FlaxElectraForSequenceClassification"),
+ ("mbart", "FlaxMBartForSequenceClassification"),
+ ("roberta", "FlaxRobertaForSequenceClassification"),
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForSequenceClassification"),
+ ("roformer", "FlaxRoFormerForSequenceClassification"),
+ ("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
+ ]
+)
+
+FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Question Answering mapping
+ ("albert", "FlaxAlbertForQuestionAnswering"),
+ ("bart", "FlaxBartForQuestionAnswering"),
+ ("bert", "FlaxBertForQuestionAnswering"),
+ ("big_bird", "FlaxBigBirdForQuestionAnswering"),
+ ("distilbert", "FlaxDistilBertForQuestionAnswering"),
+ ("electra", "FlaxElectraForQuestionAnswering"),
+ ("mbart", "FlaxMBartForQuestionAnswering"),
+ ("roberta", "FlaxRobertaForQuestionAnswering"),
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForQuestionAnswering"),
+ ("roformer", "FlaxRoFormerForQuestionAnswering"),
+ ("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
+ ]
+)
+
+FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Token Classification mapping
+ ("albert", "FlaxAlbertForTokenClassification"),
+ ("bert", "FlaxBertForTokenClassification"),
+ ("big_bird", "FlaxBigBirdForTokenClassification"),
+ ("distilbert", "FlaxDistilBertForTokenClassification"),
+ ("electra", "FlaxElectraForTokenClassification"),
+ ("roberta", "FlaxRobertaForTokenClassification"),
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForTokenClassification"),
+ ("roformer", "FlaxRoFormerForTokenClassification"),
+ ("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
+ ]
+)
+
+FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Multiple Choice mapping
+ ("albert", "FlaxAlbertForMultipleChoice"),
+ ("bert", "FlaxBertForMultipleChoice"),
+ ("big_bird", "FlaxBigBirdForMultipleChoice"),
+ ("distilbert", "FlaxDistilBertForMultipleChoice"),
+ ("electra", "FlaxElectraForMultipleChoice"),
+ ("roberta", "FlaxRobertaForMultipleChoice"),
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMultipleChoice"),
+ ("roformer", "FlaxRoFormerForMultipleChoice"),
+ ("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
+ ]
+)
+
+FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
+ [
+ ("bert", "FlaxBertForNextSentencePrediction"),
+ ]
+)
+
+FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
+ [
+ ("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
+ ("whisper", "FlaxWhisperForConditionalGeneration"),
+ ]
+)
+
+FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ ("whisper", "FlaxWhisperForAudioClassification"),
+ ]
+)
+
+FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
+FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
+FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
+FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
+)
+FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
+)
+FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
+FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
+FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
+)
+FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
+)
+FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
+)
+FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
+)
+FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
+)
+FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
+)
+FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
+)
+
+
+class FlaxAutoModel(_BaseAutoModelClass):
+ _model_mapping = FLAX_MODEL_MAPPING
+
+
+FlaxAutoModel = auto_class_update(FlaxAutoModel)
+
+
+class FlaxAutoModelForPreTraining(_BaseAutoModelClass):
+ _model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING
+
+
+FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining")
+
+
+class FlaxAutoModelForCausalLM(_BaseAutoModelClass):
+ _model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
+
+
+FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling")
+
+
+class FlaxAutoModelForMaskedLM(_BaseAutoModelClass):
+ _model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING
+
+
+FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling")
+
+
+class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass):
+ _model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
+
+
+FlaxAutoModelForSeq2SeqLM = auto_class_update(
+ FlaxAutoModelForSeq2SeqLM,
+ head_doc="sequence-to-sequence language modeling",
+ checkpoint_for_example="google-t5/t5-base",
+)
+
+
+class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass):
+ _model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
+
+
+FlaxAutoModelForSequenceClassification = auto_class_update(
+ FlaxAutoModelForSequenceClassification, head_doc="sequence classification"
+)
+
+
+class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass):
+ _model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
+
+
+FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering")
+
+
+class FlaxAutoModelForTokenClassification(_BaseAutoModelClass):
+ _model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
+
+
+FlaxAutoModelForTokenClassification = auto_class_update(
+ FlaxAutoModelForTokenClassification, head_doc="token classification"
+)
+
+
+class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass):
+ _model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
+
+
+FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice")
+
+
+class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass):
+ _model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
+
+
+FlaxAutoModelForNextSentencePrediction = auto_class_update(
+ FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
+)
+
+
+class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
+ _model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
+
+
+FlaxAutoModelForImageClassification = auto_class_update(
+ FlaxAutoModelForImageClassification, head_doc="image classification"
+)
+
+
+class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
+ _model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
+
+
+FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
+
+
+class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
+ _model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
+
+
+FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
+ FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
+)
+
+__all__ = [
+ "FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
+ "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
+ "FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
+ "FLAX_MODEL_FOR_MASKED_LM_MAPPING",
+ "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
+ "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
+ "FLAX_MODEL_FOR_PRETRAINING_MAPPING",
+ "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
+ "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
+ "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
+ "FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
+ "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
+ "FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
+ "FLAX_MODEL_MAPPING",
+ "FlaxAutoModel",
+ "FlaxAutoModelForCausalLM",
+ "FlaxAutoModelForImageClassification",
+ "FlaxAutoModelForMaskedLM",
+ "FlaxAutoModelForMultipleChoice",
+ "FlaxAutoModelForNextSentencePrediction",
+ "FlaxAutoModelForPreTraining",
+ "FlaxAutoModelForQuestionAnswering",
+ "FlaxAutoModelForSeq2SeqLM",
+ "FlaxAutoModelForSequenceClassification",
+ "FlaxAutoModelForSpeechSeq2Seq",
+ "FlaxAutoModelForTokenClassification",
+ "FlaxAutoModelForVision2Seq",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/auto/modeling_tf_auto.py b/venv/lib/python3.13/site-packages/transformers/models/auto/modeling_tf_auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf39f4d7c9c40bd87a8e4c5e3037e2cbe3574a29
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/auto/modeling_tf_auto.py
@@ -0,0 +1,776 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Auto Model class."""
+
+import warnings
+from collections import OrderedDict
+
+from ...utils import logging
+from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
+from .configuration_auto import CONFIG_MAPPING_NAMES
+
+
+logger = logging.get_logger(__name__)
+
+
+TF_MODEL_MAPPING_NAMES = OrderedDict(
+ [
+ # Base model mapping
+ ("albert", "TFAlbertModel"),
+ ("bart", "TFBartModel"),
+ ("bert", "TFBertModel"),
+ ("blenderbot", "TFBlenderbotModel"),
+ ("blenderbot-small", "TFBlenderbotSmallModel"),
+ ("blip", "TFBlipModel"),
+ ("camembert", "TFCamembertModel"),
+ ("clip", "TFCLIPModel"),
+ ("convbert", "TFConvBertModel"),
+ ("convnext", "TFConvNextModel"),
+ ("convnextv2", "TFConvNextV2Model"),
+ ("ctrl", "TFCTRLModel"),
+ ("cvt", "TFCvtModel"),
+ ("data2vec-vision", "TFData2VecVisionModel"),
+ ("deberta", "TFDebertaModel"),
+ ("deberta-v2", "TFDebertaV2Model"),
+ ("deit", "TFDeiTModel"),
+ ("distilbert", "TFDistilBertModel"),
+ ("dpr", "TFDPRQuestionEncoder"),
+ ("efficientformer", "TFEfficientFormerModel"),
+ ("electra", "TFElectraModel"),
+ ("esm", "TFEsmModel"),
+ ("flaubert", "TFFlaubertModel"),
+ ("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
+ ("gpt-sw3", "TFGPT2Model"),
+ ("gpt2", "TFGPT2Model"),
+ ("gptj", "TFGPTJModel"),
+ ("groupvit", "TFGroupViTModel"),
+ ("hubert", "TFHubertModel"),
+ ("idefics", "TFIdeficsModel"),
+ ("layoutlm", "TFLayoutLMModel"),
+ ("layoutlmv3", "TFLayoutLMv3Model"),
+ ("led", "TFLEDModel"),
+ ("longformer", "TFLongformerModel"),
+ ("lxmert", "TFLxmertModel"),
+ ("marian", "TFMarianModel"),
+ ("mbart", "TFMBartModel"),
+ ("mistral", "TFMistralModel"),
+ ("mobilebert", "TFMobileBertModel"),
+ ("mobilevit", "TFMobileViTModel"),
+ ("mpnet", "TFMPNetModel"),
+ ("mt5", "TFMT5Model"),
+ ("openai-gpt", "TFOpenAIGPTModel"),
+ ("opt", "TFOPTModel"),
+ ("pegasus", "TFPegasusModel"),
+ ("regnet", "TFRegNetModel"),
+ ("rembert", "TFRemBertModel"),
+ ("resnet", "TFResNetModel"),
+ ("roberta", "TFRobertaModel"),
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
+ ("roformer", "TFRoFormerModel"),
+ ("sam", "TFSamModel"),
+ ("sam_vision_model", "TFSamVisionModel"),
+ ("segformer", "TFSegformerModel"),
+ ("speech_to_text", "TFSpeech2TextModel"),
+ ("swiftformer", "TFSwiftFormerModel"),
+ ("swin", "TFSwinModel"),
+ ("t5", "TFT5Model"),
+ ("tapas", "TFTapasModel"),
+ ("transfo-xl", "TFTransfoXLModel"),
+ ("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"),
+ ("vit", "TFViTModel"),
+ ("vit_mae", "TFViTMAEModel"),
+ ("wav2vec2", "TFWav2Vec2Model"),
+ ("whisper", "TFWhisperModel"),
+ ("xglm", "TFXGLMModel"),
+ ("xlm", "TFXLMModel"),
+ ("xlm-roberta", "TFXLMRobertaModel"),
+ ("xlnet", "TFXLNetModel"),
+ ]
+)
+
+TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for pre-training mapping
+ ("albert", "TFAlbertForPreTraining"),
+ ("bart", "TFBartForConditionalGeneration"),
+ ("bert", "TFBertForPreTraining"),
+ ("camembert", "TFCamembertForMaskedLM"),
+ ("ctrl", "TFCTRLLMHeadModel"),
+ ("distilbert", "TFDistilBertForMaskedLM"),
+ ("electra", "TFElectraForPreTraining"),
+ ("flaubert", "TFFlaubertWithLMHeadModel"),
+ ("funnel", "TFFunnelForPreTraining"),
+ ("gpt-sw3", "TFGPT2LMHeadModel"),
+ ("gpt2", "TFGPT2LMHeadModel"),
+ ("idefics", "TFIdeficsForVisionText2Text"),
+ ("layoutlm", "TFLayoutLMForMaskedLM"),
+ ("lxmert", "TFLxmertForPreTraining"),
+ ("mobilebert", "TFMobileBertForPreTraining"),
+ ("mpnet", "TFMPNetForMaskedLM"),
+ ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
+ ("roberta", "TFRobertaForMaskedLM"),
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
+ ("t5", "TFT5ForConditionalGeneration"),
+ ("tapas", "TFTapasForMaskedLM"),
+ ("transfo-xl", "TFTransfoXLLMHeadModel"),
+ ("vit_mae", "TFViTMAEForPreTraining"),
+ ("xlm", "TFXLMWithLMHeadModel"),
+ ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
+ ("xlnet", "TFXLNetLMHeadModel"),
+ ]
+)
+
+TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
+ [
+ # Model with LM heads mapping
+ ("albert", "TFAlbertForMaskedLM"),
+ ("bart", "TFBartForConditionalGeneration"),
+ ("bert", "TFBertForMaskedLM"),
+ ("camembert", "TFCamembertForMaskedLM"),
+ ("convbert", "TFConvBertForMaskedLM"),
+ ("ctrl", "TFCTRLLMHeadModel"),
+ ("distilbert", "TFDistilBertForMaskedLM"),
+ ("electra", "TFElectraForMaskedLM"),
+ ("esm", "TFEsmForMaskedLM"),
+ ("flaubert", "TFFlaubertWithLMHeadModel"),
+ ("funnel", "TFFunnelForMaskedLM"),
+ ("gpt-sw3", "TFGPT2LMHeadModel"),
+ ("gpt2", "TFGPT2LMHeadModel"),
+ ("gptj", "TFGPTJForCausalLM"),
+ ("layoutlm", "TFLayoutLMForMaskedLM"),
+ ("led", "TFLEDForConditionalGeneration"),
+ ("longformer", "TFLongformerForMaskedLM"),
+ ("marian", "TFMarianMTModel"),
+ ("mobilebert", "TFMobileBertForMaskedLM"),
+ ("mpnet", "TFMPNetForMaskedLM"),
+ ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
+ ("rembert", "TFRemBertForMaskedLM"),
+ ("roberta", "TFRobertaForMaskedLM"),
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
+ ("roformer", "TFRoFormerForMaskedLM"),
+ ("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
+ ("t5", "TFT5ForConditionalGeneration"),
+ ("tapas", "TFTapasForMaskedLM"),
+ ("transfo-xl", "TFTransfoXLLMHeadModel"),
+ ("whisper", "TFWhisperForConditionalGeneration"),
+ ("xlm", "TFXLMWithLMHeadModel"),
+ ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
+ ("xlnet", "TFXLNetLMHeadModel"),
+ ]
+)
+
+TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Causal LM mapping
+ ("bert", "TFBertLMHeadModel"),
+ ("camembert", "TFCamembertForCausalLM"),
+ ("ctrl", "TFCTRLLMHeadModel"),
+ ("gpt-sw3", "TFGPT2LMHeadModel"),
+ ("gpt2", "TFGPT2LMHeadModel"),
+ ("gptj", "TFGPTJForCausalLM"),
+ ("mistral", "TFMistralForCausalLM"),
+ ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
+ ("opt", "TFOPTForCausalLM"),
+ ("rembert", "TFRemBertForCausalLM"),
+ ("roberta", "TFRobertaForCausalLM"),
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"),
+ ("roformer", "TFRoFormerForCausalLM"),
+ ("transfo-xl", "TFTransfoXLLMHeadModel"),
+ ("xglm", "TFXGLMForCausalLM"),
+ ("xlm", "TFXLMWithLMHeadModel"),
+ ("xlm-roberta", "TFXLMRobertaForCausalLM"),
+ ("xlnet", "TFXLNetLMHeadModel"),
+ ]
+)
+
+TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
+ [
+ ("deit", "TFDeiTForMaskedImageModeling"),
+ ("swin", "TFSwinForMaskedImageModeling"),
+ ]
+)
+
+TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Image-classsification
+ ("convnext", "TFConvNextForImageClassification"),
+ ("convnextv2", "TFConvNextV2ForImageClassification"),
+ ("cvt", "TFCvtForImageClassification"),
+ ("data2vec-vision", "TFData2VecVisionForImageClassification"),
+ ("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
+ (
+ "efficientformer",
+ ("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"),
+ ),
+ ("mobilevit", "TFMobileViTForImageClassification"),
+ ("regnet", "TFRegNetForImageClassification"),
+ ("resnet", "TFResNetForImageClassification"),
+ ("segformer", "TFSegformerForImageClassification"),
+ ("swiftformer", "TFSwiftFormerForImageClassification"),
+ ("swin", "TFSwinForImageClassification"),
+ ("vit", "TFViTForImageClassification"),
+ ]
+)
+
+
+TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Zero Shot Image Classification mapping
+ ("blip", "TFBlipModel"),
+ ("clip", "TFCLIPModel"),
+ ]
+)
+
+
+TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Semantic Segmentation mapping
+ ("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
+ ("mobilevit", "TFMobileViTForSemanticSegmentation"),
+ ("segformer", "TFSegformerForSemanticSegmentation"),
+ ]
+)
+
+TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
+ [
+ ("blip", "TFBlipForConditionalGeneration"),
+ ("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
+ ]
+)
+
+TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Masked LM mapping
+ ("albert", "TFAlbertForMaskedLM"),
+ ("bert", "TFBertForMaskedLM"),
+ ("camembert", "TFCamembertForMaskedLM"),
+ ("convbert", "TFConvBertForMaskedLM"),
+ ("deberta", "TFDebertaForMaskedLM"),
+ ("deberta-v2", "TFDebertaV2ForMaskedLM"),
+ ("distilbert", "TFDistilBertForMaskedLM"),
+ ("electra", "TFElectraForMaskedLM"),
+ ("esm", "TFEsmForMaskedLM"),
+ ("flaubert", "TFFlaubertWithLMHeadModel"),
+ ("funnel", "TFFunnelForMaskedLM"),
+ ("layoutlm", "TFLayoutLMForMaskedLM"),
+ ("longformer", "TFLongformerForMaskedLM"),
+ ("mobilebert", "TFMobileBertForMaskedLM"),
+ ("mpnet", "TFMPNetForMaskedLM"),
+ ("rembert", "TFRemBertForMaskedLM"),
+ ("roberta", "TFRobertaForMaskedLM"),
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
+ ("roformer", "TFRoFormerForMaskedLM"),
+ ("tapas", "TFTapasForMaskedLM"),
+ ("xlm", "TFXLMWithLMHeadModel"),
+ ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
+ ]
+)
+
+TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Seq2Seq Causal LM mapping
+ ("bart", "TFBartForConditionalGeneration"),
+ ("blenderbot", "TFBlenderbotForConditionalGeneration"),
+ ("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
+ ("encoder-decoder", "TFEncoderDecoderModel"),
+ ("led", "TFLEDForConditionalGeneration"),
+ ("marian", "TFMarianMTModel"),
+ ("mbart", "TFMBartForConditionalGeneration"),
+ ("mt5", "TFMT5ForConditionalGeneration"),
+ ("pegasus", "TFPegasusForConditionalGeneration"),
+ ("t5", "TFT5ForConditionalGeneration"),
+ ]
+)
+
+TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
+ [
+ ("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
+ ("whisper", "TFWhisperForConditionalGeneration"),
+ ]
+)
+
+TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Sequence Classification mapping
+ ("albert", "TFAlbertForSequenceClassification"),
+ ("bart", "TFBartForSequenceClassification"),
+ ("bert", "TFBertForSequenceClassification"),
+ ("camembert", "TFCamembertForSequenceClassification"),
+ ("convbert", "TFConvBertForSequenceClassification"),
+ ("ctrl", "TFCTRLForSequenceClassification"),
+ ("deberta", "TFDebertaForSequenceClassification"),
+ ("deberta-v2", "TFDebertaV2ForSequenceClassification"),
+ ("distilbert", "TFDistilBertForSequenceClassification"),
+ ("electra", "TFElectraForSequenceClassification"),
+ ("esm", "TFEsmForSequenceClassification"),
+ ("flaubert", "TFFlaubertForSequenceClassification"),
+ ("funnel", "TFFunnelForSequenceClassification"),
+ ("gpt-sw3", "TFGPT2ForSequenceClassification"),
+ ("gpt2", "TFGPT2ForSequenceClassification"),
+ ("gptj", "TFGPTJForSequenceClassification"),
+ ("layoutlm", "TFLayoutLMForSequenceClassification"),
+ ("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"),
+ ("longformer", "TFLongformerForSequenceClassification"),
+ ("mistral", "TFMistralForSequenceClassification"),
+ ("mobilebert", "TFMobileBertForSequenceClassification"),
+ ("mpnet", "TFMPNetForSequenceClassification"),
+ ("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
+ ("rembert", "TFRemBertForSequenceClassification"),
+ ("roberta", "TFRobertaForSequenceClassification"),
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"),
+ ("roformer", "TFRoFormerForSequenceClassification"),
+ ("tapas", "TFTapasForSequenceClassification"),
+ ("transfo-xl", "TFTransfoXLForSequenceClassification"),
+ ("xlm", "TFXLMForSequenceClassification"),
+ ("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
+ ("xlnet", "TFXLNetForSequenceClassification"),
+ ]
+)
+
+TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Question Answering mapping
+ ("albert", "TFAlbertForQuestionAnswering"),
+ ("bert", "TFBertForQuestionAnswering"),
+ ("camembert", "TFCamembertForQuestionAnswering"),
+ ("convbert", "TFConvBertForQuestionAnswering"),
+ ("deberta", "TFDebertaForQuestionAnswering"),
+ ("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
+ ("distilbert", "TFDistilBertForQuestionAnswering"),
+ ("electra", "TFElectraForQuestionAnswering"),
+ ("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
+ ("funnel", "TFFunnelForQuestionAnswering"),
+ ("gptj", "TFGPTJForQuestionAnswering"),
+ ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
+ ("longformer", "TFLongformerForQuestionAnswering"),
+ ("mobilebert", "TFMobileBertForQuestionAnswering"),
+ ("mpnet", "TFMPNetForQuestionAnswering"),
+ ("rembert", "TFRemBertForQuestionAnswering"),
+ ("roberta", "TFRobertaForQuestionAnswering"),
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"),
+ ("roformer", "TFRoFormerForQuestionAnswering"),
+ ("xlm", "TFXLMForQuestionAnsweringSimple"),
+ ("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
+ ("xlnet", "TFXLNetForQuestionAnsweringSimple"),
+ ]
+)
+TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])
+
+TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
+ [
+ ("layoutlm", "TFLayoutLMForQuestionAnswering"),
+ ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
+ ]
+)
+
+
+TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Table Question Answering mapping
+ ("tapas", "TFTapasForQuestionAnswering"),
+ ]
+)
+
+TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Token Classification mapping
+ ("albert", "TFAlbertForTokenClassification"),
+ ("bert", "TFBertForTokenClassification"),
+ ("camembert", "TFCamembertForTokenClassification"),
+ ("convbert", "TFConvBertForTokenClassification"),
+ ("deberta", "TFDebertaForTokenClassification"),
+ ("deberta-v2", "TFDebertaV2ForTokenClassification"),
+ ("distilbert", "TFDistilBertForTokenClassification"),
+ ("electra", "TFElectraForTokenClassification"),
+ ("esm", "TFEsmForTokenClassification"),
+ ("flaubert", "TFFlaubertForTokenClassification"),
+ ("funnel", "TFFunnelForTokenClassification"),
+ ("layoutlm", "TFLayoutLMForTokenClassification"),
+ ("layoutlmv3", "TFLayoutLMv3ForTokenClassification"),
+ ("longformer", "TFLongformerForTokenClassification"),
+ ("mobilebert", "TFMobileBertForTokenClassification"),
+ ("mpnet", "TFMPNetForTokenClassification"),
+ ("rembert", "TFRemBertForTokenClassification"),
+ ("roberta", "TFRobertaForTokenClassification"),
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"),
+ ("roformer", "TFRoFormerForTokenClassification"),
+ ("xlm", "TFXLMForTokenClassification"),
+ ("xlm-roberta", "TFXLMRobertaForTokenClassification"),
+ ("xlnet", "TFXLNetForTokenClassification"),
+ ]
+)
+
+TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
+ [
+ # Model for Multiple Choice mapping
+ ("albert", "TFAlbertForMultipleChoice"),
+ ("bert", "TFBertForMultipleChoice"),
+ ("camembert", "TFCamembertForMultipleChoice"),
+ ("convbert", "TFConvBertForMultipleChoice"),
+ ("deberta-v2", "TFDebertaV2ForMultipleChoice"),
+ ("distilbert", "TFDistilBertForMultipleChoice"),
+ ("electra", "TFElectraForMultipleChoice"),
+ ("flaubert", "TFFlaubertForMultipleChoice"),
+ ("funnel", "TFFunnelForMultipleChoice"),
+ ("longformer", "TFLongformerForMultipleChoice"),
+ ("mobilebert", "TFMobileBertForMultipleChoice"),
+ ("mpnet", "TFMPNetForMultipleChoice"),
+ ("rembert", "TFRemBertForMultipleChoice"),
+ ("roberta", "TFRobertaForMultipleChoice"),
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"),
+ ("roformer", "TFRoFormerForMultipleChoice"),
+ ("xlm", "TFXLMForMultipleChoice"),
+ ("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
+ ("xlnet", "TFXLNetForMultipleChoice"),
+ ]
+)
+
+TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
+ [
+ ("bert", "TFBertForNextSentencePrediction"),
+ ("mobilebert", "TFMobileBertForNextSentencePrediction"),
+ ]
+)
+TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
+ [
+ ("sam", "TFSamModel"),
+ ]
+)
+TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
+ [
+ ("albert", "TFAlbertModel"),
+ ("bert", "TFBertModel"),
+ ("convbert", "TFConvBertModel"),
+ ("deberta", "TFDebertaModel"),
+ ("deberta-v2", "TFDebertaV2Model"),
+ ("distilbert", "TFDistilBertModel"),
+ ("electra", "TFElectraModel"),
+ ("flaubert", "TFFlaubertModel"),
+ ("longformer", "TFLongformerModel"),
+ ("mobilebert", "TFMobileBertModel"),
+ ("mt5", "TFMT5EncoderModel"),
+ ("rembert", "TFRemBertModel"),
+ ("roberta", "TFRobertaModel"),
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
+ ("roformer", "TFRoFormerModel"),
+ ("t5", "TFT5EncoderModel"),
+ ("xlm", "TFXLMModel"),
+ ("xlm-roberta", "TFXLMRobertaModel"),
+ ]
+)
+
+TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
+TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
+TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
+TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
+TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
+)
+TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
+)
+TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
+)
+TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
+)
+TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
+TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
+TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
+)
+TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
+)
+TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
+)
+TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
+)
+TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
+)
+TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
+)
+TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
+)
+TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
+)
+TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
+)
+TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
+)
+
+TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
+)
+
+TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
+
+
+class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
+
+
+class TFAutoModelForTextEncoding(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING
+
+
+class TFAutoModel(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_MAPPING
+
+
+TFAutoModel = auto_class_update(TFAutoModel)
+
+
+class TFAutoModelForAudioClassification(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
+
+
+TFAutoModelForAudioClassification = auto_class_update(
+ TFAutoModelForAudioClassification, head_doc="audio classification"
+)
+
+
+class TFAutoModelForPreTraining(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
+
+
+TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
+
+
+# Private on purpose, the public class will add the deprecation warnings.
+class _TFAutoModelWithLMHead(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
+
+
+_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
+
+
+class TFAutoModelForCausalLM(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
+
+
+TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
+
+
+class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
+
+
+TFAutoModelForMaskedImageModeling = auto_class_update(
+ TFAutoModelForMaskedImageModeling, head_doc="masked image modeling"
+)
+
+
+class TFAutoModelForImageClassification(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
+
+
+TFAutoModelForImageClassification = auto_class_update(
+ TFAutoModelForImageClassification, head_doc="image classification"
+)
+
+
+class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
+
+
+TFAutoModelForZeroShotImageClassification = auto_class_update(
+ TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
+)
+
+
+class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
+
+
+TFAutoModelForSemanticSegmentation = auto_class_update(
+ TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
+)
+
+
+class TFAutoModelForVision2Seq(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
+
+
+TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")
+
+
+class TFAutoModelForMaskedLM(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
+
+
+TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")
+
+
+class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
+
+
+TFAutoModelForSeq2SeqLM = auto_class_update(
+ TFAutoModelForSeq2SeqLM,
+ head_doc="sequence-to-sequence language modeling",
+ checkpoint_for_example="google-t5/t5-base",
+)
+
+
+class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
+
+
+TFAutoModelForSequenceClassification = auto_class_update(
+ TFAutoModelForSequenceClassification, head_doc="sequence classification"
+)
+
+
+class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
+
+
+TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")
+
+
+class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
+
+
+TFAutoModelForDocumentQuestionAnswering = auto_class_update(
+ TFAutoModelForDocumentQuestionAnswering,
+ head_doc="document question answering",
+ checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
+)
+
+
+class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
+
+
+TFAutoModelForTableQuestionAnswering = auto_class_update(
+ TFAutoModelForTableQuestionAnswering,
+ head_doc="table question answering",
+ checkpoint_for_example="google/tapas-base-finetuned-wtq",
+)
+
+
+class TFAutoModelForTokenClassification(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
+
+
+TFAutoModelForTokenClassification = auto_class_update(
+ TFAutoModelForTokenClassification, head_doc="token classification"
+)
+
+
+class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
+
+
+TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice")
+
+
+class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
+
+
+TFAutoModelForNextSentencePrediction = auto_class_update(
+ TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
+)
+
+
+class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
+ _model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
+
+
+TFAutoModelForSpeechSeq2Seq = auto_class_update(
+ TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
+)
+
+
+class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
+ @classmethod
+ def from_config(cls, config):
+ warnings.warn(
+ "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
+ " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
+ " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
+ FutureWarning,
+ )
+ return super().from_config(config)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+ warnings.warn(
+ "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
+ " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
+ " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
+ FutureWarning,
+ )
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+
+
+__all__ = [
+ "TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
+ "TF_MODEL_FOR_CAUSAL_LM_MAPPING",
+ "TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
+ "TF_MODEL_FOR_MASK_GENERATION_MAPPING",
+ "TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
+ "TF_MODEL_FOR_MASKED_LM_MAPPING",
+ "TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
+ "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
+ "TF_MODEL_FOR_PRETRAINING_MAPPING",
+ "TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
+ "TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
+ "TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
+ "TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
+ "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
+ "TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
+ "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
+ "TF_MODEL_FOR_TEXT_ENCODING_MAPPING",
+ "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
+ "TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
+ "TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
+ "TF_MODEL_MAPPING",
+ "TF_MODEL_WITH_LM_HEAD_MAPPING",
+ "TFAutoModel",
+ "TFAutoModelForAudioClassification",
+ "TFAutoModelForCausalLM",
+ "TFAutoModelForImageClassification",
+ "TFAutoModelForMaskedImageModeling",
+ "TFAutoModelForMaskedLM",
+ "TFAutoModelForMaskGeneration",
+ "TFAutoModelForMultipleChoice",
+ "TFAutoModelForNextSentencePrediction",
+ "TFAutoModelForPreTraining",
+ "TFAutoModelForDocumentQuestionAnswering",
+ "TFAutoModelForQuestionAnswering",
+ "TFAutoModelForSemanticSegmentation",
+ "TFAutoModelForSeq2SeqLM",
+ "TFAutoModelForSequenceClassification",
+ "TFAutoModelForSpeechSeq2Seq",
+ "TFAutoModelForTableQuestionAnswering",
+ "TFAutoModelForTextEncoding",
+ "TFAutoModelForTokenClassification",
+ "TFAutoModelForVision2Seq",
+ "TFAutoModelForZeroShotImageClassification",
+ "TFAutoModelWithLMHead",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/auto/processing_auto.py b/venv/lib/python3.13/site-packages/transformers/models/auto/processing_auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..11862a5896b94be7d1d247f022026bf64088987d
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/auto/processing_auto.py
@@ -0,0 +1,443 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""AutoProcessor class."""
+
+import importlib
+import inspect
+import json
+import warnings
+from collections import OrderedDict
+
+# Build the list of all feature extractors
+from ...configuration_utils import PretrainedConfig
+from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
+from ...feature_extraction_utils import FeatureExtractionMixin
+from ...image_processing_utils import ImageProcessingMixin
+from ...processing_utils import ProcessorMixin
+from ...tokenization_utils import TOKENIZER_CONFIG_FILE
+from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, VIDEO_PROCESSOR_NAME, cached_file, logging
+from ...video_processing_utils import BaseVideoProcessor
+from .auto_factory import _LazyAutoMapping
+from .configuration_auto import (
+ CONFIG_MAPPING_NAMES,
+ AutoConfig,
+ model_type_to_module_name,
+ replace_list_option_in_docstrings,
+)
+from .feature_extraction_auto import AutoFeatureExtractor
+from .image_processing_auto import AutoImageProcessor
+from .tokenization_auto import AutoTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+PROCESSOR_MAPPING_NAMES = OrderedDict(
+ [
+ ("aimv2", "CLIPProcessor"),
+ ("align", "AlignProcessor"),
+ ("altclip", "AltCLIPProcessor"),
+ ("aria", "AriaProcessor"),
+ ("aya_vision", "AyaVisionProcessor"),
+ ("bark", "BarkProcessor"),
+ ("blip", "BlipProcessor"),
+ ("blip-2", "Blip2Processor"),
+ ("bridgetower", "BridgeTowerProcessor"),
+ ("chameleon", "ChameleonProcessor"),
+ ("chinese_clip", "ChineseCLIPProcessor"),
+ ("clap", "ClapProcessor"),
+ ("clip", "CLIPProcessor"),
+ ("clipseg", "CLIPSegProcessor"),
+ ("clvp", "ClvpProcessor"),
+ ("cohere2_vision", "Cohere2VisionProcessor"),
+ ("colpali", "ColPaliProcessor"),
+ ("colqwen2", "ColQwen2Processor"),
+ ("deepseek_vl", "DeepseekVLProcessor"),
+ ("deepseek_vl_hybrid", "DeepseekVLHybridProcessor"),
+ ("dia", "DiaProcessor"),
+ ("edgetam", "Sam2Processor"),
+ ("emu3", "Emu3Processor"),
+ ("evolla", "EvollaProcessor"),
+ ("flava", "FlavaProcessor"),
+ ("florence2", "Florence2Processor"),
+ ("fuyu", "FuyuProcessor"),
+ ("gemma3", "Gemma3Processor"),
+ ("gemma3n", "Gemma3nProcessor"),
+ ("git", "GitProcessor"),
+ ("glm4v", "Glm4vProcessor"),
+ ("glm4v_moe", "Glm4vProcessor"),
+ ("got_ocr2", "GotOcr2Processor"),
+ ("granite_speech", "GraniteSpeechProcessor"),
+ ("grounding-dino", "GroundingDinoProcessor"),
+ ("groupvit", "CLIPProcessor"),
+ ("hubert", "Wav2Vec2Processor"),
+ ("idefics", "IdeficsProcessor"),
+ ("idefics2", "Idefics2Processor"),
+ ("idefics3", "Idefics3Processor"),
+ ("instructblip", "InstructBlipProcessor"),
+ ("instructblipvideo", "InstructBlipVideoProcessor"),
+ ("internvl", "InternVLProcessor"),
+ ("janus", "JanusProcessor"),
+ ("kosmos-2", "Kosmos2Processor"),
+ ("kosmos-2.5", "Kosmos2_5Processor"),
+ ("kyutai_speech_to_text", "KyutaiSpeechToTextProcessor"),
+ ("layoutlmv2", "LayoutLMv2Processor"),
+ ("layoutlmv3", "LayoutLMv3Processor"),
+ ("lfm2_vl", "Lfm2VlProcessor"),
+ ("llama4", "Llama4Processor"),
+ ("llava", "LlavaProcessor"),
+ ("llava_next", "LlavaNextProcessor"),
+ ("llava_next_video", "LlavaNextVideoProcessor"),
+ ("llava_onevision", "LlavaOnevisionProcessor"),
+ ("markuplm", "MarkupLMProcessor"),
+ ("mctct", "MCTCTProcessor"),
+ ("metaclip_2", "CLIPProcessor"),
+ ("mgp-str", "MgpstrProcessor"),
+ ("mistral3", "PixtralProcessor"),
+ ("mllama", "MllamaProcessor"),
+ ("mm-grounding-dino", "GroundingDinoProcessor"),
+ ("moonshine", "Wav2Vec2Processor"),
+ ("oneformer", "OneFormerProcessor"),
+ ("ovis2", "Ovis2Processor"),
+ ("owlv2", "Owlv2Processor"),
+ ("owlvit", "OwlViTProcessor"),
+ ("paligemma", "PaliGemmaProcessor"),
+ ("perception_lm", "PerceptionLMProcessor"),
+ ("phi4_multimodal", "Phi4MultimodalProcessor"),
+ ("pix2struct", "Pix2StructProcessor"),
+ ("pixtral", "PixtralProcessor"),
+ ("pop2piano", "Pop2PianoProcessor"),
+ ("qwen2_5_omni", "Qwen2_5OmniProcessor"),
+ ("qwen2_5_vl", "Qwen2_5_VLProcessor"),
+ ("qwen2_audio", "Qwen2AudioProcessor"),
+ ("qwen2_vl", "Qwen2VLProcessor"),
+ ("qwen3_omni_moe", "Qwen3OmniMoeProcessor"),
+ ("qwen3_vl", "Qwen3VLProcessor"),
+ ("qwen3_vl_moe", "Qwen3VLProcessor"),
+ ("sam", "SamProcessor"),
+ ("sam2", "Sam2Processor"),
+ ("sam_hq", "SamHQProcessor"),
+ ("seamless_m4t", "SeamlessM4TProcessor"),
+ ("sew", "Wav2Vec2Processor"),
+ ("sew-d", "Wav2Vec2Processor"),
+ ("shieldgemma2", "ShieldGemma2Processor"),
+ ("siglip", "SiglipProcessor"),
+ ("siglip2", "Siglip2Processor"),
+ ("smolvlm", "SmolVLMProcessor"),
+ ("speech_to_text", "Speech2TextProcessor"),
+ ("speech_to_text_2", "Speech2Text2Processor"),
+ ("speecht5", "SpeechT5Processor"),
+ ("trocr", "TrOCRProcessor"),
+ ("tvlt", "TvltProcessor"),
+ ("tvp", "TvpProcessor"),
+ ("udop", "UdopProcessor"),
+ ("unispeech", "Wav2Vec2Processor"),
+ ("unispeech-sat", "Wav2Vec2Processor"),
+ ("video_llava", "VideoLlavaProcessor"),
+ ("vilt", "ViltProcessor"),
+ ("vipllava", "LlavaProcessor"),
+ ("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
+ ("voxtral", "VoxtralProcessor"),
+ ("wav2vec2", "Wav2Vec2Processor"),
+ ("wav2vec2-bert", "Wav2Vec2Processor"),
+ ("wav2vec2-conformer", "Wav2Vec2Processor"),
+ ("wavlm", "Wav2Vec2Processor"),
+ ("whisper", "WhisperProcessor"),
+ ("xclip", "XCLIPProcessor"),
+ ]
+)
+
+PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, PROCESSOR_MAPPING_NAMES)
+
+
+def processor_class_from_name(class_name: str):
+ for module_name, processors in PROCESSOR_MAPPING_NAMES.items():
+ if class_name in processors:
+ module_name = model_type_to_module_name(module_name)
+
+ module = importlib.import_module(f".{module_name}", "transformers.models")
+ try:
+ return getattr(module, class_name)
+ except AttributeError:
+ continue
+
+ for processor in PROCESSOR_MAPPING._extra_content.values():
+ if getattr(processor, "__name__", None) == class_name:
+ return processor
+
+ # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
+ # init and we return the proper dummy to get an appropriate error message.
+ main_module = importlib.import_module("transformers")
+ if hasattr(main_module, class_name):
+ return getattr(main_module, class_name)
+
+ return None
+
+
+class AutoProcessor:
+ r"""
+ This is a generic processor class that will be instantiated as one of the processor classes of the library when
+ created with the [`AutoProcessor.from_pretrained`] class method.
+
+ This class cannot be instantiated directly using `__init__()` (throws an error).
+ """
+
+ def __init__(self):
+ raise OSError(
+ "AutoProcessor is designed to be instantiated "
+ "using the `AutoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
+ )
+
+ @classmethod
+ @replace_list_option_in_docstrings(PROCESSOR_MAPPING_NAMES)
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ r"""
+ Instantiate one of the processor classes of the library from a pretrained model vocabulary.
+
+ The processor class to instantiate is selected based on the `model_type` property of the config object (either
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible):
+
+ List options
+
+ Params:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a processor files saved using the `save_pretrained()` method,
+ e.g., `./my_model_directory/`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the feature extractor files and override the cached versions
+ if they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ If `False`, then this function returns just the final feature extractor object. If `True`, then this
+ functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
+ consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
+ `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
+ execute code present on the Hub on your local machine.
+ kwargs (`dict[str, Any]`, *optional*):
+ The values in kwargs of any keys which are feature extractor attributes will be used to override the
+ loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
+ controlled by the `return_unused_kwargs` keyword parameter.
+
+
+
+ Passing `token=True` is required when you want to use a private model.
+
+
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoProcessor
+
+ >>> # Download processor from huggingface.co and cache.
+ >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
+
+ >>> # If processor files are in a directory (e.g. processor was saved using *save_pretrained('./test/saved_model/')*)
+ >>> # processor = AutoProcessor.from_pretrained("./test/saved_model/")
+ ```"""
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token") is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ config = kwargs.pop("config", None)
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
+ kwargs["_from_auto"] = True
+
+ processor_class = None
+ processor_auto_map = None
+
+ # First, let's see if we have a processor or preprocessor config.
+ # Filter the kwargs for `cached_file`.
+ cached_file_kwargs = {key: kwargs[key] for key in inspect.signature(cached_file).parameters if key in kwargs}
+ # We don't want to raise
+ cached_file_kwargs.update(
+ {
+ "_raise_exceptions_for_gated_repo": False,
+ "_raise_exceptions_for_missing_entries": False,
+ "_raise_exceptions_for_connection_errors": False,
+ }
+ )
+
+ # Let's start by checking whether the processor class is saved in a processor config
+ processor_config_file = cached_file(pretrained_model_name_or_path, PROCESSOR_NAME, **cached_file_kwargs)
+ if processor_config_file is not None:
+ config_dict, _ = ProcessorMixin.get_processor_dict(pretrained_model_name_or_path, **kwargs)
+ processor_class = config_dict.get("processor_class", None)
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
+
+ if processor_class is None:
+ # If not found, let's check whether the processor class is saved in an image processor config
+ preprocessor_config_file = cached_file(
+ pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
+ )
+ if preprocessor_config_file is not None:
+ config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
+ processor_class = config_dict.get("processor_class", None)
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
+
+ # Saved as video processor
+ if preprocessor_config_file is None:
+ preprocessor_config_file = cached_file(
+ pretrained_model_name_or_path, VIDEO_PROCESSOR_NAME, **cached_file_kwargs
+ )
+ if preprocessor_config_file is not None:
+ config_dict, _ = BaseVideoProcessor.get_video_processor_dict(
+ pretrained_model_name_or_path, **kwargs
+ )
+ processor_class = config_dict.get("processor_class", None)
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
+
+ # Saved as feature extractor
+ if preprocessor_config_file is None:
+ preprocessor_config_file = cached_file(
+ pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
+ )
+ if preprocessor_config_file is not None and processor_class is None:
+ config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(
+ pretrained_model_name_or_path, **kwargs
+ )
+ processor_class = config_dict.get("processor_class", None)
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
+
+ if processor_class is None:
+ # Next, let's check whether the processor class is saved in a tokenizer
+ tokenizer_config_file = cached_file(
+ pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **cached_file_kwargs
+ )
+ if tokenizer_config_file is not None:
+ with open(tokenizer_config_file, encoding="utf-8") as reader:
+ config_dict = json.load(reader)
+
+ processor_class = config_dict.get("processor_class", None)
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
+
+ if processor_class is None:
+ # Otherwise, load config, if it can be loaded.
+ if not isinstance(config, PretrainedConfig):
+ config = AutoConfig.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ )
+
+ # And check if the config contains the processor class.
+ processor_class = getattr(config, "processor_class", None)
+ if hasattr(config, "auto_map") and "AutoProcessor" in config.auto_map:
+ processor_auto_map = config.auto_map["AutoProcessor"]
+
+ if processor_class is not None:
+ processor_class = processor_class_from_name(processor_class)
+
+ has_remote_code = processor_auto_map is not None
+ has_local_code = processor_class is not None or type(config) in PROCESSOR_MAPPING
+ if has_remote_code:
+ if "--" in processor_auto_map:
+ upstream_repo = processor_auto_map.split("--")[0]
+ else:
+ upstream_repo = None
+ trust_remote_code = resolve_trust_remote_code(
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
+ )
+
+ if has_remote_code and trust_remote_code:
+ processor_class = get_class_from_dynamic_module(
+ processor_auto_map, pretrained_model_name_or_path, **kwargs
+ )
+ _ = kwargs.pop("code_revision", None)
+ processor_class.register_for_auto_class()
+ return processor_class.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ )
+ elif processor_class is not None:
+ return processor_class.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ )
+ # Last try: we use the PROCESSOR_MAPPING.
+ elif type(config) in PROCESSOR_MAPPING:
+ return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs)
+
+ # At this stage, there doesn't seem to be a `Processor` class available for this model, so let's try a
+ # tokenizer.
+ try:
+ return AutoTokenizer.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ )
+ except Exception:
+ try:
+ return AutoImageProcessor.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ )
+ except Exception:
+ pass
+
+ try:
+ return AutoFeatureExtractor.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ )
+ except Exception:
+ pass
+
+ raise ValueError(
+ f"Unrecognized processing class in {pretrained_model_name_or_path}. Can't instantiate a processor, a "
+ "tokenizer, an image processor or a feature extractor for this model. Make sure the repository contains "
+ "the files of at least one of those processing classes."
+ )
+
+ @staticmethod
+ def register(config_class, processor_class, exist_ok=False):
+ """
+ Register a new processor for this class.
+
+ Args:
+ config_class ([`PretrainedConfig`]):
+ The configuration corresponding to the model to register.
+ processor_class ([`ProcessorMixin`]): The processor to register.
+ """
+ PROCESSOR_MAPPING.register(config_class, processor_class, exist_ok=exist_ok)
+
+
+__all__ = ["PROCESSOR_MAPPING", "AutoProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/auto/tokenization_auto.py b/venv/lib/python3.13/site-packages/transformers/models/auto/tokenization_auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..163aba1cb12753ec7f9e487347b2a604df35fb28
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/auto/tokenization_auto.py
@@ -0,0 +1,1235 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Auto Tokenizer class."""
+
+import importlib
+import json
+import os
+import warnings
+from collections import OrderedDict
+from typing import Any, Optional, Union
+
+from transformers.utils.import_utils import is_mistral_common_available
+
+from ...configuration_utils import PretrainedConfig
+from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
+from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint
+from ...tokenization_utils import PreTrainedTokenizer
+from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
+from ...utils import (
+ cached_file,
+ extract_commit_hash,
+ is_g2p_en_available,
+ is_sentencepiece_available,
+ is_tokenizers_available,
+ logging,
+)
+from ..encoder_decoder import EncoderDecoderConfig
+from .auto_factory import _LazyAutoMapping
+from .configuration_auto import (
+ CONFIG_MAPPING_NAMES,
+ AutoConfig,
+ config_class_to_model_type,
+ model_type_to_module_name,
+ replace_list_option_in_docstrings,
+)
+
+
+if is_tokenizers_available():
+ from ...tokenization_utils_fast import PreTrainedTokenizerFast
+else:
+ PreTrainedTokenizerFast = None
+
+
+logger = logging.get_logger(__name__)
+
+# Explicit rather than inferred generics to significantly improves completion suggestion performance for language servers.
+TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
+ [
+ (
+ "aimv2",
+ (
+ "CLIPTokenizer",
+ "CLIPTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "albert",
+ (
+ "AlbertTokenizer" if is_sentencepiece_available() else None,
+ "AlbertTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ ("arcee", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("aya_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
+ ("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ ("bart", ("BartTokenizer", "BartTokenizerFast")),
+ (
+ "barthez",
+ (
+ "BarthezTokenizer" if is_sentencepiece_available() else None,
+ "BarthezTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("bartpho", ("BartphoTokenizer", None)),
+ ("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ ("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
+ ("bert-japanese", ("BertJapaneseTokenizer", None)),
+ ("bertweet", ("BertweetTokenizer", None)),
+ (
+ "big_bird",
+ (
+ "BigBirdTokenizer" if is_sentencepiece_available() else None,
+ "BigBirdTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
+ ("biogpt", ("BioGptTokenizer", None)),
+ ("bitnet", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
+ ("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
+ ("blip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ ("blip-2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
+ ("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)),
+ ("blt", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("bridgetower", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
+ ("bros", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ ("byt5", ("ByT5Tokenizer", None)),
+ (
+ "camembert",
+ (
+ "CamembertTokenizer" if is_sentencepiece_available() else None,
+ "CamembertTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("canine", ("CanineTokenizer", None)),
+ (
+ "chameleon",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("chinese_clip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "clap",
+ (
+ "RobertaTokenizer",
+ "RobertaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "clip",
+ (
+ "CLIPTokenizer",
+ "CLIPTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "clipseg",
+ (
+ "CLIPTokenizer",
+ "CLIPTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("clvp", ("ClvpTokenizer", None)),
+ (
+ "code_llama",
+ (
+ "CodeLlamaTokenizer" if is_sentencepiece_available() else None,
+ "CodeLlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
+ ("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
+ ("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
+ ("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
+ ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "cpm",
+ (
+ "CpmTokenizer" if is_sentencepiece_available() else None,
+ "CpmTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("cpmant", ("CpmAntTokenizer", None)),
+ ("csm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("ctrl", ("CTRLTokenizer", None)),
+ ("data2vec-audio", ("Wav2Vec2CTCTokenizer", None)),
+ ("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
+ ("dbrx", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
+ ("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "deberta-v2",
+ (
+ "DebertaV2Tokenizer" if is_sentencepiece_available() else None,
+ "DebertaV2TokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "deepseek_v2",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "deepseek_v3",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "deepseek_vl",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "deepseek_vl_hybrid",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("dia", ("DiaTokenizer", None)),
+ (
+ "diffllama",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "dpr",
+ (
+ "DPRQuestionEncoderTokenizer",
+ "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
+ ("emu3", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
+ ("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ ("ernie4_5", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
+ ("esm", ("EsmTokenizer", None)),
+ (
+ "exaone4",
+ (
+ "GPT2Tokenizer" if is_tokenizers_available() else None,
+ "GPT2TokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("falcon", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("falcon_mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "fastspeech2_conformer",
+ ("FastSpeech2ConformerTokenizer" if is_g2p_en_available() else None, None),
+ ),
+ ("flaubert", ("FlaubertTokenizer", None)),
+ ("flex_olmo", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
+ ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
+ ("fsmt", ("FSMTTokenizer", None)),
+ ("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "gemma",
+ (
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "gemma2",
+ (
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "gemma3",
+ (
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "gemma3_text",
+ (
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "gemma3n",
+ (
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "gemma3n_text",
+ (
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ ("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("glm4_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("glm4v", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("glm4v_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
+ ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
+ ("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
+ ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
+ ("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
+ ("gpt_neox_japanese", ("GPTNeoXJapaneseTokenizer", None)),
+ ("gpt_oss", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
+ ("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
+ ("granite", ("GPT2Tokenizer", None)),
+ ("granitemoe", ("GPT2Tokenizer", None)),
+ ("granitemoehybrid", ("GPT2Tokenizer", None)),
+ ("granitemoeshared", ("GPT2Tokenizer", None)),
+ ("grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ ("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
+ ("helium", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
+ ("hubert", ("Wav2Vec2CTCTokenizer", None)),
+ ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
+ ("idefics", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("idefics2", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
+ ("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
+ ("internvl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "jamba",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("janus", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "jetmoe",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("jukebox", ("JukeboxTokenizer", None)),
+ (
+ "kosmos-2",
+ (
+ "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("kosmos-2.5", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
+ ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
+ ("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
+ ("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
+ ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
+ ("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "llama",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "llama4",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "llama4_text",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("llava_onevision", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "longt5",
+ (
+ "T5Tokenizer" if is_sentencepiece_available() else None,
+ "T5TokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("luke", ("LukeTokenizer", None)),
+ ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
+ ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
+ ("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
+ ("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
+ ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
+ (
+ "mbart",
+ (
+ "MBartTokenizer" if is_sentencepiece_available() else None,
+ "MBartTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "mbart50",
+ (
+ "MBart50Tokenizer" if is_sentencepiece_available() else None,
+ "MBart50TokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
+ ("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "metaclip_2",
+ (
+ "XLMRobertaTokenizer",
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("mgp-str", ("MgpstrTokenizer", None)),
+ (
+ "minimax",
+ (
+ "GPT2Tokenizer" if is_sentencepiece_available() else None,
+ "GPT2TokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "ministral",
+ (
+ "MistralCommonTokenizer"
+ if is_mistral_common_available()
+ else ("LlamaTokenizer" if is_sentencepiece_available() else None),
+ "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
+ ),
+ ),
+ (
+ "mistral",
+ (
+ "MistralCommonTokenizer"
+ if is_mistral_common_available()
+ else ("LlamaTokenizer" if is_sentencepiece_available() else None),
+ "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
+ ),
+ ),
+ (
+ "mistral3",
+ (
+ "MistralCommonTokenizer"
+ if is_mistral_common_available()
+ else ("LlamaTokenizer" if is_sentencepiece_available() else None),
+ "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
+ ),
+ ),
+ (
+ "mixtral",
+ (
+ "MistralCommonTokenizer"
+ if is_mistral_common_available()
+ else ("LlamaTokenizer" if is_sentencepiece_available() else None),
+ "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
+ ),
+ ),
+ ("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
+ ("mm-grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
+ ("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("moonshine", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
+ ("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
+ ("mra", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "mt5",
+ (
+ "MT5Tokenizer" if is_sentencepiece_available() else None,
+ "MT5TokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("musicgen", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
+ ("musicgen_melody", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
+ ("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)),
+ ("myt5", ("MyT5Tokenizer", None)),
+ ("nemotron", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "nllb",
+ (
+ "NllbTokenizer" if is_sentencepiece_available() else None,
+ "NllbTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "nllb-moe",
+ (
+ "NllbTokenizer" if is_sentencepiece_available() else None,
+ "NllbTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "nystromformer",
+ (
+ "AlbertTokenizer" if is_sentencepiece_available() else None,
+ "AlbertTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
+ ("olmo2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
+ ("olmo3", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
+ ("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "omdet-turbo",
+ ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None),
+ ),
+ ("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "openai-gpt",
+ ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None),
+ ),
+ ("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
+ ("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
+ ("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
+ ("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("parakeet", ("ParakeetCTCTokenizer", None)),
+ (
+ "pegasus",
+ (
+ "PegasusTokenizer" if is_sentencepiece_available() else None,
+ "PegasusTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "pegasus_x",
+ (
+ "PegasusTokenizer" if is_sentencepiece_available() else None,
+ "PegasusTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "perceiver",
+ (
+ "PerceiverTokenizer",
+ None,
+ ),
+ ),
+ (
+ "persimmon",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("phi", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
+ ("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("phobert", ("PhobertTokenizer", None)),
+ ("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "pixtral",
+ (
+ None,
+ "MistralCommonTokenizer"
+ if is_mistral_common_available()
+ else ("PreTrainedTokenizerFast" if is_tokenizers_available() else None),
+ ),
+ ),
+ ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
+ ("prophetnet", ("ProphetNetTokenizer", None)),
+ ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "qwen2",
+ (
+ "Qwen2Tokenizer",
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("qwen2_5_omni", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
+ ("qwen2_5_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
+ ("qwen2_audio", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "qwen2_moe",
+ (
+ "Qwen2Tokenizer",
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("qwen2_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "qwen3",
+ (
+ "Qwen2Tokenizer",
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "qwen3_moe",
+ (
+ "Qwen2Tokenizer",
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "qwen3_next",
+ (
+ "Qwen2Tokenizer",
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("qwen3_omni_moe", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
+ ("qwen3_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
+ ("qwen3_vl_moe", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
+ ("rag", ("RagTokenizer", None)),
+ ("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "recurrent_gemma",
+ (
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "reformer",
+ (
+ "ReformerTokenizer" if is_sentencepiece_available() else None,
+ "ReformerTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "rembert",
+ (
+ "RemBertTokenizer" if is_sentencepiece_available() else None,
+ "RemBertTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
+ ("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "roberta-prelayernorm",
+ ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None),
+ ),
+ ("roc_bert", ("RoCBertTokenizer", None)),
+ ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
+ ("rwkv", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "seamless_m4t",
+ (
+ "SeamlessM4TTokenizer" if is_sentencepiece_available() else None,
+ "SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "seamless_m4t_v2",
+ (
+ "SeamlessM4TTokenizer" if is_sentencepiece_available() else None,
+ "SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "shieldgemma2",
+ (
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)),
+ (
+ "siglip2",
+ (
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("smollm3", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
+ ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
+ ("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
+ ("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
+ ("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
+ (
+ "squeezebert",
+ ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
+ ),
+ ("stablelm", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
+ ("starcoder2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "switch_transformers",
+ (
+ "T5Tokenizer" if is_sentencepiece_available() else None,
+ "T5TokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "t5",
+ (
+ "T5Tokenizer" if is_sentencepiece_available() else None,
+ "T5TokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "t5gemma",
+ (
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("tapas", ("TapasTokenizer", None)),
+ ("tapex", ("TapexTokenizer", None)),
+ ("transfo-xl", ("TransfoXLTokenizer", None)),
+ ("tvp", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "udop",
+ (
+ "UdopTokenizer" if is_sentencepiece_available() else None,
+ "UdopTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "umt5",
+ (
+ "T5Tokenizer" if is_sentencepiece_available() else None,
+ "T5TokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("video_llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ ("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
+ ("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
+ ("vits", ("VitsTokenizer", None)),
+ (
+ "voxtral",
+ (
+ "MistralCommonTokenizer" if is_mistral_common_available() else None,
+ "PreTrainedTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
+ ),
+ ),
+ ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
+ ("wav2vec2-bert", ("Wav2Vec2CTCTokenizer", None)),
+ ("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
+ ("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
+ ("whisper", ("WhisperTokenizer", "WhisperTokenizerFast" if is_tokenizers_available() else None)),
+ ("xclip", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "xglm",
+ (
+ "XGLMTokenizer" if is_sentencepiece_available() else None,
+ "XGLMTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("xlm", ("XLMTokenizer", None)),
+ ("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
+ (
+ "xlm-roberta",
+ (
+ "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "xlm-roberta-xl",
+ (
+ "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "xlnet",
+ (
+ "XLNetTokenizer" if is_sentencepiece_available() else None,
+ "XLNetTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ("xlstm", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
+ (
+ "xmod",
+ (
+ "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "yoso",
+ (
+ "AlbertTokenizer" if is_sentencepiece_available() else None,
+ "AlbertTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "zamba",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "zamba2",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ ]
+)
+
+TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
+
+CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
+
+
+def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
+ if class_name == "PreTrainedTokenizerFast":
+ return PreTrainedTokenizerFast
+
+ for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
+ if class_name in tokenizers:
+ module_name = model_type_to_module_name(module_name)
+ if module_name in ["mistral", "mixtral", "ministral"] and class_name == "MistralCommonTokenizer":
+ module = importlib.import_module(".tokenization_mistral_common", "transformers")
+ else:
+ module = importlib.import_module(f".{module_name}", "transformers.models")
+ try:
+ return getattr(module, class_name)
+ except AttributeError:
+ continue
+
+ for tokenizers in TOKENIZER_MAPPING._extra_content.values():
+ for tokenizer in tokenizers:
+ if getattr(tokenizer, "__name__", None) == class_name:
+ return tokenizer
+
+ # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
+ # init and we return the proper dummy to get an appropriate error message.
+ main_module = importlib.import_module("transformers")
+ if hasattr(main_module, class_name):
+ return getattr(main_module, class_name)
+
+ return None
+
+
+def get_tokenizer_config(
+ pretrained_model_name_or_path: Union[str, os.PathLike[str]],
+ cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
+ force_download: bool = False,
+ resume_download: Optional[bool] = None,
+ proxies: Optional[dict[str, str]] = None,
+ token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ subfolder: str = "",
+ **kwargs,
+) -> dict[str, Any]:
+ """
+ Loads the tokenizer configuration from a pretrained model tokenizer configuration.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the tokenizer config is located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+
+
+
+ Passing `token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `dict`: The configuration of the tokenizer.
+
+ Examples:
+
+ ```python
+ # Download configuration from huggingface.co and cache.
+ tokenizer_config = get_tokenizer_config("google-bert/bert-base-uncased")
+ # This model does not have a tokenizer config so the result will be an empty dict.
+ tokenizer_config = get_tokenizer_config("FacebookAI/xlm-roberta-base")
+
+ # Save a pretrained tokenizer locally and you can reload its config
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
+ tokenizer.save_pretrained("tokenizer-test")
+ tokenizer_config = get_tokenizer_config("tokenizer-test")
+ ```"""
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
+ token = use_auth_token
+
+ commit_hash = kwargs.get("_commit_hash")
+ resolved_config_file = cached_file(
+ pretrained_model_name_or_path,
+ TOKENIZER_CONFIG_FILE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ local_files_only=local_files_only,
+ subfolder=subfolder,
+ _raise_exceptions_for_gated_repo=False,
+ _raise_exceptions_for_missing_entries=False,
+ _raise_exceptions_for_connection_errors=False,
+ _commit_hash=commit_hash,
+ )
+ if resolved_config_file is None:
+ logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
+ return {}
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
+
+ with open(resolved_config_file, encoding="utf-8") as reader:
+ result = json.load(reader)
+ result["_commit_hash"] = commit_hash
+ return result
+
+
+class AutoTokenizer:
+ r"""
+ This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
+ created with the [`AutoTokenizer.from_pretrained`] class method.
+
+ This class cannot be instantiated directly using `__init__()` (throws an error).
+ """
+
+ def __init__(self):
+ raise OSError(
+ "AutoTokenizer is designed to be instantiated "
+ "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
+ )
+
+ @classmethod
+ @replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
+ r"""
+ Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.
+
+ The tokenizer class to instantiate is selected based on the `model_type` property of the config object (either
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
+ falling back to using pattern matching on `pretrained_model_name_or_path`:
+
+ List options
+
+ Params:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ Can be either:
+
+ - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
+ using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+ - A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
+ single vocabulary file (like Bert or XLNet), e.g.: `./my_model_directory/vocab.txt`. (Not
+ applicable to all derived classes)
+ inputs (additional positional arguments, *optional*):
+ Will be passed along to the Tokenizer `__init__()` method.
+ config ([`PretrainedConfig`], *optional*)
+ The configuration object used to determine the tokenizer class to instantiate.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download the model weights and configuration files and override the
+ cached versions if they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ subfolder (`str`, *optional*):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
+ facebook/rag-token-base), specify it here.
+ use_fast (`bool`, *optional*, defaults to `True`):
+ Use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported for
+ a given model. If a fast tokenizer is not available for a given model, a normal Python-based tokenizer
+ is returned instead.
+ tokenizer_type (`str`, *optional*):
+ Tokenizer type to be loaded.
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
+ execute code present on the Hub on your local machine.
+ kwargs (additional keyword arguments, *optional*):
+ Will be passed to the Tokenizer `__init__()` method. Can be used to set special tokens like
+ `bos_token`, `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`,
+ `additional_special_tokens`. See parameters in the `__init__()` for more details.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer
+
+ >>> # Download vocabulary from huggingface.co and cache.
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
+
+ >>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
+ >>> tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
+
+ >>> # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
+ >>> # tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")
+
+ >>> # Download vocabulary from huggingface.co and define model-specific arguments
+ >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base", add_prefix_space=True)
+ ```"""
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token") is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ config = kwargs.pop("config", None)
+ kwargs["_from_auto"] = True
+
+ use_fast = kwargs.pop("use_fast", True)
+ tokenizer_type = kwargs.pop("tokenizer_type", None)
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
+ gguf_file = kwargs.get("gguf_file")
+
+ # First, let's see whether the tokenizer_type is passed so that we can leverage it
+ if tokenizer_type is not None:
+ tokenizer_class = None
+ tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None)
+
+ if tokenizer_class_tuple is None:
+ raise ValueError(
+ f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of "
+ f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES)}."
+ )
+
+ tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple
+
+ if use_fast:
+ if tokenizer_fast_class_name is not None:
+ tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name)
+ else:
+ logger.warning(
+ "`use_fast` is set to `True` but the tokenizer class does not have a fast version. "
+ " Falling back to the slow version."
+ )
+ if tokenizer_class is None:
+ tokenizer_class = tokenizer_class_from_name(tokenizer_class_name)
+
+ if tokenizer_class is None:
+ raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.")
+
+ return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
+
+ # Next, let's try to use the tokenizer_config file to get the tokenizer class.
+ tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
+ if "_commit_hash" in tokenizer_config:
+ kwargs["_commit_hash"] = tokenizer_config["_commit_hash"]
+ config_tokenizer_class = tokenizer_config.get("tokenizer_class")
+ tokenizer_auto_map = None
+ if "auto_map" in tokenizer_config:
+ if isinstance(tokenizer_config["auto_map"], (tuple, list)):
+ # Legacy format for dynamic tokenizers
+ tokenizer_auto_map = tokenizer_config["auto_map"]
+ else:
+ tokenizer_auto_map = tokenizer_config["auto_map"].get("AutoTokenizer", None)
+
+ # If that did not work, let's try to use the config.
+ if config_tokenizer_class is None:
+ if not isinstance(config, PretrainedConfig):
+ if gguf_file:
+ gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **kwargs)
+ config_dict = load_gguf_checkpoint(gguf_path, return_tensors=False)["config"]
+ config = AutoConfig.for_model(**config_dict)
+ else:
+ config = AutoConfig.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ )
+ config_tokenizer_class = config.tokenizer_class
+ if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
+ tokenizer_auto_map = config.auto_map["AutoTokenizer"]
+
+ has_remote_code = tokenizer_auto_map is not None
+ has_local_code = type(config) in TOKENIZER_MAPPING or (
+ config_tokenizer_class is not None
+ and (
+ tokenizer_class_from_name(config_tokenizer_class) is not None
+ or tokenizer_class_from_name(config_tokenizer_class + "Fast") is not None
+ )
+ )
+ if has_remote_code:
+ if use_fast and tokenizer_auto_map[1] is not None:
+ class_ref = tokenizer_auto_map[1]
+ else:
+ class_ref = tokenizer_auto_map[0]
+ if "--" in class_ref:
+ upstream_repo = class_ref.split("--")[0]
+ else:
+ upstream_repo = None
+ trust_remote_code = resolve_trust_remote_code(
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
+ )
+
+ if has_remote_code and trust_remote_code:
+ tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
+ _ = kwargs.pop("code_revision", None)
+ tokenizer_class.register_for_auto_class()
+ return tokenizer_class.from_pretrained(
+ pretrained_model_name_or_path, *inputs, trust_remote_code=trust_remote_code, **kwargs
+ )
+ elif config_tokenizer_class is not None:
+ tokenizer_class = None
+ if use_fast and not config_tokenizer_class.endswith("Fast"):
+ tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
+ tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
+ if tokenizer_class is None:
+ tokenizer_class_candidate = config_tokenizer_class
+ tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
+ if tokenizer_class is None:
+ raise ValueError(
+ f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
+ )
+ return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
+
+ # Otherwise we have to be creative.
+ # if model is an encoder decoder, the encoder tokenizer class is used by default
+ if isinstance(config, EncoderDecoderConfig):
+ if type(config.decoder) is not type(config.encoder):
+ logger.warning(
+ f"The encoder model config class: {config.encoder.__class__} is different from the decoder model "
+ f"config class: {config.decoder.__class__}. It is not recommended to use the "
+ "`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder "
+ "specific tokenizer classes."
+ )
+ config = config.encoder
+
+ model_type = config_class_to_model_type(type(config).__name__)
+ if model_type is not None:
+ tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
+
+ if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
+ return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
+ else:
+ if tokenizer_class_py is not None:
+ return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
+ else:
+ raise ValueError(
+ "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed "
+ "in order to use this tokenizer."
+ )
+
+ raise ValueError(
+ f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n"
+ f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING)}."
+ )
+
+ @staticmethod
+ def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, exist_ok=False):
+ """
+ Register a new tokenizer in this mapping.
+
+
+ Args:
+ config_class ([`PretrainedConfig`]):
+ The configuration corresponding to the model to register.
+ slow_tokenizer_class ([`PretrainedTokenizer`], *optional*):
+ The slow tokenizer to register.
+ fast_tokenizer_class ([`PretrainedTokenizerFast`], *optional*):
+ The fast tokenizer to register.
+ """
+ if slow_tokenizer_class is None and fast_tokenizer_class is None:
+ raise ValueError("You need to pass either a `slow_tokenizer_class` or a `fast_tokenizer_class")
+ if slow_tokenizer_class is not None and issubclass(slow_tokenizer_class, PreTrainedTokenizerFast):
+ raise ValueError("You passed a fast tokenizer in the `slow_tokenizer_class`.")
+ if fast_tokenizer_class is not None and issubclass(fast_tokenizer_class, PreTrainedTokenizer):
+ raise ValueError("You passed a slow tokenizer in the `fast_tokenizer_class`.")
+
+ if (
+ slow_tokenizer_class is not None
+ and fast_tokenizer_class is not None
+ and issubclass(fast_tokenizer_class, PreTrainedTokenizerFast)
+ and fast_tokenizer_class.slow_tokenizer_class != slow_tokenizer_class
+ ):
+ raise ValueError(
+ "The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not "
+ "consistent with the slow tokenizer class you passed (fast tokenizer has "
+ f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those "
+ "so they match!"
+ )
+
+ # Avoid resetting a set slow/fast tokenizer if we are passing just the other ones.
+ if config_class in TOKENIZER_MAPPING._extra_content:
+ existing_slow, existing_fast = TOKENIZER_MAPPING[config_class]
+ if slow_tokenizer_class is None:
+ slow_tokenizer_class = existing_slow
+ if fast_tokenizer_class is None:
+ fast_tokenizer_class = existing_fast
+
+ TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class), exist_ok=exist_ok)
+
+
+__all__ = ["TOKENIZER_MAPPING", "AutoTokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/auto/video_processing_auto.py b/venv/lib/python3.13/site-packages/transformers/models/auto/video_processing_auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..84bbc8e6fdb10ea5e0a72caec2135825ff95dc20
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/auto/video_processing_auto.py
@@ -0,0 +1,393 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""AutoVideoProcessor class."""
+
+import importlib
+import json
+import os
+import warnings
+from collections import OrderedDict
+from typing import TYPE_CHECKING, Optional, Union
+
+# Build the list of all video processors
+from ...configuration_utils import PretrainedConfig
+from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
+from ...utils import CONFIG_NAME, VIDEO_PROCESSOR_NAME, cached_file, is_torchvision_available, logging
+from ...utils.import_utils import requires
+from ...video_processing_utils import BaseVideoProcessor
+from .auto_factory import _LazyAutoMapping
+from .configuration_auto import (
+ CONFIG_MAPPING_NAMES,
+ AutoConfig,
+ model_type_to_module_name,
+ replace_list_option_in_docstrings,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+if TYPE_CHECKING:
+ # This significantly improves completion suggestion performance when
+ # the transformers package is used with Microsoft's Pylance language server.
+ VIDEO_PROCESSOR_MAPPING_NAMES: OrderedDict[str, tuple[Optional[str], Optional[str]]] = OrderedDict()
+else:
+ VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict(
+ [
+ ("glm4v", "Glm4vVideoProcessor"),
+ ("instructblip", "InstructBlipVideoVideoProcessor"),
+ ("instructblipvideo", "InstructBlipVideoVideoProcessor"),
+ ("internvl", "InternVLVideoProcessor"),
+ ("llava_next_video", "LlavaNextVideoVideoProcessor"),
+ ("llava_onevision", "LlavaOnevisionVideoProcessor"),
+ ("perception_lm", "PerceptionLMVideoProcessor"),
+ ("qwen2_5_omni", "Qwen2VLVideoProcessor"),
+ ("qwen2_5_vl", "Qwen2VLVideoProcessor"),
+ ("qwen2_vl", "Qwen2VLVideoProcessor"),
+ ("qwen3_omni_moe", "Qwen2VLVideoProcessor"),
+ ("qwen3_vl", "Qwen3VLVideoProcessor"),
+ ("qwen3_vl_moe", "Qwen3VLVideoProcessor"),
+ ("sam2_video", "Sam2VideoVideoProcessor"),
+ ("smolvlm", "SmolVLMVideoProcessor"),
+ ("video_llava", "VideoLlavaVideoProcessor"),
+ ("vjepa2", "VJEPA2VideoProcessor"),
+ ]
+ )
+
+for model_type, video_processors in VIDEO_PROCESSOR_MAPPING_NAMES.items():
+ fast_video_processor_class = video_processors
+
+ # If the torchvision is not available, we set it to None
+ if not is_torchvision_available():
+ fast_video_processor_class = None
+
+ VIDEO_PROCESSOR_MAPPING_NAMES[model_type] = fast_video_processor_class
+
+VIDEO_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, VIDEO_PROCESSOR_MAPPING_NAMES)
+
+
+def video_processor_class_from_name(class_name: str):
+ for module_name, extractors in VIDEO_PROCESSOR_MAPPING_NAMES.items():
+ if class_name in extractors:
+ module_name = model_type_to_module_name(module_name)
+
+ module = importlib.import_module(f".{module_name}", "transformers.models")
+ try:
+ return getattr(module, class_name)
+ except AttributeError:
+ continue
+
+ for extractor in VIDEO_PROCESSOR_MAPPING._extra_content.values():
+ if getattr(extractor, "__name__", None) == class_name:
+ return extractor
+
+ # We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main
+ # init and we return the proper dummy to get an appropriate error message.
+ main_module = importlib.import_module("transformers")
+ if hasattr(main_module, class_name):
+ return getattr(main_module, class_name)
+
+ return None
+
+
+def get_video_processor_config(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: Optional[bool] = None,
+ proxies: Optional[dict[str, str]] = None,
+ token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ **kwargs,
+):
+ """
+ Loads the video processor configuration from a pretrained model video processor configuration.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the video processor configuration from local files.
+
+
+
+ Passing `token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `Dict`: The configuration of the video processor.
+
+ Examples:
+
+ ```python
+ # Download configuration from huggingface.co and cache.
+ video_processor_config = get_video_processor_config("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
+ # This model does not have a video processor config so the result will be an empty dict.
+ video_processor_config = get_video_processor_config("FacebookAI/xlm-roberta-base")
+
+ # Save a pretrained video processor locally and you can reload its config
+ from transformers import AutoVideoProcessor
+
+ video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
+ video_processor.save_pretrained("video-processor-test")
+ video_processor = get_video_processor_config("video-processor-test")
+ ```"""
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
+ token = use_auth_token
+
+ resolved_config_file = cached_file(
+ pretrained_model_name_or_path,
+ VIDEO_PROCESSOR_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ local_files_only=local_files_only,
+ )
+ if resolved_config_file is None:
+ logger.info(
+ "Could not locate the video processor configuration file, will try to use the model config instead."
+ )
+ return {}
+
+ with open(resolved_config_file, encoding="utf-8") as reader:
+ return json.load(reader)
+
+
+@requires(backends=("vision", "torchvision"))
+class AutoVideoProcessor:
+ r"""
+ This is a generic video processor class that will be instantiated as one of the video processor classes of the
+ library when created with the [`AutoVideoProcessor.from_pretrained`] class method.
+
+ This class cannot be instantiated directly using `__init__()` (throws an error).
+ """
+
+ def __init__(self):
+ raise OSError(
+ "AutoVideoProcessor is designed to be instantiated "
+ "using the `AutoVideoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
+ )
+
+ @classmethod
+ @replace_list_option_in_docstrings(VIDEO_PROCESSOR_MAPPING_NAMES)
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
+ r"""
+ Instantiate one of the video processor classes of the library from a pretrained model vocabulary.
+
+ The video processor class to instantiate is selected based on the `model_type` property of the config object
+ (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
+ missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
+
+ List options
+
+ Params:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained video_processor hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a video processor file saved using the
+ [`~video_processing_utils.BaseVideoProcessor.save_pretrained`] method, e.g.,
+ `./my_model_directory/`.
+ - a path or url to a saved video processor JSON *file*, e.g.,
+ `./my_model_directory/preprocessor_config.json`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model video processor should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the video processor files and override the cached versions if
+ they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ If `False`, then this function returns just the final video processor object. If `True`, then this
+ functions returns a `Tuple(video_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
+ consisting of the key/value pairs whose keys are not video processor attributes: i.e., the part of
+ `kwargs` which has not been used to update `video_processor` and is otherwise ignored.
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
+ execute code present on the Hub on your local machine.
+ kwargs (`dict[str, Any]`, *optional*):
+ The values in kwargs of any keys which are video processor attributes will be used to override the
+ loaded values. Behavior concerning key/value pairs whose keys are *not* video processor attributes is
+ controlled by the `return_unused_kwargs` keyword parameter.
+
+
+
+ Passing `token=True` is required when you want to use a private model.
+
+
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoVideoProcessor
+
+ >>> # Download video processor from huggingface.co and cache.
+ >>> video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
+
+ >>> # If video processor files are in a directory (e.g. video processor was saved using *save_pretrained('./test/saved_model/')*)
+ >>> # video_processor = AutoVideoProcessor.from_pretrained("./test/saved_model/")
+ ```"""
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token") is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ config = kwargs.pop("config", None)
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
+ kwargs["_from_auto"] = True
+
+ config_dict, _ = BaseVideoProcessor.get_video_processor_dict(pretrained_model_name_or_path, **kwargs)
+ video_processor_class = config_dict.get("video_processor_type", None)
+ video_processor_auto_map = None
+ if "AutoVideoProcessor" in config_dict.get("auto_map", {}):
+ video_processor_auto_map = config_dict["auto_map"]["AutoVideoProcessor"]
+
+ # If we still don't have the video processor class, check if we're loading from a previous image processor config
+ # and if so, infer the video processor class from there.
+ if video_processor_class is None and video_processor_auto_map is None:
+ image_processor_class = config_dict.pop("image_processor_type", None)
+ if image_processor_class is not None:
+ video_processor_class_inferred = image_processor_class.replace("ImageProcessor", "VideoProcessor")
+
+ # Some models have different image processors, e.g. InternVL uses GotOCRImageProcessor
+ # We cannot use GotOCRVideoProcessor when falling back for BC and should try to infer from config later on
+ if video_processor_class_inferred in VIDEO_PROCESSOR_MAPPING_NAMES.values():
+ video_processor_class = video_processor_class_inferred
+ if "AutoImageProcessor" in config_dict.get("auto_map", {}):
+ image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
+ video_processor_auto_map = image_processor_auto_map.replace("ImageProcessor", "VideoProcessor")
+
+ # If we don't find the video processor class in the video processor config, let's try the model config.
+ if video_processor_class is None and video_processor_auto_map is None:
+ if not isinstance(config, PretrainedConfig):
+ config = AutoConfig.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ )
+ # It could be in `config.video_processor_type``
+ video_processor_class = getattr(config, "video_processor_type", None)
+ if hasattr(config, "auto_map") and "AutoVideoProcessor" in config.auto_map:
+ video_processor_auto_map = config.auto_map["AutoVideoProcessor"]
+
+ if video_processor_class is not None:
+ video_processor_class = video_processor_class_from_name(video_processor_class)
+
+ has_remote_code = video_processor_auto_map is not None
+ has_local_code = video_processor_class is not None or type(config) in VIDEO_PROCESSOR_MAPPING
+ trust_remote_code = resolve_trust_remote_code(
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
+ )
+
+ if has_remote_code and trust_remote_code:
+ class_ref = video_processor_auto_map
+ video_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
+ _ = kwargs.pop("code_revision", None)
+ video_processor_class.register_for_auto_class()
+ return video_processor_class.from_dict(config_dict, **kwargs)
+ elif video_processor_class is not None:
+ return video_processor_class.from_dict(config_dict, **kwargs)
+ # Last try: we use the VIDEO_PROCESSOR_MAPPING.
+ elif type(config) in VIDEO_PROCESSOR_MAPPING:
+ video_processor_class = VIDEO_PROCESSOR_MAPPING[type(config)]
+
+ if video_processor_class is not None:
+ return video_processor_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
+ else:
+ raise ValueError(
+ "This video processor cannot be instantiated. Please make sure you have `torchvision` installed."
+ )
+
+ raise ValueError(
+ f"Unrecognized video processor in {pretrained_model_name_or_path}. Should have a "
+ f"`video_processor_type` key in its {VIDEO_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
+ f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in VIDEO_PROCESSOR_MAPPING_NAMES)}"
+ )
+
+ @staticmethod
+ def register(
+ config_class,
+ video_processor_class,
+ exist_ok=False,
+ ):
+ """
+ Register a new video processor for this class.
+
+ Args:
+ config_class ([`PretrainedConfig`]):
+ The configuration corresponding to the model to register.
+ video_processor_class ([`BaseVideoProcessor`]):
+ The video processor to register.
+ """
+ VIDEO_PROCESSOR_MAPPING.register(config_class, video_processor_class, exist_ok=exist_ok)
+
+
+__all__ = ["VIDEO_PROCESSOR_MAPPING", "AutoVideoProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/aya_vision/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/aya_vision/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8be47cb228b19f02b87d195747e78c2a87de752
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/aya_vision/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_aya_vision import *
+ from .modeling_aya_vision import *
+ from .processing_aya_vision import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/aya_vision/configuration_aya_vision.py b/venv/lib/python3.13/site-packages/transformers/models/aya_vision/configuration_aya_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8c1965ec463f5854656f3b5064b23f0d981cacd
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/aya_vision/configuration_aya_vision.py
@@ -0,0 +1,110 @@
+# coding=utf-8
+# Copyright 2025 Cohere team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""AyaVision model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class AyaVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`AyaVisionForConditionalGeneration`]. It is used to instantiate an
+ AyaVision model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of AyaVision.
+ e.g. [CohereForAI/aya-vision-8b](https://huggingface.co/CohereForAI/aya-vision-8b)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SiglipVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Cohere2Config`):
+ The config object or dictionary of the text backbone.
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
+ If `"full"`, the full vision features are used.
+ vision_feature_layer (`int`, *optional*, defaults to -1):
+ The index of the layer to select the vision feature.
+ downsample_factor (`int`, *optional*, defaults to 2):
+ The downsample factor to apply to the vision features.
+ adapter_layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon value used for layer normalization in the adapter.
+ image_token_index (`int`, *optional*, defaults to 255036):
+ The image token index to encode the image prompt.
+ """
+
+ model_type = "aya_vision"
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ }
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ vision_feature_select_strategy="full",
+ vision_feature_layer=-1,
+ downsample_factor=2,
+ adapter_layer_norm_eps=1e-6,
+ image_token_index=255036,
+ **kwargs,
+ ):
+ self.image_token_index = image_token_index
+ self.downsample_factor = downsample_factor
+ self.adapter_layer_norm_eps = adapter_layer_norm_eps
+ if vision_feature_select_strategy not in ["default", "full"]:
+ raise ValueError(
+ "vision_feature_select_strategy should be one of 'default', 'full'."
+ f"Got: {vision_feature_select_strategy}"
+ )
+
+ self.vision_feature_select_strategy = vision_feature_select_strategy
+ self.vision_feature_layer = vision_feature_layer
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "siglip_vision_model")
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ elif vision_config is None:
+ vision_config = CONFIG_MAPPING["siglip_vision_model"](
+ hidden_size=1152,
+ intermediate_size=4304,
+ patch_size=14,
+ image_size=384,
+ num_hidden_layers=26,
+ num_attention_heads=14,
+ vision_use_head=False,
+ )
+
+ self.vision_config = vision_config
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "cohere2")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ text_config = CONFIG_MAPPING["cohere2"]()
+
+ self.text_config = text_config
+
+ super().__init__(**kwargs)
+
+
+__all__ = ["AyaVisionConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/aya_vision/modeling_aya_vision.py b/venv/lib/python3.13/site-packages/transformers/models/aya_vision/modeling_aya_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..bccbea0264b766811a43e0cfe05f1987536ad12b
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/aya_vision/modeling_aya_vision.py
@@ -0,0 +1,518 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/aya_vision/modular_aya_vision.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_aya_vision.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 the Cohere Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.generic import check_model_inputs
+from ..auto import AutoModel
+from .configuration_aya_vision import AyaVisionConfig
+
+
+class AyaVisionMultiModalProjector(nn.Module):
+ def __init__(self, config: AyaVisionConfig):
+ super().__init__()
+ self.config = config
+ self.downsample_factor = config.downsample_factor
+ self.alignment_intermediate_size = getattr(
+ config, "alignment_intermediate_size", config.text_config.hidden_size
+ )
+ self.layernorm = nn.LayerNorm(
+ config.vision_config.hidden_size * (config.downsample_factor**2), eps=config.adapter_layer_norm_eps
+ )
+
+ self.linear_1 = nn.Linear(
+ config.vision_config.hidden_size * (config.downsample_factor**2),
+ self.alignment_intermediate_size,
+ bias=True,
+ )
+
+ self.act = ACT2FN["silu"] # SwiGLU uses SiLU activation
+ # For SwiGLU, project down to half size since we split intermediate dim
+ self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2, config.text_config.hidden_size, bias=True)
+
+ def forward(self, image_features):
+ image_features = self.pixel_shuffle(image_features)
+ image_features = self.layernorm(image_features)
+ hidden_states = self.linear_1(image_features)
+
+ # Split along last dimension and apply SwiGLU
+ x, gate = hidden_states.chunk(2, dim=-1)
+ hidden_states = self.act(gate) * x
+
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+ def pixel_shuffle(self, image_features): # B, S, D
+ batch_size, seq_length, feature_dim = image_features.shape
+ height = width = int(seq_length**0.5)
+ image_features = image_features.reshape(image_features.shape[0], width, height, -1)
+ channels = image_features.shape[-1]
+ image_features = image_features.reshape(
+ batch_size, width, int(height / self.downsample_factor), int(channels * self.downsample_factor)
+ )
+ image_features = image_features.permute(0, 2, 1, 3)
+ image_features = image_features.reshape(
+ batch_size, int(height / self.downsample_factor), int(width / self.downsample_factor), -1
+ )
+ image_features = image_features.permute(0, 2, 1, 3)
+ return image_features
+
+
+@auto_docstring
+class AyaVisionPreTrainedModel(PreTrainedModel):
+ config: AyaVisionConfig
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _skip_keys_device_placement = "past_key_values"
+
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _can_compile_fullgraph = False
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": "DecoderLayer",
+ "attentions": "Attention",
+ }
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for AyaVision causal language model (or autoregressive) outputs.
+ """
+)
+class AyaVisionCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for AyaVision outputs, with hidden states and attentions.
+ """
+)
+class AyaVisionModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ The AyaVision model which consists of a vision backbone and a language model, without a language modeling head.
+ """
+)
+class AyaVisionModel(AyaVisionPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
+ def __init__(self, config: AyaVisionConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config.vision_config)
+
+ self.multi_modal_projector = AyaVisionMultiModalProjector(config)
+ self.language_model = AutoModel.from_config(config.text_config)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
+ The tensors corresponding to the input images.
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if vision_feature_select_strategy not in ["default", "full"]:
+ raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
+
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
+
+ # If we have one vision feature layer, return the corresponding hidden states,
+ # otherwise, select the hidden states of each feature layer and concatenate them
+ if isinstance(vision_feature_layer, int):
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
+ if vision_feature_select_strategy == "default":
+ selected_image_feature = selected_image_feature[:, 1:]
+ else:
+ hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
+ # For default; crop CLS from each hidden state in the hidden state pool
+ if vision_feature_select_strategy == "default":
+ hs_pool = [hs[:, 1:] for hs in hs_pool]
+ selected_image_feature = torch.cat(hs_pool, dim=-1)
+
+ image_features = self.multi_modal_projector(selected_image_feature)
+ return image_features
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ return special_image_mask
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, AyaVisionModelOutputWithPast]:
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return AyaVisionModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The AYA_VISION model which consists of a vision backbone and a language model.
+ """
+)
+class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: AyaVisionConfig):
+ super().__init__(config)
+ self.model = AyaVisionModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ **kwargs,
+ ):
+ return self.model.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ **kwargs,
+ )
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ labels: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ image_sizes: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, AyaVisionCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, AyaVisionForConditionalGeneration
+ >>> import torch
+
+ >>> torch_device = "cuda:0"
+ >>> processor = AutoProcessor.from_pretrained("CohereForAI/aya-vision-8b", use_fast=True)
+ >>> model = AyaVisionForConditionalGeneration.from_pretrained("CohereForAI/aya-vision-8b", device_map=torch_device)
+
+ >>> messages = [
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {
+ ... "type": "image",
+ ... "url": "https://pbs.twimg.com/media/Fx7YvfQWYAIp6rZ?format=jpg&name=medium",
+ ... },
+ ... {"type": "text", "text": "चित्र में लिखा पाठ क्या कहता है?"},
+ ... ],
+ ... }
+ ... ]
+
+ >>> inputs = processor.apply_chat_template(
+ ... messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", device=torch_device
+ ... ).to(model.device)
+
+ >>> gen_tokens = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.3)
+ >>> processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
+ ```"""
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ cache_position=cache_position,
+ image_sizes=image_sizes,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return AyaVisionCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+
+ return model_inputs
+
+
+__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/aya_vision/modular_aya_vision.py b/venv/lib/python3.13/site-packages/transformers/models/aya_vision/modular_aya_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b34aa5617bfedf92fa9082992784b6861a45940
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/aya_vision/modular_aya_vision.py
@@ -0,0 +1,297 @@
+# coding=utf-8
+# Copyright 2025 the Cohere Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch AyaVision model."""
+
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from transformers.models.llava.modeling_llava import (
+ LlavaCausalLMOutputWithPast,
+ LlavaForConditionalGeneration,
+ LlavaModel,
+ LlavaModelOutputWithPast,
+ LlavaPreTrainedModel,
+ TransformersKwargs,
+)
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, logging
+from ...utils.generic import check_model_inputs
+from .configuration_aya_vision import AyaVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class AyaVisionMultiModalProjector(nn.Module):
+ def __init__(self, config: AyaVisionConfig):
+ super().__init__()
+ self.config = config
+ self.downsample_factor = config.downsample_factor
+ self.alignment_intermediate_size = getattr(
+ config, "alignment_intermediate_size", config.text_config.hidden_size
+ )
+ self.layernorm = nn.LayerNorm(
+ config.vision_config.hidden_size * (config.downsample_factor**2), eps=config.adapter_layer_norm_eps
+ )
+
+ self.linear_1 = nn.Linear(
+ config.vision_config.hidden_size * (config.downsample_factor**2),
+ self.alignment_intermediate_size,
+ bias=True,
+ )
+
+ self.act = ACT2FN["silu"] # SwiGLU uses SiLU activation
+ # For SwiGLU, project down to half size since we split intermediate dim
+ self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2, config.text_config.hidden_size, bias=True)
+
+ def forward(self, image_features):
+ image_features = self.pixel_shuffle(image_features)
+ image_features = self.layernorm(image_features)
+ hidden_states = self.linear_1(image_features)
+
+ # Split along last dimension and apply SwiGLU
+ x, gate = hidden_states.chunk(2, dim=-1)
+ hidden_states = self.act(gate) * x
+
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+ def pixel_shuffle(self, image_features): # B, S, D
+ batch_size, seq_length, feature_dim = image_features.shape
+ height = width = int(seq_length**0.5)
+ image_features = image_features.reshape(image_features.shape[0], width, height, -1)
+ channels = image_features.shape[-1]
+ image_features = image_features.reshape(
+ batch_size, width, int(height / self.downsample_factor), int(channels * self.downsample_factor)
+ )
+ image_features = image_features.permute(0, 2, 1, 3)
+ image_features = image_features.reshape(
+ batch_size, int(height / self.downsample_factor), int(width / self.downsample_factor), -1
+ )
+ image_features = image_features.permute(0, 2, 1, 3)
+ return image_features
+
+
+class AyaVisionPreTrainedModel(LlavaPreTrainedModel):
+ _can_compile_fullgraph = False
+ _can_record_outputs = {
+ "hidden_states": "DecoderLayer",
+ "attentions": "Attention",
+ }
+
+
+class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
+ pass
+
+
+class AyaVisionModelOutputWithPast(LlavaModelOutputWithPast):
+ pass
+
+
+class AyaVisionModel(LlavaModel):
+ # Unlike LLaVA, the model doesn't have to deal with Pixtral-style image states
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
+ The tensors corresponding to the input images.
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if vision_feature_select_strategy not in ["default", "full"]:
+ raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
+
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
+
+ # If we have one vision feature layer, return the corresponding hidden states,
+ # otherwise, select the hidden states of each feature layer and concatenate them
+ if isinstance(vision_feature_layer, int):
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
+ if vision_feature_select_strategy == "default":
+ selected_image_feature = selected_image_feature[:, 1:]
+ else:
+ hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
+ # For default; crop CLS from each hidden state in the hidden state pool
+ if vision_feature_select_strategy == "default":
+ hs_pool = [hs[:, 1:] for hs in hs_pool]
+ selected_image_feature = torch.cat(hs_pool, dim=-1)
+
+ image_features = self.multi_modal_projector(selected_image_feature)
+ return image_features
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, AyaVisionModelOutputWithPast]:
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return AyaVisionModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration):
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ labels: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ image_sizes: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, AyaVisionCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, AyaVisionForConditionalGeneration
+ >>> import torch
+
+ >>> torch_device = "cuda:0"
+ >>> processor = AutoProcessor.from_pretrained("CohereForAI/aya-vision-8b", use_fast=True)
+ >>> model = AyaVisionForConditionalGeneration.from_pretrained("CohereForAI/aya-vision-8b", device_map=torch_device)
+
+ >>> messages = [
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {
+ ... "type": "image",
+ ... "url": "https://pbs.twimg.com/media/Fx7YvfQWYAIp6rZ?format=jpg&name=medium",
+ ... },
+ ... {"type": "text", "text": "चित्र में लिखा पाठ क्या कहता है?"},
+ ... ],
+ ... }
+ ... ]
+
+ >>> inputs = processor.apply_chat_template(
+ ... messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", device=torch_device
+ ... ).to(model.device)
+
+ >>> gen_tokens = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.3)
+ >>> processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
+ ```"""
+ super().forward(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ labels=labels,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ image_sizes=image_sizes,
+ **kwargs,
+ )
+
+
+__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/aya_vision/processing_aya_vision.py b/venv/lib/python3.13/site-packages/transformers/models/aya_vision/processing_aya_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..7045c967046da163fa2d7daef04f5c6e6bebc794
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/aya_vision/processing_aya_vision.py
@@ -0,0 +1,257 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BatchFeature
+from ...image_utils import ImageInput, make_flat_list_of_images
+from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+
+
+class AyaVisionImagesKwargs(ImagesKwargs, total=False):
+ crop_to_patches: Optional[bool]
+ min_patches: Optional[int]
+ max_patches: Optional[int]
+
+
+class AyaVisionProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs: AyaVisionImagesKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding_side": "left",
+ "padding": True,
+ "return_mm_token_type_ids": False,
+ },
+ "images_kwargs": {
+ "crop_to_patches": True,
+ },
+ }
+
+
+class AyaVisionProcessor(ProcessorMixin):
+ r"""
+ Constructs a AyaVision processor which wraps a [`AutoImageProcessor`] and
+ [`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and
+ tokenizer functionalities. See the [`~AyaVisionProcessor.__call__`] and [`~AyaVisionProcessor.decode`] for more information.
+ Args:
+ image_processor ([`AutoImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ patch_size (`int`, *optional*, defaults to 28):
+ The size of image patches for tokenization.
+ img_size (`int`, *optional*, defaults to 364):
+ The size of the image to be tokenized. This should correspond to the size given to the image processor.
+ image_token (`str`, *optional*, defaults to `""`):
+ The token to be used to represent an image in the text.
+ downsample_factor (`int`, *optional*, defaults to 1):
+ The factor by which to scale the patch size.
+ start_of_img_token (`str`, *optional*, defaults to `"<|START_OF_IMG|>"`):
+ The token to be used to represent the start of an image in the text.
+ end_of_img_token (`str`, *optional*, defaults to `"<|END_OF_IMG|>"`):
+ The token to be used to represent the end of an image in the text.
+ img_patch_token (`str`, *optional*, defaults to `"<|IMG_PATCH|>"`):
+ The token to be used to represent an image patch in the text.
+ img_line_break_token (`str`, *optional*, defaults to `"<|IMG_LINE_BREAK|>"`):
+ The token to be used to represent a line break in the text.
+ tile_token (`str`, *optional*, defaults to `"TILE"`):
+ The token to be used to represent an image patch in the text.
+ tile_global_token (`str`, *optional*, defaults to `"TILE_GLOBAL"`):
+ The token to be used to represent the cover image in the text.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer=None,
+ patch_size: int = 28,
+ img_size: int = 364,
+ image_token="", # set the default and let users change if they have peculiar special tokens in rare cases
+ downsample_factor: int = 1,
+ start_of_img_token="<|START_OF_IMG|>",
+ end_of_img_token="<|END_OF_IMG|>",
+ img_patch_token="<|IMG_PATCH|>",
+ img_line_break_token="<|IMG_LINE_BREAK|>",
+ tile_token="TILE",
+ tile_global_token="TILE_GLOBAL",
+ chat_template=None,
+ **kwargs,
+ ):
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ self.image_token = image_token
+ self.patch_size = patch_size * downsample_factor
+ self.img_size = img_size
+
+ self.start_of_img_token = start_of_img_token
+ self.end_of_img_token = end_of_img_token
+ self.img_patch_token = img_patch_token
+ self.img_line_break_token = img_line_break_token
+ self.tile_token = tile_token
+ self.tile_global_token = tile_global_token
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.img_patch_token)
+ self.image_ids = tokenizer.convert_tokens_to_ids(
+ [img_patch_token, tile_token, tile_global_token, start_of_img_token, end_of_img_token]
+ )
+
+ def _prompt_split_image(self, num_patches):
+ """
+ Create a structured string representation of image tokens
+
+ Args:
+ num_patches: Number of patches in the image
+
+ Returns:
+ String with appropriate image tokens
+ """
+
+ img_patches_per_tile = (self.img_size // self.patch_size) ** 2
+ img_string = f"{self.start_of_img_token}"
+ if num_patches > 1:
+ for idx in range(1, num_patches):
+ img_string += f"{self.tile_token}_{idx}" + f"{self.img_patch_token}" * img_patches_per_tile
+
+ img_string += f"{self.tile_global_token}" + f"{self.img_patch_token}" * img_patches_per_tile
+ img_string += f"{self.end_of_img_token}"
+ return img_string
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
+ audio=None,
+ videos=None,
+ **kwargs: Unpack[AyaVisionProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text.
+ To prepare the vision inputs, this method forwards the `images` and `kwargs` arguments to
+ GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+ if text is None:
+ raise ValueError("You have to specify text.")
+
+ output_kwargs = self._merge_kwargs(
+ AyaVisionProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ if not isinstance(text, (list, tuple)):
+ text = [text]
+
+ # Process images
+ image_inputs = {}
+ if images is not None:
+ images = self.image_processor.fetch_images(images)
+ images = make_flat_list_of_images(images)
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ num_patches = image_inputs.pop("num_patches")
+ image_index = 0
+ processed_text = []
+ for prompt in text:
+ new_prompt = prompt
+ while "" in new_prompt:
+ # Replace the image placeholder with structured image tokens
+ image_tokens = self._prompt_split_image(num_patches[image_index])
+ new_prompt = new_prompt.replace("", image_tokens, 1)
+ image_index += 1
+ processed_text.append(new_prompt)
+
+ if image_index != len(images):
+ raise ValueError("Number of image placeholders in the prompt does not match the number of images.")
+
+ text = processed_text
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+
+ Args:
+ image_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (height, width) per each image.
+
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = AyaVisionProcessorKwargs._defaults.get("images_kwargs", {})
+ images_kwargs.update(kwargs)
+
+ num_image_patches = [
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
+ for image_size in image_sizes
+ ]
+
+ token_per_patch = (self.img_size // self.patch_size) ** 2
+ num_image_tokens = [
+ token_per_patch + 3 + sum(token_per_patch + 1 for _ in range(1, num_patches))
+ for num_patches in num_image_patches
+ ] # Add +3 and +1 for BOI/EOI and image tile tokens
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+
+ return MultiModalData(**vision_data)
+
+
+__all__ = ["AyaVisionProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/barthez/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/barthez/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..323fe2fe8af9823d4478957b2f94b078ec39b7f3
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/barthez/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .tokenization_barthez import *
+ from .tokenization_barthez_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/barthez/tokenization_barthez.py b/venv/lib/python3.13/site-packages/transformers/models/barthez/tokenization_barthez.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc583e0cd5dc455fbd91841de386e6bf54ad02b9
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/barthez/tokenization_barthez.py
@@ -0,0 +1,291 @@
+# coding=utf-8
+# Copyright 2020 Ecole Polytechnique and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License
+"""Tokenization classes for the BARThez model."""
+
+import os
+from shutil import copyfile
+from typing import Any, Optional
+
+import sentencepiece as spm
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+from ...utils.import_utils import requires
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
+
+
+SPIECE_UNDERLINE = "▁"
+
+# TODO this class is useless. This is the most standard sentencpiece model. Let's find which one is closest and nuke this.
+
+
+@requires(backends=("sentencepiece",))
+class BarthezTokenizer(PreTrainedTokenizer):
+ """
+ Adapted from [`CamembertTokenizer`] and [`BartTokenizer`]. Construct a BARThez tokenizer. Based on
+ [SentencePiece](https://github.com/google/sentencepiece).
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ sp_model_kwargs (`dict`, *optional*):
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+ to set:
+
+ - `enable_sampling`: Enable subword regularization.
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+ - `nbest_size = {0,1}`: No sampling is performed.
+ - `nbest_size > 1`: samples from the nbest_size results.
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+ using forward-filtering-and-backward-sampling algorithm.
+
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+ BPE-dropout.
+
+ Attributes:
+ sp_model (`SentencePieceProcessor`):
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ sp_model_kwargs: Optional[dict[str, Any]] = None,
+ **kwargs,
+ ) -> None:
+ # Mask token behave like a normal word, i.e. include the space before it. Will have normalized=False by default this way
+ mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token
+
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+
+ self.vocab_file = vocab_file
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(str(vocab_file))
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ sp_model_kwargs=self.sp_model_kwargs,
+ **kwargs,
+ )
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BARThez sequence has the following format:
+
+ - single sequence: ` X `
+ - pair of sequences: ` A B `
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ @property
+ def vocab_size(self):
+ return len(self.sp_model)
+
+ def get_vocab(self):
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text: str) -> list[str]:
+ return self.sp_model.encode(text, out_type=str)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.sp_model.PieceToId(token)
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.sp_model.IdToPiece(index)
+
+ # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.convert_tokens_to_string
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ current_sub_tokens = []
+ out_string = ""
+ prev_is_special = False
+ for token in tokens:
+ # make sure that special tokens are not decoded using sentencepiece model
+ if token in self.all_special_tokens:
+ if not prev_is_special:
+ out_string += " "
+ out_string += self.sp_model.decode(current_sub_tokens) + token
+ prev_is_special = True
+ current_sub_tokens = []
+ else:
+ current_sub_tokens.append(token)
+ prev_is_special = False
+ out_string += self.sp_model.decode(current_sub_tokens)
+ return out_string.strip()
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+
+ # for backward compatibility
+ if not hasattr(self, "sp_model_kwargs"):
+ self.sp_model_kwargs = {}
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(self.vocab_file)
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+
+__all__ = ["BarthezTokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/barthez/tokenization_barthez_fast.py b/venv/lib/python3.13/site-packages/transformers/models/barthez/tokenization_barthez_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..64050ca8848f57c25d272e0bc31a3f878040e14c
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/barthez/tokenization_barthez_fast.py
@@ -0,0 +1,193 @@
+# coding=utf-8
+# Copyright 2020 Ecole Polytechnique and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License
+"""Tokenization classes for the BARThez model."""
+
+import os
+from shutil import copyfile
+from typing import Optional
+
+from ...tokenization_utils import AddedToken
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import is_sentencepiece_available, logging
+
+
+if is_sentencepiece_available():
+ from .tokenization_barthez import BarthezTokenizer
+else:
+ BarthezTokenizer = None
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
+
+
+SPIECE_UNDERLINE = "▁"
+
+
+class BarthezTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Adapted from [`CamembertTokenizer`] and [`BartTokenizer`]. Construct a "fast" BARThez tokenizer. Based on
+ [SentencePiece](https://github.com/google/sentencepiece).
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ additional_special_tokens (`list[str]`, *optional*, defaults to `["NOTUSED", "NOTUSED"]`):
+ Additional special tokens used by the tokenizer.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = BarthezTokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ **kwargs,
+ ):
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
+
+ super().__init__(
+ vocab_file,
+ tokenizer_file=tokenizer_file,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ **kwargs,
+ )
+
+ self.vocab_file = vocab_file
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BARThez sequence has the following format:
+
+ - single sequence: ` X `
+ - pair of sequences: ` A B `
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not self.can_save_slow_tokenizer:
+ raise ValueError(
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+ "tokenizer."
+ )
+
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
+
+
+__all__ = ["BarthezTokenizerFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bert_japanese/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/bert_japanese/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5296087db1d007eab946f795d0c9c8fa4bdaafe
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bert_japanese/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .tokenization_bert_japanese import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bert_japanese/tokenization_bert_japanese.py b/venv/lib/python3.13/site-packages/transformers/models/bert_japanese/tokenization_bert_japanese.py
new file mode 100644
index 0000000000000000000000000000000000000000..cacacd87574a28527c5dad2c2c5391c3babc9b8b
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bert_japanese/tokenization_bert_japanese.py
@@ -0,0 +1,952 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes."""
+
+import collections
+import copy
+import os
+import unicodedata
+from typing import Any, Optional
+
+from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
+from ...utils import is_sentencepiece_available, is_sudachi_projection_available, logging
+
+
+if is_sentencepiece_available():
+ import sentencepiece as spm
+else:
+ spm = None
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "spm_file": "spiece.model"}
+
+SPIECE_UNDERLINE = "▁"
+
+
+# Copied from transformers.models.bert.tokenization_bert.load_vocab
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class BertJapaneseTokenizer(PreTrainedTokenizer):
+ r"""
+ Construct a BERT tokenizer for Japanese text.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer
+ to: this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to a one-wordpiece-per-line vocabulary file.
+ spm_file (`str`, *optional*):
+ Path to [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm or .model
+ extension) that contains the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether to lower case the input. Only has an effect when do_basic_tokenize=True.
+ do_word_tokenize (`bool`, *optional*, defaults to `True`):
+ Whether to do word tokenization.
+ do_subword_tokenize (`bool`, *optional*, defaults to `True`):
+ Whether to do subword tokenization.
+ word_tokenizer_type (`str`, *optional*, defaults to `"basic"`):
+ Type of word tokenizer. Choose from ["basic", "mecab", "sudachi", "jumanpp"].
+ subword_tokenizer_type (`str`, *optional*, defaults to `"wordpiece"`):
+ Type of subword tokenizer. Choose from ["wordpiece", "character", "sentencepiece",].
+ mecab_kwargs (`dict`, *optional*):
+ Dictionary passed to the `MecabTokenizer` constructor.
+ sudachi_kwargs (`dict`, *optional*):
+ Dictionary passed to the `SudachiTokenizer` constructor.
+ jumanpp_kwargs (`dict`, *optional*):
+ Dictionary passed to the `JumanppTokenizer` constructor.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+
+ def __init__(
+ self,
+ vocab_file,
+ spm_file=None,
+ do_lower_case=False,
+ do_word_tokenize=True,
+ do_subword_tokenize=True,
+ word_tokenizer_type="basic",
+ subword_tokenizer_type="wordpiece",
+ never_split=None,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ mecab_kwargs=None,
+ sudachi_kwargs=None,
+ jumanpp_kwargs=None,
+ **kwargs,
+ ):
+ if subword_tokenizer_type == "sentencepiece":
+ if not os.path.isfile(spm_file):
+ raise ValueError(
+ f"Can't find a vocabulary file at path '{spm_file}'. To load the vocabulary from a Google"
+ " pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ self.spm_file = spm_file
+ else:
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google"
+ " pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
+
+ self.do_word_tokenize = do_word_tokenize
+ self.word_tokenizer_type = word_tokenizer_type
+ self.lower_case = do_lower_case
+ self.never_split = never_split
+ self.mecab_kwargs = copy.deepcopy(mecab_kwargs)
+ self.sudachi_kwargs = copy.deepcopy(sudachi_kwargs)
+ self.jumanpp_kwargs = copy.deepcopy(jumanpp_kwargs)
+ if do_word_tokenize:
+ if word_tokenizer_type == "basic":
+ self.word_tokenizer = BasicTokenizer(
+ do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=False
+ )
+ elif word_tokenizer_type == "mecab":
+ self.word_tokenizer = MecabTokenizer(
+ do_lower_case=do_lower_case, never_split=never_split, **(mecab_kwargs or {})
+ )
+ elif word_tokenizer_type == "sudachi":
+ self.word_tokenizer = SudachiTokenizer(
+ do_lower_case=do_lower_case, never_split=never_split, **(sudachi_kwargs or {})
+ )
+ elif word_tokenizer_type == "jumanpp":
+ self.word_tokenizer = JumanppTokenizer(
+ do_lower_case=do_lower_case, never_split=never_split, **(jumanpp_kwargs or {})
+ )
+ else:
+ raise ValueError(f"Invalid word_tokenizer_type '{word_tokenizer_type}' is specified.")
+
+ self.do_subword_tokenize = do_subword_tokenize
+ self.subword_tokenizer_type = subword_tokenizer_type
+ if do_subword_tokenize:
+ if subword_tokenizer_type == "wordpiece":
+ self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
+ elif subword_tokenizer_type == "character":
+ self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=str(unk_token))
+ elif subword_tokenizer_type == "sentencepiece":
+ self.subword_tokenizer = SentencepieceTokenizer(vocab=self.spm_file, unk_token=str(unk_token))
+ else:
+ raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.")
+ super().__init__(
+ spm_file=spm_file,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ do_lower_case=do_lower_case,
+ do_word_tokenize=do_word_tokenize,
+ do_subword_tokenize=do_subword_tokenize,
+ word_tokenizer_type=word_tokenizer_type,
+ subword_tokenizer_type=subword_tokenizer_type,
+ never_split=never_split,
+ mecab_kwargs=mecab_kwargs,
+ sudachi_kwargs=sudachi_kwargs,
+ jumanpp_kwargs=jumanpp_kwargs,
+ **kwargs,
+ )
+
+ @property
+ def do_lower_case(self):
+ return self.lower_case
+
+ def __getstate__(self):
+ state = dict(self.__dict__)
+ if self.word_tokenizer_type in ["mecab", "sudachi", "jumanpp"]:
+ del state["word_tokenizer"]
+ return state
+
+ def __setstate__(self, state):
+ self.__dict__ = state
+ if self.word_tokenizer_type == "mecab":
+ self.word_tokenizer = MecabTokenizer(
+ do_lower_case=self.do_lower_case, never_split=self.never_split, **(self.mecab_kwargs or {})
+ )
+ elif self.word_tokenizer_type == "sudachi":
+ self.word_tokenizer = SudachiTokenizer(
+ do_lower_case=self.do_lower_case, never_split=self.never_split, **(self.sudachi_kwargs or {})
+ )
+ elif self.word_tokenizer_type == "jumanpp":
+ self.word_tokenizer = JumanppTokenizer(
+ do_lower_case=self.do_lower_case, never_split=self.never_split, **(self.jumanpp_kwargs or {})
+ )
+
+ def _tokenize(self, text):
+ if self.do_word_tokenize:
+ tokens = self.word_tokenizer.tokenize(text, never_split=self.all_special_tokens)
+ else:
+ tokens = [text]
+
+ if self.do_subword_tokenize:
+ split_tokens = [sub_token for token in tokens for sub_token in self.subword_tokenizer.tokenize(token)]
+ else:
+ split_tokens = tokens
+
+ return split_tokens
+
+ @property
+ def vocab_size(self):
+ if self.subword_tokenizer_type == "sentencepiece":
+ return len(self.subword_tokenizer.sp_model)
+ return len(self.vocab)
+
+ def get_vocab(self):
+ if self.subword_tokenizer_type == "sentencepiece":
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ if self.subword_tokenizer_type == "sentencepiece":
+ return self.subword_tokenizer.sp_model.PieceToId(token)
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ if self.subword_tokenizer_type == "sentencepiece":
+ return self.subword_tokenizer.sp_model.IdToPiece(index)
+ return self.ids_to_tokens.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ if self.subword_tokenizer_type == "sentencepiece":
+ return self.subword_tokenizer.sp_model.decode(tokens)
+ out_string = " ".join(tokens).replace(" ##", "").strip()
+ return out_string
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.build_inputs_with_special_tokens
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if os.path.isdir(save_directory):
+ if self.subword_tokenizer_type == "sentencepiece":
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["spm_file"]
+ )
+ else:
+ vocab_file = os.path.join(
+ save_directory,
+ (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"],
+ )
+ else:
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
+
+ if self.subword_tokenizer_type == "sentencepiece":
+ with open(vocab_file, "wb") as writer:
+ content_spiece_model = self.subword_tokenizer.sp_model.serialized_model_proto()
+ writer.write(content_spiece_model)
+ else:
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ index = 0
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!"
+ )
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+ return (vocab_file,)
+
+
+class MecabTokenizer:
+ """Runs basic tokenization with MeCab morphological parser."""
+
+ def __init__(
+ self,
+ do_lower_case=False,
+ never_split=None,
+ normalize_text=True,
+ mecab_dic: Optional[str] = "unidic_lite",
+ mecab_option: Optional[str] = None,
+ ):
+ """
+ Constructs a MecabTokenizer.
+
+ Args:
+ **do_lower_case**: (*optional*) boolean (default True)
+ Whether to lowercase the input.
+ **never_split**: (*optional*) list of str
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
+ [`PreTrainedTokenizer.tokenize`]) List of tokens not to split.
+ **normalize_text**: (*optional*) boolean (default True)
+ Whether to apply unicode normalization to text before tokenization.
+ **mecab_dic**: (*optional*) string (default "ipadic")
+ Name of dictionary to be used for MeCab initialization. If you are using a system-installed dictionary,
+ set this option to `None` and modify *mecab_option*.
+ **mecab_option**: (*optional*) string
+ String passed to MeCab constructor.
+ """
+ self.do_lower_case = do_lower_case
+ self.never_split = never_split if never_split is not None else []
+ self.normalize_text = normalize_text
+
+ try:
+ import fugashi
+ except ModuleNotFoundError as error:
+ raise error.__class__(
+ "You need to install fugashi to use MecabTokenizer. "
+ "See https://pypi.org/project/fugashi/ for installation."
+ )
+
+ mecab_option = mecab_option or ""
+
+ if mecab_dic is not None:
+ if mecab_dic == "ipadic":
+ try:
+ import ipadic
+ except ModuleNotFoundError as error:
+ raise error.__class__(
+ "The ipadic dictionary is not installed. "
+ "See https://github.com/polm/ipadic-py for installation."
+ )
+
+ dic_dir = ipadic.DICDIR
+
+ elif mecab_dic == "unidic_lite":
+ try:
+ import unidic_lite
+ except ModuleNotFoundError as error:
+ raise error.__class__(
+ "The unidic_lite dictionary is not installed. "
+ "See https://github.com/polm/unidic-lite for installation."
+ )
+
+ dic_dir = unidic_lite.DICDIR
+
+ elif mecab_dic == "unidic":
+ try:
+ import unidic
+ except ModuleNotFoundError as error:
+ raise error.__class__(
+ "The unidic dictionary is not installed. "
+ "See https://github.com/polm/unidic-py for installation."
+ )
+
+ dic_dir = unidic.DICDIR
+ if not os.path.isdir(dic_dir):
+ raise RuntimeError(
+ "The unidic dictionary itself is not found. "
+ "See https://github.com/polm/unidic-py for installation."
+ )
+
+ else:
+ raise ValueError("Invalid mecab_dic is specified.")
+
+ mecabrc = os.path.join(dic_dir, "mecabrc")
+ mecab_option = f'-d "{dic_dir}" -r "{mecabrc}" ' + mecab_option
+
+ self.mecab = fugashi.GenericTagger(mecab_option)
+
+ def tokenize(self, text, never_split=None, **kwargs):
+ """Tokenizes a piece of text."""
+ if self.normalize_text:
+ text = unicodedata.normalize("NFKC", text)
+
+ never_split = self.never_split + (never_split if never_split is not None else [])
+ tokens = []
+
+ for word in self.mecab(text):
+ token = word.surface
+
+ if self.do_lower_case and token not in never_split:
+ token = token.lower()
+
+ tokens.append(token)
+
+ return tokens
+
+
+class SudachiTokenizer:
+ """Runs basic tokenization with Sudachi morphological parser."""
+
+ def __init__(
+ self,
+ do_lower_case=False,
+ never_split=None,
+ normalize_text=True,
+ trim_whitespace=False,
+ sudachi_split_mode="A",
+ sudachi_config_path=None,
+ sudachi_resource_dir=None,
+ sudachi_dict_type="core",
+ sudachi_projection=None,
+ ):
+ """
+ Constructs a SudachiTokenizer.
+
+ Args:
+ **do_lower_case**: (*optional*) boolean (default True)
+ Whether to lowercase the input.
+ **never_split**: (*optional*) list of str
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
+ [`PreTrainedTokenizer.tokenize`]) List of tokens not to split.
+ **normalize_text**: (*optional*) boolean (default True)
+ Whether to apply unicode normalization to text before tokenization.
+ **trim_whitespace**: (*optional*) boolean (default False)
+ Whether to trim all whitespace, tab, newline from tokens.
+ **sudachi_split_mode**: (*optional*) string
+ Split mode of sudachi, choose from `["A", "B", "C"]`.
+ **sudachi_config_path**: (*optional*) string
+ **sudachi_resource_dir**: (*optional*) string
+ **sudachi_dict_type**: (*optional*) string
+ dict type of sudachi, choose from `["small", "core", "full"]`.
+ **sudachi_projection**: (*optional*) string
+ Word projection mode of sudachi, choose from `["surface", "normalized", "reading", "dictionary", "dictionary_and_surface", "normalized_and_surface", "normalized_nouns"]`.
+ """
+
+ self.do_lower_case = do_lower_case
+ self.never_split = never_split if never_split is not None else []
+ self.normalize_text = normalize_text
+ self.trim_whitespace = trim_whitespace
+
+ try:
+ from sudachipy import dictionary, tokenizer
+ except ImportError:
+ raise ImportError(
+ "You need to install sudachipy to use SudachiTokenizer. "
+ "See https://github.com/WorksApplications/SudachiPy for installation."
+ )
+
+ if sudachi_split_mode == "A":
+ self.split_mode = tokenizer.Tokenizer.SplitMode.A
+ elif sudachi_split_mode == "B":
+ self.split_mode = tokenizer.Tokenizer.SplitMode.B
+ elif sudachi_split_mode == "C":
+ self.split_mode = tokenizer.Tokenizer.SplitMode.C
+ else:
+ raise ValueError("Invalid sudachi_split_mode is specified.")
+
+ self.projection = sudachi_projection
+
+ sudachi_dictionary = dictionary.Dictionary(
+ config_path=sudachi_config_path, resource_dir=sudachi_resource_dir, dict=sudachi_dict_type
+ )
+ if is_sudachi_projection_available():
+ self.sudachi = sudachi_dictionary.create(self.split_mode, projection=self.projection)
+ elif self.projection is not None:
+ raise ImportError("You need to install sudachipy>=0.6.8 to specify `projection` field in sudachi_kwargs.")
+ else:
+ self.sudachi = sudachi_dictionary.create(self.split_mode)
+
+ def tokenize(self, text, never_split=None, **kwargs):
+ """Tokenizes a piece of text."""
+ if self.normalize_text:
+ text = unicodedata.normalize("NFKC", text)
+
+ never_split = self.never_split + (never_split if never_split is not None else [])
+ tokens = []
+
+ for word in self.sudachi.tokenize(text):
+ token = word.surface()
+
+ if self.do_lower_case and token not in never_split:
+ token = token.lower()
+
+ if self.trim_whitespace:
+ if token.strip() == "":
+ continue
+ else:
+ token = token.strip()
+
+ tokens.append(token)
+
+ return tokens
+
+
+class JumanppTokenizer:
+ """Runs basic tokenization with jumanpp morphological parser."""
+
+ def __init__(
+ self,
+ do_lower_case=False,
+ never_split=None,
+ normalize_text=True,
+ trim_whitespace=False,
+ ):
+ """
+ Constructs a JumanppTokenizer.
+
+ Args:
+ **do_lower_case**: (*optional*) boolean (default True)
+ Whether to lowercase the input.
+ **never_split**: (*optional*) list of str
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
+ [`PreTrainedTokenizer.tokenize`]) List of tokens not to split.
+ **normalize_text**: (*optional*) boolean (default True)
+ Whether to apply unicode normalization to text before tokenization.
+ **trim_whitespace**: (*optional*) boolean (default False)
+ Whether to trim all whitespace, tab, newline from tokens.
+ """
+
+ self.do_lower_case = do_lower_case
+ self.never_split = never_split if never_split is not None else []
+ self.normalize_text = normalize_text
+ self.trim_whitespace = trim_whitespace
+
+ try:
+ import rhoknp
+ except ImportError:
+ raise ImportError(
+ "You need to install rhoknp to use JumanppTokenizer. "
+ "See https://github.com/ku-nlp/rhoknp for installation."
+ )
+
+ self.juman = rhoknp.Jumanpp()
+
+ def tokenize(self, text, never_split=None, **kwargs):
+ """Tokenizes a piece of text."""
+ if self.normalize_text:
+ text = unicodedata.normalize("NFKC", text)
+
+ text = text.strip()
+
+ never_split = self.never_split + (never_split if never_split is not None else [])
+ tokens = []
+
+ for mrph in self.juman.apply_to_sentence(text).morphemes:
+ token = mrph.text
+
+ if self.do_lower_case and token not in never_split:
+ token = token.lower()
+
+ if self.trim_whitespace:
+ if token.strip() == "":
+ continue
+ else:
+ token = token.strip()
+
+ tokens.append(token)
+
+ return tokens
+
+
+class CharacterTokenizer:
+ """Runs Character tokenization."""
+
+ def __init__(self, vocab, unk_token, normalize_text=True):
+ """
+ Constructs a CharacterTokenizer.
+
+ Args:
+ **vocab**:
+ Vocabulary object.
+ **unk_token**: str
+ A special symbol for out-of-vocabulary token.
+ **normalize_text**: (`optional`) boolean (default True)
+ Whether to apply unicode normalization to text before tokenization.
+ """
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.normalize_text = normalize_text
+
+ def tokenize(self, text):
+ """
+ Tokenizes a piece of text into characters.
+
+ For example, `input = "apple""` will return as output `["a", "p", "p", "l", "e"]`.
+
+ Args:
+ text: A single token or whitespace separated tokens.
+ This should have already been passed through *BasicTokenizer*.
+
+ Returns:
+ A list of characters.
+ """
+ if self.normalize_text:
+ text = unicodedata.normalize("NFKC", text)
+
+ output_tokens = []
+ for char in text:
+ if char not in self.vocab:
+ output_tokens.append(self.unk_token)
+ continue
+
+ output_tokens.append(char)
+
+ return output_tokens
+
+
+# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
+class BasicTokenizer:
+ """
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
+
+ Args:
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ do_split_on_punc (`bool`, *optional*, defaults to `True`):
+ In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
+ the full context of the words, such as contractions.
+ """
+
+ def __init__(
+ self,
+ do_lower_case=True,
+ never_split=None,
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ do_split_on_punc=True,
+ ):
+ if never_split is None:
+ never_split = []
+ self.do_lower_case = do_lower_case
+ self.never_split = set(never_split)
+ self.tokenize_chinese_chars = tokenize_chinese_chars
+ self.strip_accents = strip_accents
+ self.do_split_on_punc = do_split_on_punc
+
+ def tokenize(self, text, never_split=None):
+ """
+ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
+
+ Args:
+ never_split (`List[str]`, *optional*)
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
+ """
+ # union() returns a new set by concatenating the two sets.
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
+ text = self._clean_text(text)
+
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ if self.tokenize_chinese_chars:
+ text = self._tokenize_chinese_chars(text)
+ # prevents treating the same character with different unicode codepoints as different characters
+ unicode_normalized_text = unicodedata.normalize("NFC", text)
+ orig_tokens = whitespace_tokenize(unicode_normalized_text)
+ split_tokens = []
+ for token in orig_tokens:
+ if token not in never_split:
+ if self.do_lower_case:
+ token = token.lower()
+ if self.strip_accents is not False:
+ token = self._run_strip_accents(token)
+ elif self.strip_accents:
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text, never_split=None):
+ """Splits punctuation on a piece of text."""
+ if not self.do_split_on_punc or (never_split is not None and text in never_split):
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if (
+ (cp >= 0x4E00 and cp <= 0x9FFF)
+ or (cp >= 0x3400 and cp <= 0x4DBF)
+ or (cp >= 0x20000 and cp <= 0x2A6DF)
+ or (cp >= 0x2A700 and cp <= 0x2B73F)
+ or (cp >= 0x2B740 and cp <= 0x2B81F)
+ or (cp >= 0x2B820 and cp <= 0x2CEAF)
+ or (cp >= 0xF900 and cp <= 0xFAFF)
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
+ ):
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
+class WordpieceTokenizer:
+ """Runs WordPiece tokenization."""
+
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
+ tokenization using the given vocabulary.
+
+ For example, `input = "unaffable"` will return as output `["un", "##aff", "##able"]`.
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through *BasicTokenizer*.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ return output_tokens
+
+
+class SentencepieceTokenizer:
+ """
+ Runs sentencepiece tokenization. Based on transformers.models.albert.tokenization_albert.AlbertTokenizer.
+ """
+
+ def __init__(
+ self,
+ vocab,
+ unk_token,
+ do_lower_case=False,
+ remove_space=True,
+ keep_accents=True,
+ sp_model_kwargs: Optional[dict[str, Any]] = None,
+ ):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.do_lower_case = do_lower_case
+ self.remove_space = remove_space
+ self.keep_accents = keep_accents
+
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(self.vocab)
+
+ def preprocess_text(self, inputs):
+ if self.remove_space:
+ outputs = " ".join(inputs.strip().split())
+ else:
+ outputs = inputs
+ outputs = outputs.replace("``", '"').replace("''", '"')
+
+ if not self.keep_accents:
+ outputs = unicodedata.normalize("NFKD", outputs)
+ outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
+ if self.do_lower_case:
+ outputs = outputs.lower()
+
+ return outputs
+
+ def tokenize(self, text):
+ """
+ Tokenizes text by sentencepiece. Based on [SentencePiece](https://github.com/google/sentencepiece).
+ Tokenization needs the given vocabulary.
+
+ Args:
+ text: A string needs to be tokenized.
+
+ Returns:
+ A list of sentencepiece tokens.
+ """
+ text = self.preprocess_text(text)
+ pieces = self.sp_model.encode(text, out_type=str)
+ new_pieces = []
+ for piece in pieces:
+ if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
+ cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
+ if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
+ if len(cur_pieces[0]) == 1:
+ cur_pieces = cur_pieces[1:]
+ else:
+ cur_pieces[0] = cur_pieces[0][1:]
+ cur_pieces.append(piece[-1])
+ new_pieces.extend(cur_pieces)
+ else:
+ new_pieces.append(piece)
+
+ return new_pieces
+
+
+__all__ = ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bertweet/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/bertweet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..432622f1595d1a0d8bb1b3c9b9774b7d1e387d3e
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bertweet/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .tokenization_bertweet import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bertweet/tokenization_bertweet.py b/venv/lib/python3.13/site-packages/transformers/models/bertweet/tokenization_bertweet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ce1a3182bf9d5b3f960b5a211544612ab3129c3
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bertweet/tokenization_bertweet.py
@@ -0,0 +1,769 @@
+# coding=utf-8
+# Copyright (c) 2020, VinAI Research and the HuggingFace Inc. team.
+# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for BERTweet"""
+
+import html
+import os
+import re
+from shutil import copyfile
+from typing import Optional
+
+import regex
+
+from ...tokenization_utils import PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+ "vocab_file": "vocab.txt",
+ "merges_file": "bpe.codes",
+}
+
+
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+
+ pairs = set(pairs)
+ return pairs
+
+
+class BertweetTokenizer(PreTrainedTokenizer):
+ """
+ Constructs a BERTweet tokenizer, using Byte-Pair-Encoding.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ normalization (`bool`, *optional*, defaults to `False`):
+ Whether or not to apply a normalization preprocess.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ normalization=False,
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ **kwargs,
+ ):
+ try:
+ from emoji import demojize
+
+ self.demojizer = demojize
+ except ImportError:
+ logger.warning(
+ "emoji is not installed, thus not converting emoticons or emojis into text. Install emoji: pip3"
+ " install emoji==0.6.0"
+ )
+ self.demojizer = None
+
+ self.vocab_file = vocab_file
+ self.merges_file = merges_file
+
+ self.encoder = {}
+ self.encoder[str(bos_token)] = 0
+ self.encoder[str(pad_token)] = 1
+ self.encoder[str(eos_token)] = 2
+ self.encoder[str(unk_token)] = 3
+
+ self.add_from_file(vocab_file)
+
+ self.decoder = {v: k for k, v in self.encoder.items()}
+
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ merges = merges_handle.read().split("\n")[:-1]
+ merges = [tuple(merge.split()[:-1]) for merge in merges]
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {}
+
+ self.normalization = normalization
+ self.tweetPreprocessor = TweetTokenizer()
+ self.special_puncts = {"’": "'", "…": "..."}
+
+ super().__init__(
+ normalization=normalization,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ **kwargs,
+ )
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERTweet sequence has the following format:
+
+ - single sequence: ` X `
+ - pair of sequences: ` A B `
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. BERTweet does
+ not make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of zeros.
+ """
+
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ @property
+ def vocab_size(self):
+ return len(self.encoder)
+
+ def get_vocab(self):
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ word = tuple(list(word[:-1]) + [word[-1] + ""])
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = "@@ ".join(word)
+ word = word[:-4]
+ self.cache[token] = word
+ return word
+
+ def _tokenize(self, text):
+ """Tokenize a string."""
+ if self.normalization: # Perform Tweet normalization before performing BPE
+ text = self.normalizeTweet(text)
+
+ split_tokens = []
+ words = re.findall(r"\S+\n?", text)
+ for token in words:
+ split_tokens.extend(list(self.bpe(token).split(" ")))
+ return split_tokens
+
+ def normalizeTweet(self, tweet):
+ """
+ Normalize a raw Tweet
+ """
+ for punct in self.special_puncts:
+ tweet = tweet.replace(punct, self.special_puncts[punct])
+
+ tokens = self.tweetPreprocessor.tokenize(tweet)
+ normTweet = " ".join([self.normalizeToken(token) for token in tokens])
+
+ normTweet = (
+ normTweet.replace("cannot ", "can not ")
+ .replace("n't ", " n't ")
+ .replace("n 't ", " n't ")
+ .replace("ca n't", "can't")
+ .replace("ai n't", "ain't")
+ )
+ normTweet = (
+ normTweet.replace("'m ", " 'm ")
+ .replace("'re ", " 're ")
+ .replace("'s ", " 's ")
+ .replace("'ll ", " 'll ")
+ .replace("'d ", " 'd ")
+ .replace("'ve ", " 've ")
+ )
+ normTweet = (
+ normTweet.replace(" p . m .", " p.m.")
+ .replace(" p . m ", " p.m ")
+ .replace(" a . m .", " a.m.")
+ .replace(" a . m ", " a.m ")
+ )
+
+ return " ".join(normTweet.split())
+
+ def normalizeToken(self, token):
+ """
+ Normalize tokens in a Tweet
+ """
+ lowercased_token = token.lower()
+ if token.startswith("@"):
+ return "@USER"
+ elif lowercased_token.startswith("http") or lowercased_token.startswith("www"):
+ return "HTTPURL"
+ elif len(token) == 1:
+ if token in self.special_puncts:
+ return self.special_puncts[token]
+ if self.demojizer is not None:
+ return self.demojizer(token)
+ else:
+ return token
+ else:
+ return token
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ out_string = " ".join(tokens).replace("@@ ", "").strip()
+ return out_string
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ out_merge_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ if os.path.abspath(self.merges_file) != os.path.abspath(out_merge_file):
+ copyfile(self.merges_file, out_merge_file)
+
+ return out_vocab_file, out_merge_file
+
+ # def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
+ # filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens))
+ # tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens)
+ # tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far)
+ # return ''.join(tokens_generated_so_far)
+
+ def add_from_file(self, f):
+ """
+ Loads a pre-existing dictionary from a text file and adds its symbols to this instance.
+ """
+ if isinstance(f, str):
+ try:
+ with open(f, "r", encoding="utf-8") as fd:
+ self.add_from_file(fd)
+ except FileNotFoundError as fnfe:
+ raise fnfe
+ except UnicodeError:
+ raise Exception(f"Incorrect encoding detected in {f}, please rebuild the dataset")
+ return
+
+ lines = f.readlines()
+ for lineTmp in lines:
+ line = lineTmp.strip()
+ idx = line.rfind(" ")
+ if idx == -1:
+ raise ValueError("Incorrect dictionary format, expected ''")
+ word = line[:idx]
+ self.encoder[word] = len(self.encoder)
+
+
+# Natural Language Toolkit: Twitter Tokenizer
+#
+# Copyright (C) 2001-2020 NLTK Project
+# Author: Christopher Potts
+# Ewan Klein (modifications)
+# Pierpaolo Pantone <> (modifications)
+# URL: http://nltk.org/
+# For license information, see LICENSE.TXT
+#
+
+
+"""
+Twitter-aware tokenizer, designed to be flexible and easy to adapt to new domains and tasks. The basic logic is this:
+
+1. The tuple regex_strings defines a list of regular expression strings.
+
+2. The regex_strings strings are put, in order, into a compiled regular expression object called word_re.
+
+3. The tokenization is done by word_re.findall(s), where s is the user-supplied string, inside the tokenize() method of
+ the class Tokenizer.
+
+4. When instantiating Tokenizer objects, there is a single option: preserve_case. By default, it is set to True. If it
+ is set to False, then the tokenizer will lowercase everything except for emoticons.
+
+"""
+
+
+######################################################################
+#
+# import regex # https://github.com/nltk/nltk/issues/2409
+# import html
+#
+######################################################################
+# The following strings are components in the regular expression
+# that is used for tokenizing. It's important that phone_number
+# appears first in the final regex (since it can contain whitespace).
+# It also could matter that tags comes after emoticons, due to the
+# possibility of having text like
+#
+# <:| and some text >:)
+#
+# Most importantly, the final element should always be last, since it
+# does a last ditch whitespace-based tokenization of whatever is left.
+
+# ToDo: Update with http://en.wikipedia.org/wiki/List_of_emoticons ?
+
+# This particular element is used in a couple ways, so we define it
+# with a name:
+# docstyle-ignore
+EMOTICONS = r"""
+ (?:
+ [<>]?
+ [:;=8] # eyes
+ [\-o\*\']? # optional nose
+ [\)\]\(\[dDpP/\:\}\{@\|\\] # mouth
+ |
+ [\)\]\(\[dDpP/\:\}\{@\|\\] # mouth
+ [\-o\*\']? # optional nose
+ [:;=8] # eyes
+ [<>]?
+ |
+ <3 # heart
+ )"""
+
+# URL pattern due to John Gruber, modified by Tom Winzig. See
+# https://gist.github.com/winzig/8894715
+# docstyle-ignore
+URLS = r""" # Capture 1: entire matched URL
+ (?:
+ https?: # URL protocol and colon
+ (?:
+ /{1,3} # 1-3 slashes
+ | # or
+ [a-z0-9%] # Single letter or digit or '%'
+ # (Trying not to match e.g. "URI::Escape")
+ )
+ | # or
+ # looks like domain name followed by a slash:
+ [a-z0-9.\-]+[.]
+ (?:[a-z]{2,13})
+ /
+ )
+ (?: # One or more:
+ [^\s()<>{}\[\]]+ # Run of non-space, non-()<>{}[]
+ | # or
+ \([^\s()]*?\([^\s()]+\)[^\s()]*?\) # balanced parens, one level deep: (...(...)...)
+ |
+ \([^\s]+?\) # balanced parens, non-recursive: (...)
+ )+
+ (?: # End with:
+ \([^\s()]*?\([^\s()]+\)[^\s()]*?\) # balanced parens, one level deep: (...(...)...)
+ |
+ \([^\s]+?\) # balanced parens, non-recursive: (...)
+ | # or
+ [^\s`!()\[\]{};:'".,<>?«»“”‘’] # not a space or one of these punct chars
+ )
+ | # OR, the following to match naked domains:
+ (?:
+ (?\s]+>""",
+ # ASCII Arrows
+ r"""[\-]+>|<[\-]+""",
+ # Twitter username:
+ r"""(?:@[\w_]+)""",
+ # Twitter hashtags:
+ r"""(?:\#+[\w_]+[\w\'_\-]*[\w_]+)""",
+ # email addresses
+ r"""[\w.+-]+@[\w-]+\.(?:[\w-]\.?)+[\w-]""",
+ # docstyle-ignore
+ # Remaining word types:
+ r"""
+ (?:[^\W\d_](?:[^\W\d_]|['\-_])+[^\W\d_]) # Words with apostrophes or dashes.
+ |
+ (?:[+\-]?\d+[,/.:-]\d+[+\-]?) # Numbers, including fractions, decimals.
+ |
+ (?:[\w_]+) # Words without apostrophes or dashes.
+ |
+ (?:\.(?:\s*\.){1,}) # Ellipsis dots.
+ |
+ (?:\S) # Everything else that isn't whitespace.
+ """,
+)
+
+######################################################################
+# This is the core tokenizing regex:
+
+WORD_RE = regex.compile(r"""(%s)""" % "|".join(REGEXPS), regex.VERBOSE | regex.I | regex.UNICODE)
+
+# WORD_RE performs poorly on these patterns:
+HANG_RE = regex.compile(r"([^a-zA-Z0-9])\1{3,}")
+
+# The emoticon string gets its own regex so that we can preserve case for
+# them as needed:
+EMOTICON_RE = regex.compile(EMOTICONS, regex.VERBOSE | regex.I | regex.UNICODE)
+
+# These are for regularizing HTML entities to Unicode:
+ENT_RE = regex.compile(r"&(#?(x?))([^&;\s]+);")
+
+
+######################################################################
+# Functions for converting html entities
+######################################################################
+
+
+def _str_to_unicode(text, encoding=None, errors="strict"):
+ if encoding is None:
+ encoding = "utf-8"
+ if isinstance(text, bytes):
+ return text.decode(encoding, errors)
+ return text
+
+
+def _replace_html_entities(text, keep=(), remove_illegal=True, encoding="utf-8"):
+ """
+ Remove entities from text by converting them to their corresponding unicode character.
+
+ Args:
+ text:
+ A unicode string or a byte string encoded in the given *encoding* (which defaults to 'utf-8').
+ keep (list):
+ List of entity names which should not be replaced. This supports both numeric entities (`nnnn;` and
+ `hhhh;`) and named entities (such as ` ` or `>`).
+ remove_illegal (bool):
+ If `True`, entities that can't be converted are removed. Otherwise, entities that can't be converted are
+ kept "as is".
+
+ Returns: A unicode string with the entities removed.
+
+ See https://github.com/scrapy/w3lib/blob/master/w3lib/html.py
+
+ Examples:
+
+ ```python
+ >>> from nltk.tokenize.casual import _replace_html_entities
+
+ >>> _replace_html_entities(b"Price: £100")
+ 'Price: \\xa3100'
+
+ >>> print(_replace_html_entities(b"Price: £100"))
+ Price: £100
+ ```"""
+
+ def _convert_entity(match):
+ entity_body = match.group(3)
+ if match.group(1):
+ try:
+ if match.group(2):
+ number = int(entity_body, 16)
+ else:
+ number = int(entity_body, 10)
+ # Numeric character references in the 80-9F range are typically
+ # interpreted by browsers as representing the characters mapped
+ # to bytes 80-9F in the Windows-1252 encoding. For more info
+ # see: https://en.wikipedia.org/wiki/ISO/IEC_8859-1#Similar_character_sets
+ if 0x80 <= number <= 0x9F:
+ return bytes((number,)).decode("cp1252")
+ except ValueError:
+ number = None
+ else:
+ if entity_body in keep:
+ return match.group(0)
+ else:
+ number = html.entities.name2codepoint.get(entity_body)
+ if number is not None:
+ try:
+ return chr(number)
+ except (ValueError, OverflowError):
+ pass
+
+ return "" if remove_illegal else match.group(0)
+
+ return ENT_RE.sub(_convert_entity, _str_to_unicode(text, encoding))
+
+
+######################################################################
+
+
+class TweetTokenizer:
+ r"""
+ Examples:
+
+ ```python
+ >>> # Tokenizer for tweets.
+ >>> from nltk.tokenize import TweetTokenizer
+
+ >>> tknzr = TweetTokenizer()
+ >>> s0 = "This is a cooool #dummysmiley: :-) :-P <3 and some arrows < > -> <--"
+ >>> tknzr.tokenize(s0)
+ ['This', 'is', 'a', 'cooool', '#dummysmiley', ':', ':-)', ':-P', '<3', 'and', 'some', 'arrows', '<', '>', '->', '<--']
+
+ >>> # Examples using *strip_handles* and *reduce_len parameters*:
+ >>> tknzr = TweetTokenizer(strip_handles=True, reduce_len=True)
+ >>> s1 = "@remy: This is waaaaayyyy too much for you!!!!!!"
+ >>> tknzr.tokenize(s1)
+ [':', 'This', 'is', 'waaayyy', 'too', 'much', 'for', 'you', '!', '!', '!']
+ ```"""
+
+ def __init__(self, preserve_case=True, reduce_len=False, strip_handles=False):
+ self.preserve_case = preserve_case
+ self.reduce_len = reduce_len
+ self.strip_handles = strip_handles
+
+ def tokenize(self, text):
+ """
+ Args:
+ text: str
+
+ Returns: list(str) A tokenized list of strings; concatenating this list returns the original string if
+ `preserve_case=False`
+ """
+ # Fix HTML character entities:
+ text = _replace_html_entities(text)
+ # Remove username handles
+ if self.strip_handles:
+ text = remove_handles(text)
+ # Normalize word lengthening
+ if self.reduce_len:
+ text = reduce_lengthening(text)
+ # Shorten problematic sequences of characters
+ safe_text = HANG_RE.sub(r"\1\1\1", text)
+ # Tokenize:
+ words = WORD_RE.findall(safe_text)
+ # Possibly alter the case, but avoid changing emoticons like :D into :d:
+ if not self.preserve_case:
+ words = [x if EMOTICON_RE.search(x) else x.lower() for x in words]
+ return words
+
+
+######################################################################
+# Normalization Functions
+######################################################################
+
+
+def reduce_lengthening(text):
+ """
+ Replace repeated character sequences of length 3 or greater with sequences of length 3.
+ """
+ pattern = regex.compile(r"(.)\1{2,}")
+ return pattern.sub(r"\1\1\1", text)
+
+
+def remove_handles(text):
+ """
+ Remove Twitter username handles from text.
+ """
+ pattern = regex.compile(
+ r"(?>> from transformers import BioGptModel, BioGptConfig
+
+ >>> # Initializing a BioGPT microsoft/biogpt style configuration
+ >>> configuration = BioGptConfig()
+
+ >>> # Initializing a model from the microsoft/biogpt style configuration
+ >>> model = BioGptModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "biogpt"
+
+ def __init__(
+ self,
+ vocab_size=42384,
+ hidden_size=1024,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ intermediate_size=4096,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=1024,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ scale_embedding=True,
+ use_cache=True,
+ layerdrop=0.0,
+ activation_dropout=0.0,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.scale_embedding = scale_embedding
+ self.use_cache = use_cache
+ self.layerdrop = layerdrop
+ self.activation_dropout = activation_dropout
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+
+__all__ = ["BioGptConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/biogpt/modeling_biogpt.py b/venv/lib/python3.13/site-packages/transformers/models/biogpt/modeling_biogpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b9937420025f86dc74c73b12060d28d7beec627
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/biogpt/modeling_biogpt.py
@@ -0,0 +1,967 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/biogpt/modular_biogpt.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_biogpt.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available, logging
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_biogpt import BioGptConfig
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+class BioGptLearnedPositionalEmbedding(nn.Embedding):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int):
+ # BIOGPT is set up so that if padding_idx is specified then offset the embedding ids by 2
+ # and adjust num_embeddings appropriately. Other models don't have this hack
+ self.offset = 2
+ super().__init__(num_embeddings + self.offset, embedding_dim)
+
+ def forward(
+ self,
+ attention_mask: torch.LongTensor,
+ past_key_values_length: int = 0,
+ position_ids: Optional[torch.LongTensor] = None,
+ ):
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+
+ if position_ids is None:
+ position_ids = torch.cumsum(attention_mask, dim=1)
+ position_ids = (position_ids * attention_mask - 1).long()
+ # cut positions if `past_key_values_length` is > 0
+ position_ids = position_ids[:, past_key_values_length:]
+
+ return super().forward(position_ids + self.offset)
+
+
+class BioGptScaledWordEmbedding(nn.Embedding):
+ """
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
+ self.embed_scale = embed_scale
+
+ def forward(self, input_ids: torch.Tensor):
+ return super().forward(input_ids) * self.embed_scale
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class BioGptAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ is_causal: bool = False,
+ config: Optional[BioGptConfig] = None,
+ layer_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+ self.layer_idx = layer_idx
+ if layer_idx is None and self.is_decoder:
+ logger.warning_once(
+ f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
+ "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
+ is_updated = False
+ if past_key_values is not None:
+ if isinstance(past_key_values, EncoderDecoderCache):
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = past_key_values.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_values.self_attention_cache
+ else:
+ curr_past_key_value = past_key_values
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_values is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
+ value_states = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_states = self.k_proj(current_states)
+ value_states = self.v_proj(current_states)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
+
+ if past_key_values is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
+ past_key_values.is_updated[self.layer_idx] = True
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class BioGptDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+
+ self.self_attn = BioGptAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_probs_dropout_prob,
+ is_decoder=True,
+ is_causal=True,
+ config=config,
+ layer_idx=layer_idx,
+ )
+ self.dropout = config.hidden_dropout_prob
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ position_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ past_key_values (`Cache`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
+ cache in the correct position and to infer the complete sequence length.
+ """
+ residual = hidden_states
+
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class BioGptPreTrainedModel(PreTrainedModel):
+ config: BioGptConfig
+ base_model_prefix = "biogpt"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ ):
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
+ return attention_mask
+
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+@auto_docstring
+class BioGptModel(BioGptPreTrainedModel):
+ def __init__(self, config: BioGptConfig):
+ super().__init__(config)
+ self.config = config
+ self.layerdrop = config.layerdrop
+ self.dropout = config.hidden_dropout_prob
+ self.embed_dim = config.hidden_size
+ self.padding_idx = config.pad_token_id
+ embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
+
+ self.embed_tokens = BioGptScaledWordEmbedding(
+ config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale
+ )
+ self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
+
+ self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+ self.layer_norm = nn.LayerNorm(self.embed_dim)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # initialize past_key_values
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+ if use_cache and isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `DynamicCache` instead, e.g. "
+ "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ batch_size, seq_length = inputs_embeds.size()[:-1]
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
+ )
+
+ if attention_mask is None:
+ # required mask seq length can be calculated via length of past cache
+ mask_seq_length = past_key_values_length + seq_length
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
+
+ self_attn_cache = past_key_values
+
+ causal_mask = self._update_causal_mask(
+ attention_mask,
+ inputs_embeds,
+ cache_position,
+ self_attn_cache,
+ )
+
+ # embed positions
+ if position_ids is None:
+ # position_ids = cache_position.unsqueeze(0)
+ position_ids = torch.cumsum(attention_mask, dim=1)
+ position_ids = (position_ids * attention_mask - 1).long()
+ # cut positions if `past_seen_tokens` is > 0
+ position_ids = position_ids[:, past_key_values_length:]
+
+ positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
+ hidden_states = inputs_embeds + positions
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
+ """
+)
+class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["output_projection.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.biogpt = BioGptModel(config)
+ self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.output_projection
+
+ def set_output_embeddings(self, new_embeddings):
+ self.output_projection = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.biogpt(
+ input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.output_projection(sequence_output)
+
+ lm_loss = None
+ if labels is not None:
+ lm_loss = self.loss_function(
+ prediction_scores,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[1:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+@auto_docstring
+class BioGptForTokenClassification(BioGptPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.biogpt = BioGptModel(config)
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
+ classifier_dropout = config.classifier_dropout
+ else:
+ classifier_dropout = config.hidden_dropout_prob
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.biogpt(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = transformer_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)
+ active_labels = torch.where(
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+ )
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The BioGpt Model transformer with a sequence classification head on top (linear layer).
+
+ [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it is required to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+class BioGptForSequenceClassification(BioGptPreTrainedModel):
+ def __init__(self, config: BioGptConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.biogpt = BioGptModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ ) -> Union[tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.biogpt(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ hidden_states = transformer_outputs[0]
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.score(hidden_states[:, slice_indices, :])
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ if self.config.pad_token_id is None:
+ sequence_length = -1
+ else:
+ if input_ids is not None:
+ sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+ else:
+ sequence_length = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ def get_input_embeddings(self):
+ return self.biogpt.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.biogpt.embed_tokens = value
+
+
+__all__ = [
+ "BioGptForCausalLM",
+ "BioGptForTokenClassification",
+ "BioGptForSequenceClassification",
+ "BioGptModel",
+ "BioGptPreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/biogpt/modular_biogpt.py b/venv/lib/python3.13/site-packages/transformers/models/biogpt/modular_biogpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d95b2a2d051a960715fa10edc9783f380a53696
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/biogpt/modular_biogpt.py
@@ -0,0 +1,789 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BioGPT model."""
+
+import math
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import (
+ AttentionMaskConverter,
+)
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ TransformersKwargs,
+ auto_docstring,
+ is_torch_flex_attn_available,
+ logger,
+)
+from ...utils.deprecation import deprecate_kwarg
+from ..bart.modeling_bart import (
+ BartAttention,
+ BartDecoderLayer,
+ BartScaledWordEmbedding,
+)
+from ..opt.modeling_opt import OPTLearnedPositionalEmbedding
+from .configuration_biogpt import BioGptConfig
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
+
+
+class BioGptLearnedPositionalEmbedding(OPTLearnedPositionalEmbedding):
+ def forward(
+ self,
+ attention_mask: torch.LongTensor,
+ past_key_values_length: int = 0,
+ position_ids: Optional[torch.LongTensor] = None,
+ ):
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+ super().forward(attention_mask, past_key_values_length, position_ids)
+
+
+class BioGptScaledWordEmbedding(BartScaledWordEmbedding):
+ pass
+
+
+class BioGptAttention(BartAttention):
+ pass
+
+
+class BioGptDecoderLayer(BartDecoderLayer):
+ def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None):
+ super().__init__(config)
+ self.embed_dim = config.hidden_size
+
+ self.self_attn = BioGptAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_probs_dropout_prob,
+ is_decoder=True,
+ is_causal=True,
+ config=config,
+ layer_idx=layer_idx,
+ )
+ self.dropout = config.hidden_dropout_prob
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim)
+
+ del self.encoder_attn
+ del self.encoder_attn_layer_norm
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ position_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ past_key_values (`Cache`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
+ cache in the correct position and to infer the complete sequence length.
+ """
+ residual = hidden_states
+
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class BioGptPreTrainedModel(PreTrainedModel):
+ config: BioGptConfig
+ base_model_prefix = "biogpt"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ ):
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
+ return attention_mask
+
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+@auto_docstring
+class BioGptModel(BioGptPreTrainedModel):
+ def __init__(self, config: BioGptConfig):
+ super().__init__(config)
+ self.config = config
+ self.layerdrop = config.layerdrop
+ self.dropout = config.hidden_dropout_prob
+ self.embed_dim = config.hidden_size
+ self.padding_idx = config.pad_token_id
+ embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
+
+ self.embed_tokens = BioGptScaledWordEmbedding(
+ config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale
+ )
+ self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
+
+ self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+ self.layer_norm = nn.LayerNorm(self.embed_dim)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # initialize past_key_values
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+ if use_cache and isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `DynamicCache` instead, e.g. "
+ "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ batch_size, seq_length = inputs_embeds.size()[:-1]
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
+ )
+
+ if attention_mask is None:
+ # required mask seq length can be calculated via length of past cache
+ mask_seq_length = past_key_values_length + seq_length
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
+
+ self_attn_cache = past_key_values
+
+ causal_mask = self._update_causal_mask(
+ attention_mask,
+ inputs_embeds,
+ cache_position,
+ self_attn_cache,
+ )
+
+ # embed positions
+ if position_ids is None:
+ # position_ids = cache_position.unsqueeze(0)
+ position_ids = torch.cumsum(attention_mask, dim=1)
+ position_ids = (position_ids * attention_mask - 1).long()
+ # cut positions if `past_seen_tokens` is > 0
+ position_ids = position_ids[:, past_key_values_length:]
+
+ positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
+ hidden_states = inputs_embeds + positions
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
+ """
+)
+class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["output_projection.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.biogpt = BioGptModel(config)
+ self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.output_projection
+
+ def set_output_embeddings(self, new_embeddings):
+ self.output_projection = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.biogpt(
+ input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.output_projection(sequence_output)
+
+ lm_loss = None
+ if labels is not None:
+ lm_loss = self.loss_function(
+ prediction_scores,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[1:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+@auto_docstring
+class BioGptForTokenClassification(BioGptPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.biogpt = BioGptModel(config)
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
+ classifier_dropout = config.classifier_dropout
+ else:
+ classifier_dropout = config.hidden_dropout_prob
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.biogpt(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = transformer_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)
+ active_labels = torch.where(
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+ )
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The BioGpt Model transformer with a sequence classification head on top (linear layer).
+
+ [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it is required to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+class BioGptForSequenceClassification(BioGptPreTrainedModel):
+ def __init__(self, config: BioGptConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.biogpt = BioGptModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ ) -> Union[tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.biogpt(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ hidden_states = transformer_outputs[0]
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.score(hidden_states[:, slice_indices, :])
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ if self.config.pad_token_id is None:
+ sequence_length = -1
+ else:
+ if input_ids is not None:
+ sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+ else:
+ sequence_length = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ def get_input_embeddings(self):
+ return self.biogpt.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.biogpt.embed_tokens = value
+
+
+__all__ = [
+ "BioGptForCausalLM",
+ "BioGptForTokenClassification",
+ "BioGptForSequenceClassification",
+ "BioGptModel",
+ "BioGptPreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/biogpt/tokenization_biogpt.py b/venv/lib/python3.13/site-packages/transformers/models/biogpt/tokenization_biogpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..f84403ca7ddc65586528da07e4efc9f8621a0e19
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/biogpt/tokenization_biogpt.py
@@ -0,0 +1,331 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for BioGPT."""
+
+import json
+import os
+from typing import Optional
+
+from ...tokenization_utils import PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+ "vocab_file": "vocab.json",
+ "merges_file": "merges.txt",
+}
+
+
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
+ strings)
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+class BioGptTokenizer(PreTrainedTokenizer):
+ """
+ Construct an FAIRSEQ Transformer tokenizer. Moses tokenization followed by Byte-Pair Encoding.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Merges file.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ unk_token="",
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ pad_token="",
+ **kwargs,
+ ):
+ try:
+ import sacremoses
+ except ImportError:
+ raise ImportError(
+ "You need to install sacremoses to use BioGptTokenizer. "
+ "See https://pypi.org/project/sacremoses/ for installation."
+ )
+
+ self.lang = "en"
+ self.sm = sacremoses
+ # cache of sm.MosesTokenizer instance
+ self.cache_moses_tokenizer = {}
+ self.cache_moses_detokenizer = {}
+
+ """ Initialisation"""
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ merges = merges_handle.read().split("\n")[:-1]
+ merges = [tuple(merge.split()[:2]) for merge in merges]
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {}
+
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self):
+ """Returns vocab size"""
+ return len(self.encoder)
+
+ def get_vocab(self):
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ def moses_tokenize(self, text, lang):
+ if lang not in self.cache_moses_tokenizer:
+ moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
+ self.cache_moses_tokenizer[lang] = moses_tokenizer
+ return self.cache_moses_tokenizer[lang].tokenize(
+ text, aggressive_dash_splits=True, return_str=False, escape=True
+ )
+
+ def moses_detokenize(self, tokens, lang):
+ if lang not in self.cache_moses_detokenizer:
+ moses_detokenizer = self.sm.MosesDetokenizer(lang=lang)
+ self.cache_moses_detokenizer[lang] = moses_detokenizer
+ return self.cache_moses_detokenizer[lang].detokenize(tokens)
+
+ def bpe(self, token):
+ word = tuple(token[:-1]) + (token[-1] + "",)
+ if token in self.cache:
+ return self.cache[token]
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ""
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ if word == "\n ":
+ word = "\n"
+ self.cache[token] = word
+ return word
+
+ def _tokenize(self, text, bypass_tokenizer=False):
+ """Returns a tokenized string."""
+ if bypass_tokenizer:
+ text = text.split()
+ else:
+ text = self.moses_tokenize(text, self.lang)
+
+ split_tokens = []
+ for token in text:
+ if token:
+ split_tokens.extend(list(self.bpe(token).split(" ")))
+
+ return split_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ # remove BPE
+ tokens = [t.replace(" ", "").replace("", " ") for t in tokens]
+ tokens = "".join(tokens).split()
+ # detokenize
+ text = self.moses_detokenize(tokens, self.lang)
+ return text
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BioGPT sequence has the following format:
+
+ - single sequence: ` X `
+ - pair of sequences: ` A B `
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.sep_token_id] + token_ids_0
+ sep = [self.sep_token_id]
+ return sep + token_ids_0 + sep + token_ids_1
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+ # no bos used in fairseq
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
+ return [1] + ([0] * len(token_ids_0))
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ merge_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+ )
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!"
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return vocab_file, merge_file
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sm"] = None
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+
+ try:
+ import sacremoses
+ except ImportError:
+ raise ImportError(
+ "You need to install sacremoses to use XLMTokenizer. "
+ "See https://pypi.org/project/sacremoses/ for installation."
+ )
+
+ self.sm = sacremoses
+
+
+__all__ = ["BioGptTokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bit/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/bit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..edfeb4dbe75bb53c011719d6c550b245ac814b28
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bit/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_bit import *
+ from .image_processing_bit import *
+ from .image_processing_bit_fast import *
+ from .modeling_bit import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bit/configuration_bit.py b/venv/lib/python3.13/site-packages/transformers/models/bit/configuration_bit.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b1f24fa0688fe8a62748e8fac58cf62a2b176d0
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bit/configuration_bit.py
@@ -0,0 +1,136 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""BiT model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class BitConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`BitModel`]. It is used to instantiate an BiT
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the BiT
+ [google/bit-50](https://huggingface.co/google/bit-50) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ embedding_size (`int`, *optional*, defaults to 64):
+ Dimensionality (hidden size) for the embedding layer.
+ hidden_sizes (`list[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`):
+ Dimensionality (hidden size) at each stage.
+ depths (`list[int]`, *optional*, defaults to `[3, 4, 6, 3]`):
+ Depth (number of layers) for each stage.
+ layer_type (`str`, *optional*, defaults to `"preactivation"`):
+ The layer to use, it can be either `"preactivation"` or `"bottleneck"`.
+ hidden_act (`str`, *optional*, defaults to `"relu"`):
+ The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"`
+ are supported.
+ global_padding (`str`, *optional*):
+ Padding strategy to use for the convolutional layers. Can be either `"valid"`, `"same"`, or `None`.
+ num_groups (`int`, *optional*, defaults to 32):
+ Number of groups used for the `BitGroupNormActivation` layers.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ The drop path rate for the stochastic depth.
+ embedding_dynamic_padding (`bool`, *optional*, defaults to `False`):
+ Whether or not to make use of dynamic padding for the embedding layer.
+ output_stride (`int`, *optional*, defaults to 32):
+ The output stride of the model.
+ width_factor (`int`, *optional*, defaults to 1):
+ The width factor for the model.
+ out_features (`list[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`list[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+
+ Example:
+ ```python
+ >>> from transformers import BitConfig, BitModel
+
+ >>> # Initializing a BiT bit-50 style configuration
+ >>> configuration = BitConfig()
+
+ >>> # Initializing a model (with random weights) from the bit-50 style configuration
+ >>> model = BitModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "bit"
+ layer_types = ["preactivation", "bottleneck"]
+ supported_padding = ["SAME", "VALID"]
+
+ def __init__(
+ self,
+ num_channels=3,
+ embedding_size=64,
+ hidden_sizes=[256, 512, 1024, 2048],
+ depths=[3, 4, 6, 3],
+ layer_type="preactivation",
+ hidden_act="relu",
+ global_padding=None,
+ num_groups=32,
+ drop_path_rate=0.0,
+ embedding_dynamic_padding=False,
+ output_stride=32,
+ width_factor=1,
+ out_features=None,
+ out_indices=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if layer_type not in self.layer_types:
+ raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}")
+ if global_padding is not None:
+ if global_padding.upper() in self.supported_padding:
+ global_padding = global_padding.upper()
+ else:
+ raise ValueError(f"Padding strategy {global_padding} not supported")
+ self.num_channels = num_channels
+ self.embedding_size = embedding_size
+ self.hidden_sizes = hidden_sizes
+ self.depths = depths
+ self.layer_type = layer_type
+ self.hidden_act = hidden_act
+ self.global_padding = global_padding
+ self.num_groups = num_groups
+ self.drop_path_rate = drop_path_rate
+ self.embedding_dynamic_padding = embedding_dynamic_padding
+ self.output_stride = output_stride
+ self.width_factor = width_factor
+
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+
+
+__all__ = ["BitConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bit/image_processing_bit.py b/venv/lib/python3.13/site-packages/transformers/models/bit/image_processing_bit.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d32752edca86aa5105edbbc7dc566d71f79ab8f
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bit/image_processing_bit.py
@@ -0,0 +1,324 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for BiT."""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ convert_to_rgb,
+ get_resize_output_image_size,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_vision_available():
+ import PIL
+
+
+class BitImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a BiT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in the `preprocess` method.
+ size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
+ `preprocess` method.
+ crop_size (`dict[str, int]` *optional*, defaults to 224):
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
+ method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize:
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_MEAN`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_center_crop: bool = True,
+ crop_size: Optional[dict[str, int]] = None,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 224}
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+ self.do_convert_rgb = do_convert_rgb
+
+ # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
+ resized to keep the input aspect ratio.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ default_to_square = True
+ if "shortest_edge" in size:
+ size = size["shortest_edge"]
+ default_to_square = False
+ elif "height" in size and "width" in size:
+ size = (size["height"], size["width"])
+ else:
+ raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
+
+ output_size = get_resize_output_image_size(
+ image,
+ size=size,
+ default_to_square=default_to_square,
+ input_data_format=input_data_format,
+ )
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[int] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size, param_name="size", default_to_square=False)
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ # PIL RGBA images are converted to RGB
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ all_images = []
+ for image in images:
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_center_crop:
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ all_images.append(image)
+
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ for image in all_images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["BitImageProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bit/image_processing_bit_fast.py b/venv/lib/python3.13/site-packages/transformers/models/bit/image_processing_bit_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..46c0e94f3d190b0b5157aa3e2a3ba7c38f3f41cc
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bit/image_processing_bit_fast.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for BiT."""
+
+from ...image_processing_utils_fast import BaseImageProcessorFast
+from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
+from ...utils import auto_docstring
+
+
+@auto_docstring
+class BitImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ size = {"shortest_edge": 224}
+ default_to_square = False
+ crop_size = {"height": 224, "width": 224}
+ rescale_factor = 1 / 255
+ do_resize = True
+ do_center_crop = True
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+
+
+__all__ = ["BitImageProcessorFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bit/modeling_bit.py b/venv/lib/python3.13/site-packages/transformers/models/bit/modeling_bit.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e491f06eae6c8bb77de28c372aeebfc7795966f
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bit/modeling_bit.py
@@ -0,0 +1,821 @@
+# coding=utf-8
+# Copyright 2022 Google AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BiT model. Also supports backbone for ViT hybrid."""
+
+import collections
+import math
+from typing import Optional
+
+import numpy as np
+import torch
+from torch import Tensor, nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+ BackboneOutput,
+ BaseModelOutputWithNoAttention,
+ BaseModelOutputWithPoolingAndNoAttention,
+ ImageClassifierOutputWithNoAttention,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging
+from ...utils.backbone_utils import BackboneMixin
+from .configuration_bit import BitConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_padding_value(padding=None, kernel_size=7, stride=1, dilation=1) -> tuple[tuple, bool]:
+ r"""
+ Utility function to get the tuple padding value given the kernel_size and padding.
+
+ Args:
+ padding (Union[`str`, `int`], *optional*):
+ Padding value, can be either `"same"`, `"valid"`. If a different value is provided the default padding from
+ PyTorch is used.
+ kernel_size (`int`, *optional*, defaults to 7):
+ Kernel size of the convolution layers.
+ stride (`int`, *optional*, defaults to 1):
+ Stride value of the convolution layers.
+ dilation (`int`, *optional*, defaults to 1):
+ Dilation value of the convolution layers.
+ """
+ dynamic = False
+ if padding is None:
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
+ return padding, dynamic
+
+ if isinstance(padding, str):
+ # for any string padding, the padding will be calculated for you, one of three ways
+ padding = padding.lower()
+ if padding == "same":
+ # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
+ if stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0:
+ # static case, no extra overhead
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
+ else:
+ # dynamic 'SAME' padding, has runtime/GPU memory overhead
+ padding = 0
+ dynamic = True
+ elif padding == "valid":
+ # 'VALID' padding, same as padding=0
+ padding = 0
+ else:
+ # Default to PyTorch style 'same'-ish symmetric padding
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
+ return padding, dynamic
+
+
+class WeightStandardizedConv2d(nn.Conv2d):
+ """Conv2d with Weight Standardization. Includes TensorFlow compatible SAME padding. Used for ViT Hybrid model.
+
+ Paper: [Micro-Batch Training with Batch-Channel Normalization and Weight
+ Standardization](https://huggingface.co/papers/1903.10520v2)
+ """
+
+ def __init__(
+ self,
+ in_channel,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding="SAME",
+ dilation=1,
+ groups=1,
+ bias=False,
+ eps=1e-6,
+ ):
+ padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
+ super().__init__(
+ in_channel,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ )
+ if is_dynamic:
+ self.pad = DynamicPad2d(kernel_size, stride, dilation)
+ else:
+ self.pad = None
+ self.eps = eps
+
+ def forward(self, hidden_state):
+ if self.pad is not None:
+ hidden_state = self.pad(hidden_state)
+ weight = nn.functional.batch_norm(
+ self.weight.reshape(1, self.out_channels, -1), None, None, training=True, momentum=0.0, eps=self.eps
+ ).reshape_as(self.weight)
+ hidden_state = nn.functional.conv2d(
+ hidden_state, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
+ )
+ return hidden_state
+
+
+class BitGroupNormActivation(nn.GroupNorm):
+ r"""
+ A module that combines group normalization with an activation function.
+ """
+
+ def __init__(self, config, num_channels, eps=1e-5, affine=True, apply_activation=True):
+ super().__init__(config.num_groups, num_channels, eps=eps, affine=affine)
+ if apply_activation:
+ self.activation = ACT2FN[config.hidden_act]
+ else:
+ self.activation = nn.Identity()
+
+ def forward(self, hidden_state):
+ hidden_state = nn.functional.group_norm(hidden_state, self.num_groups, self.weight, self.bias, self.eps)
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class DynamicPad2d(nn.Module):
+ r"""
+ A module that wraps dynamic padding of any input, given the parameters of the convolutional layer and the input
+ hidden states.
+ """
+
+ def __init__(self, kernel_size, stride, dilation, value=0):
+ super().__init__()
+ # Safety checkers
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size, kernel_size)
+
+ if isinstance(stride, int):
+ stride = (stride, stride)
+
+ if isinstance(dilation, int):
+ dilation = (dilation, dilation)
+
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.dilation = dilation
+ self.value = value
+
+ def compute_padding(x, kernel_size, stride, dilation):
+ return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0)
+
+ self.compute_padding = compute_padding
+
+ def forward(self, input):
+ # Get width and height
+ input_height, input_width = input.size()[-2:]
+
+ # Compute the padding values
+ padding_height = self.compute_padding(input_height, self.kernel_size[0], self.stride[0], self.dilation[0])
+ padding_width = self.compute_padding(input_width, self.kernel_size[1], self.stride[1], self.dilation[1])
+
+ # apply pad
+ if padding_height > 0 or padding_width > 0:
+ input = nn.functional.pad(
+ input,
+ [
+ padding_width // 2,
+ padding_width - padding_width // 2,
+ padding_height // 2,
+ padding_height - padding_height // 2,
+ ],
+ value=self.value,
+ )
+ return input
+
+
+class BitMaxPool2d(nn.MaxPool2d):
+ """Tensorflow like 'SAME' wrapper for 2D max pooling"""
+
+ def __init__(
+ self,
+ kernel_size: int,
+ stride=None,
+ dilation=1,
+ ceil_mode=False,
+ padding=(0, 0),
+ padding_value=0,
+ use_dynamic_padding=True,
+ ):
+ kernel_size = kernel_size if isinstance(kernel_size, collections.abc.Iterable) else (kernel_size, kernel_size)
+ stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride)
+ dilation = dilation if isinstance(dilation, collections.abc.Iterable) else (dilation, dilation)
+ super().__init__(kernel_size, stride, padding, dilation, ceil_mode)
+ if use_dynamic_padding:
+ self.pad = DynamicPad2d(kernel_size, stride, dilation, padding_value)
+ else:
+ self.pad = nn.Identity()
+
+ def forward(self, hidden_states):
+ hidden_states = self.pad(hidden_states)
+ return nn.functional.max_pool2d(
+ hidden_states, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode
+ )
+
+
+class BitEmbeddings(nn.Module):
+ """
+ BiT Embeddings (stem) composed of a single aggressive convolution.
+ """
+
+ def __init__(self, config: BitConfig):
+ super().__init__()
+
+ self.convolution = WeightStandardizedConv2d(
+ config.num_channels,
+ config.embedding_size,
+ kernel_size=7,
+ stride=2,
+ eps=1e-8,
+ padding=config.global_padding,
+ )
+
+ self.pooler = BitMaxPool2d(kernel_size=3, stride=2, use_dynamic_padding=config.embedding_dynamic_padding)
+
+ # Use the same padding strategy as convolutional layers
+ if config.global_padding is not None and config.global_padding.upper() == "SAME":
+ self.pad = nn.Identity()
+ else:
+ self.pad = nn.ConstantPad2d(padding=(1, 1, 1, 1), value=0.0)
+
+ if config.layer_type != "preactivation":
+ self.norm = BitGroupNormActivation(config, num_channels=config.embedding_size)
+ else:
+ self.norm = nn.Identity()
+
+ self.num_channels = config.num_channels
+
+ def forward(self, pixel_values: Tensor) -> Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+
+ embedding = self.convolution(pixel_values)
+
+ embedding = self.pad(embedding)
+
+ embedding = self.norm(embedding)
+
+ embedding = self.pooler(embedding)
+
+ return embedding
+
+
+# Copied from transformers.models.convnext.modeling_convnext.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Bit
+class BitDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return f"p={self.drop_prob}"
+
+
+def make_div(value, divisor=8):
+ min_value = divisor
+ new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
+ if new_value < 0.9 * value:
+ new_value += divisor
+ return new_value
+
+
+class BitPreActivationBottleneckLayer(nn.Module):
+ """Pre-activation (v2) bottleneck block.
+ Follows the implementation of "Identity Mappings in Deep Residual Networks":
+ https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
+
+ Except it puts the stride on 3x3 conv when available.
+ """
+
+ def __init__(
+ self,
+ config,
+ in_channels,
+ out_channels=None,
+ bottle_ratio=0.25,
+ stride=1,
+ dilation=1,
+ first_dilation=None,
+ groups=1,
+ drop_path_rate=0.0,
+ is_first_layer=False,
+ ):
+ super().__init__()
+
+ first_dilation = first_dilation or dilation
+
+ out_channels = out_channels or in_channels
+ mid_channels = make_div(out_channels * bottle_ratio)
+
+ if is_first_layer:
+ self.downsample = BitDownsampleConv(
+ config,
+ in_channels,
+ out_channels,
+ stride=stride,
+ preact=True,
+ )
+ else:
+ self.downsample = None
+
+ self.norm1 = BitGroupNormActivation(config, in_channels)
+ self.conv1 = WeightStandardizedConv2d(in_channels, mid_channels, 1, eps=1e-8, padding=config.global_padding)
+
+ self.norm2 = BitGroupNormActivation(config, num_channels=mid_channels)
+ self.conv2 = WeightStandardizedConv2d(
+ mid_channels, mid_channels, 3, stride=stride, groups=groups, eps=1e-8, padding=config.global_padding
+ )
+
+ self.norm3 = BitGroupNormActivation(config, mid_channels)
+ self.conv3 = WeightStandardizedConv2d(mid_channels, out_channels, 1, eps=1e-8, padding=config.global_padding)
+
+ self.drop_path = BitDropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
+
+ def forward(self, hidden_states):
+ hidden_states_preact = self.norm1(hidden_states)
+
+ # shortcut branch
+ shortcut = hidden_states
+ if self.downsample is not None:
+ shortcut = self.downsample(hidden_states_preact)
+
+ # residual branch
+ hidden_states = self.conv1(hidden_states_preact)
+ hidden_states = self.conv2(self.norm2(hidden_states))
+ hidden_states = self.conv3(self.norm3(hidden_states))
+ hidden_states = self.drop_path(hidden_states)
+ return hidden_states + shortcut
+
+
+class BitBottleneckLayer(nn.Module):
+ """Non Pre-activation bottleneck block, equivalent to V1.5/V1b bottleneck. Used for ViT Hybrid."""
+
+ def __init__(
+ self,
+ config,
+ in_channels,
+ out_channels=None,
+ bottle_ratio=0.25,
+ stride=1,
+ dilation=1,
+ first_dilation=None,
+ groups=1,
+ drop_path_rate=0.0,
+ is_first_layer=False,
+ ):
+ super().__init__()
+ first_dilation = first_dilation or dilation
+
+ out_channels = out_channels or in_channels
+ mid_chs = make_div(out_channels * bottle_ratio)
+
+ if is_first_layer:
+ self.downsample = BitDownsampleConv(
+ config,
+ in_channels,
+ out_channels,
+ stride=stride,
+ preact=False,
+ )
+ else:
+ self.downsample = None
+
+ self.conv1 = WeightStandardizedConv2d(in_channels, mid_chs, 1, eps=1e-8, padding=config.global_padding)
+ self.norm1 = BitGroupNormActivation(config, num_channels=mid_chs)
+ self.conv2 = WeightStandardizedConv2d(
+ mid_chs,
+ mid_chs,
+ 3,
+ stride=stride,
+ dilation=first_dilation,
+ groups=groups,
+ eps=1e-8,
+ padding=config.global_padding,
+ )
+ self.norm2 = BitGroupNormActivation(config, num_channels=mid_chs)
+ self.conv3 = WeightStandardizedConv2d(mid_chs, out_channels, 1, eps=1e-8, padding=config.global_padding)
+ self.norm3 = BitGroupNormActivation(config, num_channels=out_channels, apply_activation=False)
+ self.drop_path = BitDropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
+
+ self.activation = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states):
+ # shortcut branch
+ shortcut = hidden_states
+ if self.downsample is not None:
+ shortcut = self.downsample(hidden_states)
+
+ # residual
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = self.norm1(hidden_states)
+
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+
+ hidden_states = self.conv3(hidden_states)
+ hidden_states = self.norm3(hidden_states)
+
+ hidden_states = self.drop_path(hidden_states)
+ hidden_states = self.activation(hidden_states + shortcut)
+ return hidden_states
+
+
+class BitDownsampleConv(nn.Module):
+ def __init__(
+ self,
+ config,
+ in_channels,
+ out_channels,
+ stride=1,
+ preact=True,
+ ):
+ super().__init__()
+ self.conv = WeightStandardizedConv2d(
+ in_channels, out_channels, 1, stride=stride, eps=1e-8, padding=config.global_padding
+ )
+ self.norm = (
+ nn.Identity()
+ if preact
+ else BitGroupNormActivation(config, num_channels=out_channels, apply_activation=False)
+ )
+
+ def forward(self, x):
+ return self.norm(self.conv(x))
+
+
+class BitStage(nn.Module):
+ """
+ A ResNet v2 stage composed by stacked layers.
+ """
+
+ def __init__(
+ self,
+ config,
+ in_channels,
+ out_channels,
+ stride,
+ dilation,
+ depth,
+ bottle_ratio=0.25,
+ layer_dropout=None,
+ ):
+ super().__init__()
+
+ first_dilation = 1 if dilation in (1, 2) else 2
+
+ # Get the layer type
+ if config.layer_type == "bottleneck":
+ layer_cls = BitBottleneckLayer
+ else:
+ layer_cls = BitPreActivationBottleneckLayer
+
+ prev_chs = in_channels
+ self.layers = nn.Sequential()
+ for layer_idx in range(depth):
+ # Get the current hyper-parameters
+ stride, drop_path_rate, is_first_layer = self._get_updated_hyperparameters(
+ layer_idx, stride, layer_dropout
+ )
+
+ self.layers.add_module(
+ str(layer_idx),
+ layer_cls(
+ config,
+ prev_chs,
+ out_channels,
+ stride=stride,
+ dilation=dilation,
+ bottle_ratio=bottle_ratio,
+ first_dilation=first_dilation,
+ drop_path_rate=drop_path_rate,
+ is_first_layer=is_first_layer,
+ ),
+ )
+ prev_chs = out_channels
+ first_dilation = dilation
+
+ def _get_updated_hyperparameters(self, layer_idx, stride, layer_dropout):
+ r"""
+ Get the new hyper-parameters with respect to the previous ones and the index of the current layer.
+ """
+ if layer_dropout:
+ drop_path_rate = layer_dropout[layer_idx]
+ else:
+ drop_path_rate = 0.0
+
+ if layer_idx != 0:
+ stride = 1
+
+ is_first_layer = layer_idx == 0
+
+ return stride, drop_path_rate, is_first_layer
+
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = input
+ for _, layer in enumerate(self.layers):
+ hidden_state = layer(hidden_state)
+ return hidden_state
+
+
+class BitEncoder(nn.Module):
+ def __init__(self, config: BitConfig):
+ super().__init__()
+ self.stages = nn.ModuleList([])
+
+ prev_chs = config.embedding_size
+
+ # These needs to stay hardcoded
+ current_stride = 4
+ dilation = 1
+
+ layer_dropouts = [
+ x.tolist()
+ for x in torch.Tensor(np.linspace(0, config.drop_path_rate, sum(config.depths))).split(config.depths)
+ ]
+
+ for stage_idx, (current_depth, current_hidden_size, layer_dropout) in enumerate(
+ zip(config.depths, config.hidden_sizes, layer_dropouts)
+ ):
+ # Get the updated hyper params
+ out_channels, stride, dilation = self._get_updated_hyperparameters(
+ stage_idx, current_stride, current_hidden_size, dilation, config
+ )
+
+ stage = BitStage(
+ config,
+ prev_chs,
+ out_channels,
+ stride=stride,
+ dilation=dilation,
+ depth=current_depth,
+ layer_dropout=layer_dropout,
+ )
+
+ prev_chs = out_channels
+ current_stride *= stride
+
+ self.stages.add_module(str(stage_idx), stage)
+
+ def _get_updated_hyperparameters(self, stage_idx, current_stride, current_hidden_size, dilation, config):
+ out_channels = make_div(current_hidden_size * config.width_factor)
+ stride = 1 if stage_idx == 0 else 2
+ if current_stride >= config.output_stride:
+ dilation *= stride
+ stride = 1
+ return out_channels, stride, dilation
+
+ def forward(
+ self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
+ ) -> BaseModelOutputWithNoAttention:
+ hidden_states = () if output_hidden_states else None
+
+ for stage_module in self.stages:
+ if output_hidden_states:
+ hidden_states = hidden_states + (hidden_state,)
+
+ hidden_state = stage_module(hidden_state)
+
+ if output_hidden_states:
+ hidden_states = hidden_states + (hidden_state,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_state, hidden_states] if v is not None)
+
+ return BaseModelOutputWithNoAttention(
+ last_hidden_state=hidden_state,
+ hidden_states=hidden_states,
+ )
+
+
+@auto_docstring
+class BitPreTrainedModel(PreTrainedModel):
+ config: BitConfig
+ base_model_prefix = "bit"
+ main_input_name = "pixel_values"
+ _no_split_modules = ["BitEmbeddings"]
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Conv2d):
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
+ # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
+ elif isinstance(module, nn.Linear):
+ nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
+ if module.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
+ nn.init.uniform_(module.bias, -bound, bound)
+ elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(module.weight, 1)
+ nn.init.constant_(module.bias, 0)
+
+
+@auto_docstring
+class BitModel(BitPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.embedder = BitEmbeddings(config)
+
+ self.encoder = BitEncoder(config)
+ self.norm = (
+ BitGroupNormActivation(config, num_channels=config.hidden_sizes[-1])
+ if config.layer_type == "preactivation"
+ else nn.Identity()
+ )
+
+ self.pooler = nn.AdaptiveAvgPool2d((1, 1))
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
+ ) -> BaseModelOutputWithPoolingAndNoAttention:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ embedding_output = self.embedder(pixel_values)
+
+ encoder_outputs = self.encoder(
+ embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict
+ )
+
+ last_hidden_state = encoder_outputs[0]
+
+ last_hidden_state = self.norm(last_hidden_state)
+
+ pooled_output = self.pooler(last_hidden_state)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ BiT Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """
+)
+class BitForImageClassification(BitPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.bit = BitModel(config)
+ # classification head
+ self.classifier = nn.Sequential(
+ nn.Flatten(),
+ nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(),
+ )
+ # initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> ImageClassifierOutputWithNoAttention:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return (loss,) + output if loss is not None else output
+
+ return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
+
+
+@auto_docstring(
+ custom_intro="""
+ BiT backbone, to be used with frameworks like DETR and MaskFormer.
+ """
+)
+class BitBackbone(BitPreTrainedModel, BackboneMixin):
+ has_attentions = False
+
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.bit = BitModel(config)
+ self.num_features = [config.embedding_size] + config.hidden_sizes
+
+ # initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
+ ) -> BackboneOutput:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("google/bit-50")
+ >>> model = AutoBackbone.from_pretrained("google/bit-50")
+
+ >>> inputs = processor(image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.bit(pixel_values, output_hidden_states=True, return_dict=True)
+
+ hidden_states = outputs.hidden_states
+
+ feature_maps = ()
+ for idx, stage in enumerate(self.stage_names):
+ if stage in self.out_features:
+ feature_maps += (hidden_states[idx],)
+
+ if not return_dict:
+ output = (feature_maps,)
+ if output_hidden_states:
+ output += (outputs.hidden_states,)
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=None,
+ )
+
+
+__all__ = ["BitForImageClassification", "BitModel", "BitPreTrainedModel", "BitBackbone"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bloom/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/bloom/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72d1d6e6ca4724235ce46978e5c11e84583d4033
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bloom/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_bloom import *
+ from .modeling_bloom import *
+ from .modeling_flax_bloom import *
+ from .tokenization_bloom_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bloom/configuration_bloom.py b/venv/lib/python3.13/site-packages/transformers/models/bloom/configuration_bloom.py
new file mode 100644
index 0000000000000000000000000000000000000000..74748c11304111955f5a5ef038a13256d02d1837
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bloom/configuration_bloom.py
@@ -0,0 +1,238 @@
+# coding=utf-8
+# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Bloom configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+from typing import TYPE_CHECKING, Any, Optional
+
+from packaging import version
+
+
+if TYPE_CHECKING:
+ from ... import PreTrainedTokenizer, TensorType
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfigWithPast, PatchingSpec
+from ...utils import is_torch_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class BloomConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`BloomModel`]. It is used to instantiate a Bloom
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to the Bloom architecture
+ [bigscience/bloom](https://huggingface.co/bigscience/bloom).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 250880):
+ Vocabulary size of the Bloom model. Defines the maximum number of different tokens that can be represented
+ by the `inputs_ids` passed when calling [`BloomModel`]. Check [this
+ discussion](https://huggingface.co/bigscience/bloom/discussions/120#633d28389addb8530b406c2a) on how the
+ `vocab_size` has been defined.
+ hidden_size (`int`, *optional*, defaults to 64):
+ Dimensionality of the embeddings and hidden states.
+ n_layer (`int`, *optional*, defaults to 2):
+ Number of hidden layers in the Transformer encoder.
+ n_head (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
+ The epsilon to use in the layer normalization layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ apply_residual_connection_post_layernorm (`bool`, *optional*, defaults to `False`):
+ If enabled, use the layer norm of the hidden states as the residual in the transformer blocks
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
+ Dropout rate of the dropout function on the bias dropout.
+ attention_dropout (`float`, *optional*, defaults to 0.1):
+ Dropout rate applied to the attention probs
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ pretraining_tp (`int`, *optional*, defaults to `1`):
+ Experimental feature. Tensor parallelism rank used during pretraining with Megatron. Please refer to [this
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232). Note also that this is enabled only when
+ `slow_but_exact=True`.
+ slow_but_exact (`bool`, *optional*, defaults to `False`):
+ Experimental feature. Whether to use slow but exact implementation of the attention mechanism. While
+ merging the TP rank tensors, due to slicing operations the results may be slightly different between the
+ model trained on Megatron and our model. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232). A solution to obtain more accurate results is to
+ enable this feature. Enabling this will hurt the computational time of the inference. Will be probably
+ resolved in the future once the main model has been fine-tuned with TP_rank=1.
+
+ Example:
+
+ ```python
+ >>> from transformers import BloomConfig, BloomModel
+
+ >>> # Initializing a Bloom configuration
+ >>> configuration = BloomConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = BloomModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "bloom"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {
+ "num_hidden_layers": "n_layer",
+ "num_attention_heads": "n_head",
+ }
+
+ def __init__(
+ self,
+ vocab_size=250880,
+ hidden_size=64,
+ n_layer=2,
+ n_head=8,
+ layer_norm_epsilon=1e-5,
+ initializer_range=0.02,
+ use_cache=True,
+ bos_token_id=1,
+ eos_token_id=2,
+ apply_residual_connection_post_layernorm=False,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ pretraining_tp=1, # TP rank used when training with megatron
+ slow_but_exact=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ # Backward compatibility with n_embed kwarg
+ n_embed = kwargs.pop("n_embed", None)
+ self.hidden_size = hidden_size if n_embed is None else n_embed
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_range = initializer_range
+ self.use_cache = use_cache
+ self.pretraining_tp = pretraining_tp
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.slow_but_exact = slow_but_exact
+
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+
+class BloomOnnxConfig(OnnxConfigWithPast):
+ torch_onnx_minimum_version = version.parse("1.12")
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ task: str = "default",
+ patching_specs: Optional[list[PatchingSpec]] = None,
+ use_past: bool = False,
+ ):
+ super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
+ if not getattr(self._config, "pad_token_id", None):
+ # TODO: how to do that better?
+ self._config.pad_token_id = 0
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
+ if self.use_past:
+ # BLOOM stores values on dynamic axis 2. For more details see: https://github.com/huggingface/transformers/pull/18344
+ self.fill_with_past_key_values_(common_inputs, direction="inputs", inverted_values_shape=True)
+ common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
+ else:
+ common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
+
+ return common_inputs
+
+ @property
+ def num_layers(self) -> int:
+ return self._config.n_layer
+
+ @property
+ def num_attention_heads(self) -> int:
+ return self._config.n_head
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-3
+
+ def generate_dummy_inputs(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ batch_size: int = -1,
+ seq_length: int = -1,
+ is_pair: bool = False,
+ framework: Optional["TensorType"] = None,
+ ) -> Mapping[str, Any]:
+ common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
+ tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
+ )
+
+ # We need to order the input in the way they appears in the forward()
+ ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
+
+ # Need to add the past_keys
+ if self.use_past:
+ if not is_torch_available():
+ raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
+ else:
+ import torch
+
+ batch, seqlen = common_inputs["input_ids"].shape
+ # Not using the same length for past_key_values
+ past_key_values_length = seqlen + 2
+ head_dim = self._config.hidden_size // self.num_attention_heads
+ past_key_shape = (
+ batch * self.num_attention_heads,
+ head_dim,
+ past_key_values_length,
+ )
+ past_value_shape = (
+ batch * self.num_attention_heads,
+ past_key_values_length,
+ head_dim,
+ )
+ ordered_inputs["past_key_values"] = [
+ (torch.zeros(past_key_shape), torch.zeros(past_value_shape)) for _ in range(self.num_layers)
+ ]
+
+ ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
+ if self.use_past:
+ mask_dtype = ordered_inputs["attention_mask"].dtype
+ ordered_inputs["attention_mask"] = torch.cat(
+ [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
+ )
+
+ return ordered_inputs
+
+ @property
+ def default_onnx_opset(self) -> int:
+ return 13
+
+
+__all__ = ["BloomConfig", "BloomOnnxConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bloom/modeling_bloom.py b/venv/lib/python3.13/site-packages/transformers/models/bloom/modeling_bloom.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fde63e03b4dd33a8625a73a1a74d96780aebc86
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bloom/modeling_bloom.py
@@ -0,0 +1,1252 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BLOOM model."""
+
+import math
+import warnings
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
+from torch.nn import functional as F
+
+from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ auto_docstring,
+ is_torch_flex_attn_available,
+ logging,
+)
+from .configuration_bloom import BloomConfig
+
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
+ """
+ Link to paper: https://huggingface.co/papers/2108.12409 Alibi tensor is not causal as the original paper mentions, it
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
+ `softmax(l+a) = softmax(l)`. Based on
+ https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
+ TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
+
+ Args:
+ Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
+ attention_mask (`torch.Tensor`):
+ Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
+ num_heads (`int`):
+ number of heads
+ dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
+ dtype of the output tensor
+ """
+ batch_size, seq_length = attention_mask.shape
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
+ base = torch.tensor(
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
+ )
+ powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
+ slopes = torch.pow(base, powers)
+
+ if closest_power_of_2 != num_heads:
+ extra_base = torch.tensor(
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
+ )
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
+
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
+ # => the query_length dimension will then be broadcasted correctly
+ # This is more or less identical to T5's relative position bias:
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
+ alibi = slopes[..., None] * arange_tensor
+ return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
+
+
+def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
+ """
+ Dropout add function
+
+ Args:
+ x (`torch.tensor`):
+ input tensor
+ residual (`torch.tensor`):
+ residual tensor
+ prob (`float`):
+ dropout probability
+ training (`bool`):
+ training mode
+ """
+ out = F.dropout(x, p=prob, training=training)
+ out = residual + out
+ return out
+
+
+def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
+ """
+ Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
+ make the model jitable.
+
+ Args:
+ x (`torch.tensor`):
+ input hidden states
+ """
+ return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
+
+
+def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
+ """
+ gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
+ 0.3989423 * x * torch.exp(-0.5 * x * x)
+
+ Args:
+ g (`torch.tensor`):
+ gradient output tensor
+ x (`torch.tensor`):
+ input tensor
+ """
+ x = x[0] # x is a tuple of 1 element, needs to unpack it first
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
+ return ff * g
+
+
+class GeLUFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input: torch.Tensor) -> torch.Tensor:
+ ctx.save_for_backward(input)
+ return bloom_gelu_forward(input)
+
+ @staticmethod
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
+ input = ctx.saved_tensors
+ tmp = bloom_gelu_back(grad_output, input)
+ return tmp
+
+
+class BloomGelu(nn.Module):
+ """
+ BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
+ torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
+ copied from Megatron-DeepSpeed code and adapted for our needs
+
+ See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.training:
+ return GeLUFunction.apply(x)
+ else:
+ return bloom_gelu_forward(x)
+
+
+class BloomAttention(nn.Module):
+ def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+
+ self.pretraining_tp = config.pretraining_tp
+ self.slow_but_exact = config.slow_but_exact
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.n_head
+ self.head_dim = self.hidden_size // self.num_heads
+ self.split_size = self.hidden_size
+ self.hidden_dropout = config.hidden_dropout
+
+ if self.head_dim * self.num_heads != self.hidden_size:
+ raise ValueError(
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ # Layer-wise attention scaling
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
+ self.beta = 1.0
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
+
+ def _reshape(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Split the last dimension into (num_heads, head_dim) and reshapes to (bs, heads, len, dim) shape
+ without making any copies, results share same memory storage as `fused_qkv`
+
+ Args:
+ fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim]
+
+ Returns:
+ query: [batch_size, num_heads, seq_length, head_dim]
+ key: [batch_size, num_heads, seq_length, head_dim]
+ value: [batch_size, num_heads, seq_length, head_dim]
+ """
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
+ query_layer = fused_qkv[..., 0, :].transpose(1, 2)
+ key_layer = fused_qkv[..., 1, :].transpose(1, 2)
+ value_layer = fused_qkv[..., 2, :].transpose(1, 2)
+ return query_layer, key_layer, value_layer
+
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Merge heads together over the last dimension
+
+ Args:
+ x (`torch.tensor`): [batch_size * num_heads, seq_length, head_dim]
+
+ Returns:
+ torch.tensor: [batch_size, seq_length, num_heads * head_dim]
+ """
+ # What we want to achieve is:
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
+ batch_size_and_num_heads, seq_length, _ = x.shape
+ batch_size = batch_size_and_num_heads // self.num_heads
+
+ # First view to decompose the batch size
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
+ x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
+
+ # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
+ x = x.permute(0, 2, 1, 3)
+
+ # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
+ return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Cache] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ):
+ batch_size, q_length, _ = hidden_states.shape
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
+ # 3 x [batch_size, num_heads, seq_length, head_dim]
+ query_layer, key_layer, value_layer = self._reshape(fused_qkv)
+
+ if layer_past is not None:
+ cache_kwargs = {"cache_position": cache_position}
+ key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
+
+ # reshape qkv for further computations
+ query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
+ key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2)
+ value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
+
+ # [batch_size * num_heads, q_length, kv_length]
+ attention_scores = alibi.baddbmm(
+ batch1=query_layer,
+ batch2=key_layer,
+ beta=self.beta,
+ alpha=self.inv_norm_factor,
+ )
+
+ # change view to [batch_size, num_heads, q_length, kv_length]
+ attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
+ attn_weights = attn_weights + causal_mask
+
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
+ attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype)
+
+ # [batch_size, num_heads, q_length, kv_length]
+ attention_probs = self.attention_dropout(attention_probs)
+
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ # change view [batch_size x num_heads, q_length, kv_length]
+ attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1)
+
+ # matmul: [batch_size * num_heads, q_length, head_dim]
+ context_layer = torch.bmm(attention_probs_reshaped, value_layer)
+
+ # change view [batch_size, q_length, num_heads * head_dim]
+ context_layer = self._merge_heads(context_layer)
+
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ slices = self.hidden_size / self.pretraining_tp
+ output_tensor = torch.zeros_like(context_layer)
+ for i in range(self.pretraining_tp):
+ output_tensor = output_tensor + F.linear(
+ context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
+ )
+ else:
+ output_tensor = self.dense(context_layer)
+
+ output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
+ return output_tensor, attention_probs
+
+
+class BloomMLP(nn.Module):
+ def __init__(self, config: BloomConfig):
+ super().__init__()
+ hidden_size = config.hidden_size
+
+ self.pretraining_tp = config.pretraining_tp
+ self.slow_but_exact = config.slow_but_exact
+ self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
+ self.gelu_impl = BloomGelu()
+ self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
+ self.hidden_dropout = config.hidden_dropout
+
+ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
+
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ intermediate_output = torch.zeros_like(residual)
+ slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
+ for i in range(self.pretraining_tp):
+ intermediate_output = intermediate_output + F.linear(
+ hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
+ self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
+ )
+ else:
+ intermediate_output = self.dense_4h_to_h(hidden_states)
+
+ output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
+
+ return output
+
+
+class BloomBlock(GradientCheckpointingLayer):
+ def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ hidden_size = config.hidden_size
+
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+ self.num_heads = config.n_head
+ self.self_attention = BloomAttention(config, layer_idx)
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ self.mlp = BloomMLP(config)
+
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
+ self.hidden_dropout = config.hidden_dropout
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Cache] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ):
+ # hidden_states: [batch_size, seq_length, hidden_size]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+
+ # Layer norm post the self attention.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ # Self attention.
+ attention_output, attn_weights = self.self_attention(
+ layernorm_output,
+ residual,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ alibi=alibi,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+
+ layernorm_output = self.post_attention_layernorm(attention_output)
+
+ # Get residual
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = attention_output
+
+ # MLP.
+ output = self.mlp(layernorm_output, residual)
+
+ return output, attn_weights # hidden_states, attentions
+
+
+@auto_docstring
+class BloomPreTrainedModel(PreTrainedModel):
+ config: BloomConfig
+ base_model_prefix = "transformer"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["BloomBlock"]
+ _skip_keys_device_placement = "past_key_values"
+
+ _can_compile_fullgraph = True
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights."""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class BloomModel(BloomPreTrainedModel):
+ def __init__(self, config: BloomConfig):
+ super().__init__(config)
+
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.n_head
+
+ # Embedding + LN Embedding
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
+ self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ # Transformer blocks
+ self.h = nn.ModuleList([BloomBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+
+ # Final Layer Norm
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
+ return build_alibi_tensor(attention_mask, num_heads, dtype)
+
+ def get_input_embeddings(self):
+ return self.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
+ self.word_embeddings = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor], ...]]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **deprecated_arguments,
+ ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ batch_size, seq_length, _ = inputs_embeds.shape
+ past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ seq_length_with_past = seq_length + past_length
+ if cache_position is None:
+ cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape batch_size x num_heads x N x N
+ # head_mask has shape n_layer x batch x num_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ # Compute alibi tensor: check build_alibi_tensor documentation
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
+ else:
+ attention_mask = attention_mask.to(hidden_states.device)
+
+ alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ for i, block in enumerate(self.h):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = block(
+ hidden_states,
+ layer_past=past_key_values,
+ attention_mask=causal_mask,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=alibi,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[1],)
+
+ # Add last hidden state
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
+ )
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, "BlockMask"],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+@auto_docstring(
+ custom_intro="""
+ The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """
+)
+class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: BloomConfig):
+ super().__init__(config)
+ self.transformer = BloomModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def set_output_embeddings(self, new_embeddings: torch.Tensor):
+ self.lm_head = new_embeddings
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ use_cache=True,
+ **kwargs,
+ ):
+ # Overwritten because of the fixed-shape attention mask creation
+
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
+ # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
+ # (we can't check exception 3 while compiling)
+ # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
+ # generate the first token for each sequence. Later use the generated Input ids for continuation.
+ if past_key_values is not None:
+ if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
+ inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
+ elif (
+ inputs_embeds is not None # Exception 1
+ or cache_position[-1] >= input_ids.shape[1] # Exception 3
+ ):
+ input_ids = input_ids[:, -cache_position.shape[0] :]
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
+ input_ids = input_ids[:, cache_position]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
+ else:
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the
+ # input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in
+ # the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
+
+ # This part differs from other models because BLOOM needs a 2D mask to construct alibi tensor
+ # The only difference is the usage of 2D instead of 4D mask, but the shape will be static
+ if isinstance(past_key_values, StaticCache) and attention_mask is not None:
+ target_length = past_key_values.get_max_cache_shape()
+ batch_size, seq_length = attention_mask.shape
+ diff = target_length - seq_length
+
+ new_attn_mask = torch.zeros(batch_size, diff, device=attention_mask.device, dtype=attention_mask.dtype)
+ attention_mask = torch.cat(
+ [attention_mask, new_attn_mask],
+ dim=-1,
+ )
+
+ model_inputs.update(
+ {
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ }
+ )
+
+ # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
+ for key, value in kwargs.items():
+ if key not in model_inputs:
+ model_inputs[key] = value
+
+ return model_inputs
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor], ...]]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **deprecated_arguments,
+ ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ # Bloom has deprecated kwargs, so we need to pop num_items_in_batch explicitly
+ num_items_in_batch = deprecated_arguments.pop("num_items_in_batch", None)
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(lm_logits.device)
+ # Flatten the tokens
+ loss = self.loss_function(
+ lm_logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ num_items_in_batch=num_items_in_batch,
+ )
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Bloom Model transformer with a sequence classification head on top (linear layer).
+
+ [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-1) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+class BloomForSequenceClassification(BloomPreTrainedModel):
+ def __init__(self, config: BloomConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = BloomModel(config)
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor], ...]]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **deprecated_arguments,
+ ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class BloomForTokenClassification(BloomPreTrainedModel):
+ def __init__(self, config: BloomConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.transformer = BloomModel(config)
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
+ classifier_dropout = config.classifier_dropout
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor], ...]]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **deprecated_arguments,
+ ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ batch_size, seq_length = labels.shape
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
+ )
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class BloomForQuestionAnswering(BloomPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = BloomModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, QuestionAnsweringModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "BloomForCausalLM",
+ "BloomModel",
+ "BloomPreTrainedModel",
+ "BloomForSequenceClassification",
+ "BloomForTokenClassification",
+ "BloomForQuestionAnswering",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bloom/modeling_flax_bloom.py b/venv/lib/python3.13/site-packages/transformers/models/bloom/modeling_flax_bloom.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7bb1cc9c9a5b098f9c03d6313818e744e61951e
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bloom/modeling_flax_bloom.py
@@ -0,0 +1,737 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc. Team and Bigscience Workshop. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Flax BLOOM model."""
+
+import math
+from functools import partial
+from typing import Optional
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask
+from flax.linen.activation import tanh
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+
+from ...modeling_flax_outputs import (
+ FlaxBaseModelOutput,
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
+ FlaxCausalLMOutput,
+)
+from ...modeling_flax_utils import FlaxPreTrainedModel, append_call_sample_docstring
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from .configuration_bloom import BloomConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "bigscience/bloom"
+_CONFIG_FOR_DOC = "BloomConfig"
+
+
+BLOOM_START_DOCSTRING = r"""
+
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+BLOOM_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+def build_alibi_tensor(attention_mask: jnp.ndarray, num_heads: int, dtype: Optional[jnp.dtype] = jnp.float32):
+ """
+ Flax implementation of the BLOOM Alibi tensor. BLOOM Alibi tensor is not causal as the original paper mentions, it
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
+ `softmax(l+a) = softmax(l)`. Based on
+ https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
+ Link to paper: https://huggingface.co/papers/2108.12409
+
+ Args:
+ attention_mask (`jnp.ndarray`):
+ Token-wise attention mask, this should be of shape `(batch_size, max_seq_len)`.
+ num_heads (`int`):
+ Number of attention heads.
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
+ The data type (dtype) of the output tensor.
+
+ Returns: Alibi tensor of shape `(batch_size * num_heads, 1, max_seq_len)`.
+ """
+ batch_size, seq_length = attention_mask.shape
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
+ base = jnp.array(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=jnp.float32)
+ powers = jnp.arange(1, 1 + closest_power_of_2, dtype=jnp.float32)
+ slopes = jax.lax.pow(base, powers)
+
+ if closest_power_of_2 != num_heads:
+ extra_base = jnp.array(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=jnp.float32)
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
+ extra_powers = jnp.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=jnp.float32)
+ slopes = jnp.cat([slopes, jax.lax.pow(extra_base, extra_powers)], axis=0)
+
+ # Note: the Alibi tensor will added to the attention bias that will be applied to the query, key product of attention
+ # therefore, Alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
+ # so that the query_length dimension will then be broadcast correctly.
+ # This is more or less identical to T5's relative position bias:
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
+ arange_tensor = ((attention_mask.cumsum(axis=-1) - 1) * attention_mask)[:, None, :]
+ alibi = slopes[..., None] * arange_tensor
+ alibi = jnp.expand_dims(alibi, axis=2)
+ return jnp.asarray(alibi, dtype)
+
+
+class FlaxBloomAttention(nn.Module):
+ config: BloomConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.hidden_size = self.config.hidden_size
+ self.num_heads = self.config.n_head
+ self.head_dim = self.hidden_size // self.num_heads
+ self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
+
+ if self.head_dim * self.num_heads != self.hidden_size:
+ raise ValueError(
+ f"`hidden_size` must be divisible by `num_heads` (got `hidden_size`: {self.hidden_size} and "
+ f"`num_heads`: {self.num_heads})."
+ )
+
+ dense = partial(
+ nn.Dense,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ )
+
+ self.query_key_value = dense(self.hidden_size * 3)
+ self.dense = dense(self.hidden_size)
+ self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout)
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_heads, self.head_dim * 3))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
+
+ @nn.compact
+ # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJAttention._concatenate_to_cache
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slightly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key
+ # positions that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
+ def __call__(
+ self,
+ hidden_states,
+ residual,
+ alibi,
+ attention_mask=None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ # proj q, k, v
+ fused_qkv = self.query_key_value(hidden_states)
+ fused_qkv = self._split_heads(fused_qkv)
+ query, key, value = jnp.split(fused_qkv, 3, axis=-1)
+
+ causal_attention_mask = make_causal_mask(attention_mask, dtype="bool")
+
+ # for fast decoding causal attention mask should be shifted
+ causal_attention_mask_shift = (
+ self.variables["cache"]["cache_index"] if self.has_variable("cache", "cached_key") else 0
+ )
+
+ # fast decoding for generate requires special attention_mask
+ if self.has_variable("cache", "cached_key"):
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_attention_mask = jax.lax.dynamic_slice(
+ causal_attention_mask,
+ (0, 0, causal_attention_mask_shift, 0),
+ (1, 1, seq_length, max_decoder_length),
+ )
+
+ # broadcast causal attention mask & attention mask to fit for merge
+ causal_attention_mask = jnp.broadcast_to(
+ causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:]
+ )
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_attention_mask)
+
+ dropout_rng = None
+ if not deterministic and self.config.attention_dropout > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.has_variable("cache", "cached_key") or init_cache:
+ key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
+
+ # transform boolean mask into float mask
+ mask_value = jnp.finfo(self.dtype).min
+ attention_bias = lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
+ )
+
+ attention_bias = attention_bias + alibi
+
+ # Cast in fp32 if the original dtype is different from fp32
+ attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
+
+ attn_weights = dot_product_attention_weights(
+ query,
+ key,
+ bias=attention_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.config.attention_dropout,
+ deterministic=deterministic,
+ dtype=attention_dtype,
+ )
+
+ # Cast back in the original dtype if the native dtype is not fp32
+ if self.attention_softmax_in_fp32:
+ attn_weights = attn_weights.astype(self.dtype)
+
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
+ attn_output = self._merge_heads(attn_output)
+ attn_output = self.dense(attn_output)
+ attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
+
+ attn_output = attn_output + residual
+
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
+ return outputs
+
+
+class BloomGELU(nn.Module):
+ def setup(self):
+ self.dtype = jnp.float32
+
+ def __call__(self, x):
+ return x * 0.5 * (1.0 + tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
+
+
+class FlaxBloomMLP(nn.Module):
+ config: BloomConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ hidden_size = self.config.hidden_size
+
+ kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
+
+ self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init)
+ self.dense_4h_to_h = nn.Dense(hidden_size, dtype=self.dtype, kernel_init=kernel_init)
+ self.hidden_dropout = nn.Dropout(self.config.hidden_dropout)
+ self.act = BloomGELU()
+
+ def __call__(self, hidden_states, residual, deterministic: bool = True):
+ hidden_states = self.dense_h_to_4h(hidden_states)
+ hidden_states = self.act(hidden_states)
+
+ intermediate_output = self.dense_4h_to_h(hidden_states)
+
+ intermediate_output = intermediate_output + residual
+ hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic)
+
+ return hidden_states
+
+
+class FlaxBloomBlock(nn.Module):
+ config: BloomConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+
+ self.self_attention = FlaxBloomAttention(self.config, dtype=self.dtype)
+ self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+
+ self.mlp = FlaxBloomMLP(self.config, dtype=self.dtype)
+
+ self.apply_residual_connection_post_layernorm = self.config.apply_residual_connection_post_layernorm
+ self.hidden_dropout = self.config.hidden_dropout
+
+ def __call__(
+ self,
+ hidden_states,
+ alibi,
+ attention_mask=None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ layernorm_output = self.input_layernorm(hidden_states)
+
+ # layer norm before saving residual if config calls for it
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ # self-attention
+ attn_outputs = self.self_attention(
+ layernorm_output,
+ residual=residual,
+ alibi=alibi,
+ attention_mask=attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ )
+
+ attention_output = attn_outputs[0]
+
+ outputs = attn_outputs[1:]
+
+ post_layernorm = self.post_attention_layernorm(attention_output)
+
+ # set residual based on config
+ if self.apply_residual_connection_post_layernorm:
+ residual = post_layernorm
+ else:
+ residual = attention_output
+
+ output = self.mlp(post_layernorm, residual, deterministic=deterministic)
+
+ outputs = (output,) + outputs
+
+ return outputs
+
+
+class FlaxBloomPreTrainedModel(FlaxPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BloomConfig
+ base_model_prefix = "transformer"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: BloomConfig,
+ input_shape: tuple = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ def init_cache(self, batch_size, max_length):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ """
+ # init input variables to retrieve cache
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0), input_ids, attention_mask, return_dict=False, init_cache=True
+ )
+ return unfreeze(init_variables["cache"])
+
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ past_key_values: Optional[dict] = None,
+ params: Optional[dict] = None,
+ dropout_rng: jax.random.PRNGKey = None,
+ train: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size, sequence_length = input_ids.shape
+
+ if attention_mask is None:
+ attention_mask = jnp.ones((batch_size, sequence_length))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # If past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
+ # changed by FlaxBloomAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ not train,
+ False,
+ output_attentions,
+ output_hidden_states,
+ return_dict,
+ rngs=rngs,
+ mutable=mutable,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past_key_values = outputs
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past_key_values = outputs
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
+
+ return outputs
+
+
+class FlaxBloomBlockCollection(nn.Module):
+ config: BloomConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.layers = [
+ FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype)
+ for layer_number in range(self.config.num_hidden_layers)
+ ]
+
+ def __call__(
+ self,
+ hidden_states,
+ alibi,
+ attention_mask=None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ ):
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ for layer_number in range(self.config.num_hidden_layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = self.layers[layer_number](
+ hidden_states,
+ alibi=alibi,
+ attention_mask=attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions += (layer_outputs[1],)
+
+ # this contains possible `None` values - `FlaxBloomModule` will filter them out
+ outputs = (hidden_states, all_hidden_states, all_attentions)
+
+ return outputs
+
+
+class FlaxBloomModule(nn.Module):
+ config: BloomConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.embed_dim = self.config.hidden_size
+
+ # word embeddings (no positional embedding layer)
+ self.word_embeddings = nn.Embed(
+ self.config.vocab_size,
+ self.embed_dim,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ dtype=self.dtype,
+ )
+
+ # post-embedding layernorm
+ self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+
+ # transformer layers
+ self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype)
+
+ # final layernorm
+ self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ deterministic=True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ inputs_embeds = self.word_embeddings(input_ids)
+ # do post-embedding layernorm
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+
+ # build alibi depending on `attention_mask`
+ alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
+
+ outputs = self.h(
+ hidden_states,
+ alibi=alibi,
+ attention_mask=attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = outputs[1] + (hidden_states,)
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
+ else:
+ outputs = (hidden_states,) + outputs[1:]
+
+ if not return_dict:
+ return tuple(v for v in [outputs[0], outputs[-1]] if v is not None)
+
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=outputs[1],
+ attentions=outputs[-1],
+ )
+
+
+@add_start_docstrings(
+ "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
+ BLOOM_START_DOCSTRING,
+)
+# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoModel with GPTNeo->Bloom
+class FlaxBloomModel(FlaxBloomPreTrainedModel):
+ module_class = FlaxBloomModule
+
+
+append_call_sample_docstring(FlaxBloomModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)
+
+
+class FlaxBloomForCausalLMModule(nn.Module):
+ config: BloomConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.transformer = FlaxBloomModule(self.config, dtype=self.dtype)
+ self.lm_head = nn.Dense(
+ self.config.vocab_size,
+ use_bias=False,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ if self.config.tie_word_embeddings:
+ shared_kernel = self.transformer.variables["params"]["word_embeddings"]["embedding"].T
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
+ else:
+ lm_logits = self.lm_head(hidden_states)
+
+ if not return_dict:
+ return (lm_logits,) + outputs[1:]
+
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
+
+
+@add_start_docstrings(
+ """
+ The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """,
+ BLOOM_START_DOCSTRING,
+)
+class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel):
+ module_class = FlaxBloomForCausalLMModule
+
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
+ # initializing the cache
+ batch_size, seq_length = input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length)
+ # Note that usually one would have to put 0's in the attention_mask for
+ # x > input_ids.shape[-1] and x < cache_length. But since Bloom uses a causal mask,
+ # those positions are masked anyway. Thus, we can create a single static attention_mask here,
+ # which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if attention_mask is not None:
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
+
+ return {
+ "past_key_values": past_key_values,
+ "attention_mask": extended_attention_mask,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ return model_kwargs
+
+
+append_call_sample_docstring(FlaxBloomForCausalLM, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC)
+
+
+__all__ = ["FlaxBloomForCausalLM", "FlaxBloomModel", "FlaxBloomPreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bloom/tokenization_bloom_fast.py b/venv/lib/python3.13/site-packages/transformers/models/bloom/tokenization_bloom_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7a9f7449a4e9a5c2306ba4853ff5c9d99c00a20
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bloom/tokenization_bloom_fast.py
@@ -0,0 +1,152 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for Bloom."""
+
+import pickle
+from typing import Optional
+
+from ...tokenization_utils_base import BatchEncoding
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"}
+
+
+class BloomTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" Bloom tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
+ Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```python
+ >>> from transformers import BloomTokenizerFast
+
+ >>> tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom")
+ >>> tokenizer("Hello world")["input_ids"]
+ [59414, 8876]
+
+ >>> tokenizer(" Hello world")["input_ids"]
+ [86153, 8876]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
+ the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The beginning of sequence token.
+ eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The end of sequence token.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (Bloom tokenizer detect beginning of words by the preceding space).
+ trim_offsets (`bool`, *optional*, defaults to `True`):
+ Whether or not the post-processing step should trim offsets to avoid including whitespaces.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = None
+ # No `max_model_input_sizes` as BLOOM uses ALiBi positional embeddings
+
+ def __init__(
+ self,
+ vocab_file=None,
+ merges_file=None,
+ tokenizer_file=None,
+ unk_token="",
+ bos_token="",
+ eos_token="",
+ pad_token="",
+ add_prefix_space=False,
+ clean_up_tokenization_spaces=False,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file=vocab_file,
+ merges_file=merges_file,
+ tokenizer_file=tokenizer_file,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ add_prefix_space=add_prefix_space,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+ # TODO @ArthurZucker this can only work one way for now, to update later-on. Tests should also properly
+ # check this as they were green before.
+ pre_tok_state = pickle.dumps(self.backend_tokenizer.pre_tokenizer)
+ decoder_state = pickle.dumps(self.backend_tokenizer.decoder)
+
+ if add_prefix_space:
+ pre_tok_state = pre_tok_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
+ decoder_state = decoder_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
+ self.backend_tokenizer.pre_tokenizer = pickle.loads(pre_tok_state)
+ self.backend_tokenizer.decoder = pickle.loads(decoder_state)
+
+ self.add_prefix_space = add_prefix_space
+
+ def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+ if not (self.add_prefix_space or not is_split_into_words):
+ raise Exception(
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
+ " pretokenized inputs."
+ )
+
+ return super()._batch_encode_plus(*args, **kwargs)
+
+ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+
+ if not (self.add_prefix_space or not is_split_into_words):
+ raise Exception(
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
+ " pretokenized inputs."
+ )
+
+ return super()._encode_plus(*args, **kwargs)
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+
+__all__ = ["BloomTokenizerFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bridgetower/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/bridgetower/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ca84a320fdc4aa1f4cf99aef78154849514c8a6
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bridgetower/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_bridgetower import *
+ from .image_processing_bridgetower import *
+ from .image_processing_bridgetower_fast import *
+ from .modeling_bridgetower import *
+ from .processing_bridgetower import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bridgetower/configuration_bridgetower.py b/venv/lib/python3.13/site-packages/transformers/models/bridgetower/configuration_bridgetower.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c84b0a294dafb35c991654c6e7df79c3fe58452
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bridgetower/configuration_bridgetower.py
@@ -0,0 +1,308 @@
+# coding=utf-8
+# Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License=, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing=, software
+# distributed under the License is distributed on an "AS IS" BASIS=,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND=, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""BridgeTower model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class BridgeTowerVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the vision configuration of a [`BridgeTowerModel`]. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the bridgetower-base
+ [BridgeTower/bridgetower-base](https://huggingface.co/BridgeTower/bridgetower-base/) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in visual encoder model.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ image_size (`int`, *optional*, defaults to 288):
+ The size (resolution) of each image.
+ initializer_factor (`float`, *optional*, defaults to 1):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ stop_gradient (`bool`, *optional*, defaults to `False`):
+ Whether to stop gradient for training.
+ share_layernorm (`bool`, *optional*, defaults to `True`):
+ Whether LayerNorm layers are shared.
+ remove_last_layer (`bool`, *optional*, defaults to `False`):
+ Whether to remove the last layer from the vision encoder.
+
+
+ Example:
+
+ ```python
+ >>> from transformers import BridgeTowerVisionConfig
+
+ >>> # Initializing a BridgeTower BridgeTower/bridgetower-base style configuration for the vision model
+ >>> configuration = BridgeTowerVisionConfig()
+
+ >>> # Accessing the configuration
+ >>> configuration
+ ```"""
+
+ model_type = "bridgetower_vision_model"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_channels=3,
+ patch_size=16,
+ image_size=288,
+ initializer_factor=1,
+ layer_norm_eps=1e-05,
+ stop_gradient=False,
+ share_layernorm=True,
+ remove_last_layer=False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.initializer_factor = initializer_factor
+ self.layer_norm_eps = layer_norm_eps
+ self.stop_gradient = stop_gradient
+ self.share_layernorm = share_layernorm
+ self.remove_last_layer = remove_last_layer
+
+
+class BridgeTowerTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the text configuration of a [`BridgeTowerModel`]. The default values here
+ are copied from RoBERTa. Instantiating a configuration with the defaults will yield a similar configuration to that
+ of the bridgetower-base [BridegTower/bridgetower-base](https://huggingface.co/BridgeTower/bridgetower-base/)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50265):
+ Vocabulary size of the text part of the model. Defines the number of different tokens that can be
+ represented by the `inputs_ids` passed when calling [`BridgeTowerModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 514):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids`.
+ initializer_factor (`float`, *optional*, defaults to 1):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
+ is_decoder (`bool`, *optional*, defaults to `False`):
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+
+ Example:
+
+ ```python
+ >>> from transformers import BridgeTowerTextConfig
+
+ >>> # Initializing a BridgeTower BridgeTower/bridgetower-base style configuration for the text model
+ >>> configuration = BridgeTowerTextConfig()
+
+ >>> # Accessing the configuration
+ >>> configuration
+ ```"""
+
+ model_type = "bridgetower_text_model"
+ base_config_key = "text_config"
+
+ def __init__(
+ self,
+ vocab_size=50265,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ initializer_factor=1,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=514,
+ type_vocab_size=1,
+ layer_norm_eps=1e-05,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ position_embedding_type="absolute",
+ use_cache=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.initializer_factor = initializer_factor
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.layer_norm_eps = layer_norm_eps
+ self.position_embedding_type = position_embedding_type
+ self.use_cache = use_cache
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+
+
+class BridgeTowerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`BridgeTowerModel`]. It is used to instantiate a
+ BridgeTower model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the bridgetower-base
+ [BridgeTower/bridgetower-base](https://huggingface.co/BridgeTower/bridgetower-base/) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ share_cross_modal_transformer_layers (`bool`, *optional*, defaults to `True`):
+ Whether cross modal transformer layers are shared.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler.
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ initializer_factor (`float`, *optional*, defaults to 1):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ share_link_tower_layers (`bool`, *optional*, defaults to `False`):
+ Whether the bride/link tower layers are shared.
+ link_tower_type (`str`, *optional*, defaults to `"add"`):
+ Type of the bridge/link layer.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 6):
+ Number of hidden layers in the Transformer encoder.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie input and output embeddings.
+ init_layernorm_from_vision_encoder (`bool`, *optional*, defaults to `False`):
+ Whether to init LayerNorm from the vision encoder.
+ text_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`BridgeTowerTextConfig`].
+ vision_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`BridgeTowerVisionConfig`].
+
+ Example:
+
+ ```python
+ >>> from transformers import BridgeTowerModel, BridgeTowerConfig
+
+ >>> # Initializing a BridgeTower BridgeTower/bridgetower-base style configuration
+ >>> configuration = BridgeTowerConfig()
+
+ >>> # Initializing a model from the BridgeTower/bridgetower-base style configuration
+ >>> model = BridgeTowerModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "bridgetower"
+ sub_configs = {"text_config": BridgeTowerTextConfig, "vision_config": BridgeTowerVisionConfig}
+
+ def __init__(
+ self,
+ share_cross_modal_transformer_layers=True,
+ hidden_act="gelu",
+ hidden_size=768,
+ initializer_factor=1,
+ layer_norm_eps=1e-05,
+ share_link_tower_layers=False,
+ link_tower_type="add",
+ num_attention_heads=12,
+ num_hidden_layers=6,
+ tie_word_embeddings=False,
+ init_layernorm_from_vision_encoder=False,
+ text_config=None,
+ vision_config=None,
+ **kwargs,
+ ):
+ # TODO: remove this once the Hub files are updated.
+ _ = kwargs.pop("text_config_dict", None)
+ _ = kwargs.pop("vision_config_dict", None)
+
+ super().__init__(**kwargs)
+ self.share_cross_modal_transformer_layers = share_cross_modal_transformer_layers
+ self.hidden_act = hidden_act
+ self.hidden_size = hidden_size
+ self.initializer_factor = initializer_factor
+ self.layer_norm_eps = layer_norm_eps
+ self.share_link_tower_layers = share_link_tower_layers
+ self.link_tower_type = link_tower_type
+ self.num_attention_heads = num_attention_heads
+ self.num_hidden_layers = num_hidden_layers
+ self.tie_word_embeddings = tie_word_embeddings
+ self.init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder
+
+ if text_config is None:
+ text_config = {}
+ logger.info("`text_config` is `None`. Initializing the `BridgeTowerTextConfig` with default values.")
+
+ if vision_config is None:
+ vision_config = {}
+ logger.info("`vision_config` is `None`. Initializing the `BridgeTowerVisionConfig` with default values.")
+
+ self.text_config = BridgeTowerTextConfig(**text_config)
+ self.vision_config = BridgeTowerVisionConfig(**vision_config)
+
+
+__all__ = ["BridgeTowerConfig", "BridgeTowerTextConfig", "BridgeTowerVisionConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bridgetower/image_processing_bridgetower.py b/venv/lib/python3.13/site-packages/transformers/models/bridgetower/image_processing_bridgetower.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb39ed0975610f65f2b44725ffbbbb43a0c18f3f
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bridgetower/image_processing_bridgetower.py
@@ -0,0 +1,541 @@
+# coding=utf-8
+# Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for BridgeTower."""
+
+from collections.abc import Iterable
+from typing import Any, Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import PaddingMode, center_crop, pad, resize, to_channel_dimension_format
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+
+
+if is_vision_available():
+ import PIL
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from transformers.models.vilt.image_processing_vilt.max_across_indices
+def max_across_indices(values: Iterable[Any]) -> list[Any]:
+ """
+ Return the maximum value across all indices of an iterable of values.
+ """
+ return [max(values_i) for values_i in zip(*values)]
+
+
+# Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask
+def make_pixel_mask(
+ image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> np.ndarray:
+ """
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+
+ Args:
+ image (`np.ndarray`):
+ Image to make the pixel mask for.
+ output_size (`tuple[int, int]`):
+ Output size of the mask.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ mask = np.zeros(output_size, dtype=np.int64)
+ mask[:input_height, :input_width] = 1
+ return mask
+
+
+# Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width
+def get_max_height_width(
+ images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> list[int]:
+ """
+ Get the maximum height and width across all images in a batch.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if input_data_format == ChannelDimension.FIRST:
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
+ elif input_data_format == ChannelDimension.LAST:
+ max_height, max_width, _ = max_across_indices([img.shape for img in images])
+ else:
+ raise ValueError(f"Invalid channel dimension format: {input_data_format}")
+ return (max_height, max_width)
+
+
+# Copied from transformers.models.vilt.image_processing_vilt.get_resize_output_image_size
+def get_resize_output_image_size(
+ input_image: np.ndarray,
+ shorter: int = 800,
+ longer: int = 1333,
+ size_divisor: int = 32,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> tuple[int, int]:
+ input_height, input_width = get_image_size(input_image, input_data_format)
+ min_size, max_size = shorter, longer
+
+ scale = min_size / min(input_height, input_width)
+
+ if input_height < input_width:
+ new_height = min_size
+ new_width = scale * input_width
+ else:
+ new_height = scale * input_height
+ new_width = min_size
+
+ if max(new_height, new_width) > max_size:
+ scale = max_size / max(new_height, new_width)
+ new_height = scale * new_height
+ new_width = scale * new_width
+
+ new_height, new_width = int(new_height + 0.5), int(new_width + 0.5)
+ new_height = new_height // size_divisor * size_divisor
+ new_width = new_width // size_divisor * size_divisor
+
+ return new_height, new_width
+
+
+class BridgeTowerImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a BridgeTower image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`dict[str, int]` *optional*, defaults to `{'shortest_edge': 288}`):
+ Resize the shorter side of the input to `size["shortest_edge"]`. The longer side will be limited to under
+ `int((1333 / 800) * size["shortest_edge"])` while preserving the aspect ratio. Only has an effect if
+ `do_resize` is set to `True`. Can be overridden by the `size` parameter in the `preprocess` method.
+ size_divisor (`int`, *optional*, defaults to 32):
+ The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
+ is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
+ overridden by the `resample` parameter in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image. Can be overridden by the `do_center_crop` parameter in the `preprocess`
+ method.
+ crop_size (`dict[str, int]`, *optional*):
+ Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
+ Can be overridden by the `crop_size` parameter in the `preprocess` method. If unset defaults to `size`,
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by
+ the `do_pad` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values", "pixel_mask"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ size_divisor: int = 32,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_center_crop: bool = True,
+ crop_size: Optional[dict[str, int]] = None,
+ do_pad: bool = True,
+ **kwargs,
+ ) -> None:
+ if "pad_and_return_pixel_mask" in kwargs:
+ do_pad = kwargs.pop("pad_and_return_pixel_mask")
+
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 288}
+ size = get_size_dict(size, default_to_square=False)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.size_divisor = size_divisor
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+ self.do_pad = do_pad
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+
+ # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor.resize
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ size_divisor: int = 32,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image.
+
+ Resizes the shorter side of the image to `size["shortest_edge"]` while preserving the aspect ratio. If the
+ longer side is larger than the max size `(int(`size["shortest_edge"]` * 1333 / 800))`, the longer side is then
+ resized to the max size while preserving the aspect ratio.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Controls the size of the output image. Should be of the form `{"shortest_edge": int}`.
+ size_divisor (`int`, *optional*, defaults to 32):
+ The image is resized to a size that is a multiple of this value.
+ resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ size = get_size_dict(size, default_to_square=False)
+ if "shortest_edge" not in size:
+ raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
+ shorter = size["shortest_edge"]
+ longer = int(1333 / 800 * shorter)
+ output_size = get_resize_output_image_size(
+ image, shorter=shorter, longer=longer, size_divisor=size_divisor, input_data_format=input_data_format
+ )
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def center_crop(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
+ any edge, the image is padded with 0's and then center cropped.
+
+ Args:
+ image (`np.ndarray`):
+ Image to center crop.
+ size (`dict[str, int]`):
+ Size of the output image in the form `{"height": h, "width": w}`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input
+ image.
+ """
+ output_size = size["shortest_edge"]
+ return center_crop(
+ image,
+ size=(output_size, output_size),
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor._pad_image
+ def _pad_image(
+ self,
+ image: np.ndarray,
+ output_size: tuple[int, int],
+ constant_values: Union[float, Iterable[float]] = 0,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pad an image with zeros to the given size.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ output_height, output_width = output_size
+
+ pad_bottom = output_height - input_height
+ pad_right = output_width - input_width
+ padding = ((0, pad_bottom), (0, pad_right))
+ padded_image = pad(
+ image,
+ padding,
+ mode=PaddingMode.CONSTANT,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ return padded_image
+
+ # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor.pad
+ def pad(
+ self,
+ images: list[np.ndarray],
+ constant_values: Union[float, Iterable[float]] = 0,
+ return_pixel_mask: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> BatchFeature:
+ """
+ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
+ in the batch and optionally returns their corresponding pixel mask.
+
+ Args:
+ image (`np.ndarray`):
+ Image to pad.
+ constant_values (`float` or `Iterable[float]`, *optional*):
+ The value to use for the padding if `mode` is `"constant"`.
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
+ Whether to return a pixel mask.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ pad_size = get_max_height_width(images, input_data_format=input_data_format)
+
+ padded_images = [
+ self._pad_image(
+ image,
+ pad_size,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ for image in images
+ ]
+ data = {"pixel_values": padded_images}
+
+ if return_pixel_mask:
+ masks = [
+ make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
+ for image in images
+ ]
+ data["pixel_mask"] = masks
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ size_divisor: Optional[int] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[dict[str, int]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Controls the size of the image after `resize`. The shortest edge of the image is resized to
+ `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
+ is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
+ edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
+ size_divisor (`int`, *optional*, defaults to `self.size_divisor`):
+ The image is resized to a size that is a multiple of this value.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to normalize the image by if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image to the (max_height, max_width) in the batch. If `True`, a pixel mask is also
+ created and returned.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the
+ image is padded with 0's and then center cropped.
+ crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
+ padded with zeros and then cropped
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size_divisor = size_divisor if size_divisor is not None else self.size_divisor
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ # For backwards compatibility. Initial version of this processor was cropping to the "size" argument, which
+ # it should default to if crop_size is undefined.
+ crop_size = (
+ crop_size if crop_size is not None else (self.crop_size if self.crop_size is not None else self.size)
+ )
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ # Here, crop_size is used only if it is set, else size will be used.
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if do_resize:
+ images = [
+ self.resize(
+ image=image,
+ size=size,
+ size_divisor=size_divisor,
+ resample=resample,
+ input_data_format=input_data_format,
+ )
+ for image in images
+ ]
+
+ if do_center_crop:
+ images = [
+ self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
+ ]
+
+ if do_rescale:
+ images = [
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_normalize:
+ images = [
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+ ]
+
+ if do_pad:
+ encoded_outputs = self.pad(
+ images, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=data_format
+ )
+ else:
+ encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
+
+ return encoded_outputs
+
+
+__all__ = ["BridgeTowerImageProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bridgetower/image_processing_bridgetower_fast.py b/venv/lib/python3.13/site-packages/transformers/models/bridgetower/image_processing_bridgetower_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..5be6f9f6c54b7bf6e973b9102179b63cbfe353d8
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bridgetower/image_processing_bridgetower_fast.py
@@ -0,0 +1,280 @@
+# coding=utf-8
+# Copyright 2025 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for BridgeTower."""
+
+from collections.abc import Iterable
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ BatchFeature,
+ DefaultFastImageProcessorKwargs,
+ ImageInput,
+ SizeDict,
+ TensorType,
+ Unpack,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
+from ...utils import auto_docstring
+
+
+def make_pixel_mask(
+ image: "torch.Tensor",
+ output_size: tuple[int, int],
+) -> "torch.Tensor":
+ """
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+
+ Args:
+ image (`np.ndarray`):
+ Image to make the pixel mask for.
+ output_size (`tuple[int, int]`):
+ Output size of the mask.
+ """
+ input_height, input_width = image.shape[-2:]
+ batch_size = image.size(0)
+ mask = torch.zeros((batch_size, *output_size), dtype=torch.long)
+ mask[:input_height, :input_width] = 1
+ return mask
+
+
+def get_resize_output_image_size(
+ input_image: "torch.Tensor",
+ shorter: int = 800,
+ longer: int = 1333,
+ size_divisor: int = 32,
+) -> tuple[int, int]:
+ input_height, input_width = input_image.shape[-2:]
+ min_size, max_size = shorter, longer
+
+ scale = min_size / min(input_height, input_width)
+
+ if input_height < input_width:
+ new_height = min_size
+ new_width = scale * input_width
+ else:
+ new_height = scale * input_height
+ new_width = min_size
+
+ if max(new_height, new_width) > max_size:
+ scale = max_size / max(new_height, new_width)
+ new_height = scale * new_height
+ new_width = scale * new_width
+
+ new_height, new_width = int(new_height + 0.5), int(new_width + 0.5)
+ new_height = new_height // size_divisor * size_divisor
+ new_width = new_width // size_divisor * size_divisor
+
+ return new_height, new_width
+
+
+class BridgeTowerFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ Args:
+ size_divisor (`int`, *optional*, defaults to 32):
+ The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
+ is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
+ """
+
+ size_divisor: Optional[int]
+
+
+@auto_docstring
+class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ size = {"shortest_edge": 288}
+ default_to_square = False
+ crop_size = {"shortest_edge": 288}
+ do_resize = True
+ do_center_crop = True
+ do_rescale = True
+ do_normalize = True
+ do_pad = True
+ size_divisor = 32
+ valid_kwargs = BridgeTowerFastImageProcessorKwargs
+ model_input_names = ["pixel_values", "pixel_mask"]
+
+ def __init__(self, **kwargs: Unpack[BridgeTowerFastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[BridgeTowerFastImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def resize(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ size_divisor: int = 32,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ antialias: bool = True,
+ **kwargs,
+ ) -> "torch.Tensor":
+ """
+ Resize an image.
+
+ Resizes the shorter side of the image to `size["shortest_edge"]` while preserving the aspect ratio. If the
+ longer side is larger than the max size `(int(`size["shortest_edge"]` * 1333 / 800))`, the longer side is then
+ resized to the max size while preserving the aspect ratio.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to resize.
+ size (`SizeDict`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ size_divisor (`int`, *optional*, defaults to 32):
+ The image is resized to a size that is a multiple of this value.
+ resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
+
+ Returns:
+ `torch.Tensor`: The resized image.
+ """
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
+ if not size.shortest_edge:
+ raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
+ shorter = size.shortest_edge
+ longer = int(1333 / 800 * shorter)
+ output_height, output_width = get_resize_output_image_size(
+ image,
+ shorter=shorter,
+ longer=longer,
+ size_divisor=size_divisor,
+ )
+ return super().resize(
+ image=image,
+ size=SizeDict(height=output_height, width=output_width),
+ interpolation=interpolation,
+ antialias=antialias,
+ )
+
+ def center_crop(
+ self,
+ image: "torch.Tensor",
+ size: dict[str, int],
+ **kwargs,
+ ) -> "torch.Tensor":
+ """
+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
+ any edge, the image is padded with 0's and then center cropped.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to center crop.
+ size (`dict[str, int]`):
+ Size of the output image in the form `{"height": h, "width": w}`.
+ """
+ output_size = size.shortest_edge
+ return F.center_crop(
+ image,
+ output_size=(output_size, output_size),
+ **kwargs,
+ )
+
+ def _pad_image(
+ self,
+ image: "torch.Tensor",
+ output_size: tuple[int, int],
+ constant_values: Union[float, Iterable[float]] = 0,
+ ) -> "torch.Tensor":
+ """
+ Pad an image with zeros to the given size.
+ """
+ input_height, input_width = image.shape[-2:]
+ output_height, output_width = output_size
+
+ pad_bottom = output_height - input_height
+ pad_right = output_width - input_width
+ padding = (0, 0, pad_right, pad_bottom)
+ padded_image = F.pad(
+ image,
+ padding,
+ fill=constant_values,
+ )
+ return padded_image
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ size_divisor: Optional[int],
+ interpolation: Optional["F.InterpolationMode"],
+ do_pad: bool,
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(
+ image=stacked_images, size=size, size_divisor=size_divisor, interpolation=interpolation
+ )
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_center_crop:
+ stacked_images = self.center_crop(stacked_images, crop_size)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+
+ data = {}
+ if do_pad:
+ processed_images, processed_masks = self.pad(
+ processed_images, return_mask=True, disable_grouping=disable_grouping
+ )
+ processed_masks = torch.stack(processed_masks, dim=0) if return_tensors else processed_masks
+ data["pixel_mask"] = processed_masks
+
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+ data["pixel_values"] = processed_images
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def to_dict(self):
+ encoder_dict = super().to_dict()
+ encoder_dict.pop("_valid_processor_keys", None)
+ encoder_dict.pop("crop_size", None)
+ return encoder_dict
+
+
+__all__ = ["BridgeTowerImageProcessorFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bridgetower/modeling_bridgetower.py b/venv/lib/python3.13/site-packages/transformers/models/bridgetower/modeling_bridgetower.py
new file mode 100644
index 0000000000000000000000000000000000000000..59c5be00c3169996c5f3490491dc262022817dc3
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bridgetower/modeling_bridgetower.py
@@ -0,0 +1,1875 @@
+# coding=utf-8
+# Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BridgeTower Model"""
+
+import math
+from collections import OrderedDict
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN, QuickGELUActivation
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ MaskedLMOutput,
+ ModelOutput,
+ SequenceClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import auto_docstring, logging, torch_int
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+_TOKENIZER_FOR_DOC = "RobertaTokenizer"
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`BridgeTowerModel`].
+ """
+)
+class BridgeTowerModelOutput(ModelOutput):
+ r"""
+ text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_size)`):
+ Sequence of hidden-states at the text output of the last layer of the model.
+ image_features (`torch.FloatTensor` of shape `(batch_size, image_sequence_length, hidden_size)`):
+ Sequence of hidden-states at the image output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size x 2)`):
+ Concatenation of last layer hidden-state of the first token of the text and image sequence (classification
+ token), respectively, after further processing through layers used for auxiliary pretraining tasks.
+ """
+
+ text_features: Optional[torch.FloatTensor] = None
+ image_features: Optional[torch.FloatTensor] = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of ['BridgeTowerForContrastiveLearning']
+ """
+)
+class BridgeTowerContrastiveOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
+ Image-text contrastive loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
+ The text embeddings obtained by applying the projection layer to the pooler_output.
+ image_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to the pooler_output.
+ cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
+ The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ text_embeds: Optional[tuple[torch.FloatTensor]] = None
+ image_embeds: Optional[tuple[torch.FloatTensor]] = None
+ cross_embeds: Optional[tuple[torch.FloatTensor]] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+class BridgeTowerResidualAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(config.hidden_size, config.hidden_size // 64)
+ self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = nn.ModuleDict(
+ OrderedDict(
+ [
+ ("c_fc", nn.Linear(config.hidden_size, config.hidden_size * 4)),
+ ("gelu", QuickGELUActivation()),
+ ("c_proj", nn.Linear(config.hidden_size * 4, config.hidden_size)),
+ ]
+ )
+ )
+ self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attn_mask = None
+
+ def attention(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor):
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_state.device)
+ self.attn_mask = (
+ self.attn_mask.to(dtype=hidden_state.dtype, device=hidden_state.device)
+ if self.attn_mask is not None
+ else None
+ )
+ return self.attn(
+ hidden_state,
+ hidden_state,
+ hidden_state,
+ need_weights=False,
+ attn_mask=self.attn_mask,
+ key_padding_mask=attention_mask,
+ )[0]
+
+ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
+ residual_state = hidden_state + self.attention(self.ln_1(hidden_state), attention_mask)
+ hidden_state = self.ln_2(residual_state)
+ for layer in self.mlp.values():
+ hidden_state = layer(hidden_state)
+ hidden_state = residual_state + hidden_state
+ return hidden_state
+
+
+class BridgeTowerTransformer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_hidden_layers = config.num_hidden_layers
+ if config.remove_last_layer:
+ self.resblocks = nn.ModuleList(
+ [BridgeTowerResidualAttention(config) for _ in range(self.num_hidden_layers - 1)]
+ )
+ else:
+ self.resblocks = nn.ModuleList(
+ [BridgeTowerResidualAttention(config) for _ in range(self.num_hidden_layers)]
+ )
+ self.stop_gradient = config.stop_gradient
+
+ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
+ hidden_states = []
+ for block in self.resblocks:
+ hidden_state = block(hidden_state, attention_mask)
+ if self.stop_gradient:
+ hidden_states.append(hidden_state.detach())
+ else:
+ hidden_states.append(hidden_state)
+ return hidden_states
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->BridgeTower
+class BridgeTowerVisionEmbeddings(nn.Module):
+ def __init__(self, config: BridgeTowerVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=False,
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ position_embedding = self.position_embedding.weight.unsqueeze(0)
+ num_positions = position_embedding.shape[1] - 1
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embedding(self.position_ids)
+
+ class_pos_embed = position_embedding[:, :1]
+ patch_pos_embed = position_embedding[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
+ )
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+class BridgeTowerVisionTransformer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.embeddings = BridgeTowerVisionEmbeddings(config)
+ self.ln_pre = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.transformer = BridgeTowerTransformer(config)
+ self.ln_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.share_layernorm = config.share_layernorm
+ if not config.share_layernorm:
+ self.ln_separate = nn.ModuleList(
+ [nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)]
+ )
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ attention_mask,
+ interpolate_pos_encoding: bool = False,
+ ):
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding)
+ hidden_states = self.ln_pre(hidden_states)
+ # NLD -> LND
+ hidden_states = hidden_states.permute(1, 0, 2)
+
+ hidden_states = self.transformer(hidden_states, attention_mask)
+ # shape = [num_hidden_layers, hidden_size, *, grid ** 2]
+ hidden_states = torch.stack(hidden_states, dim=0)
+ # shape = [num_hidden_layers, *, hidden_size, grid ** 2]
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
+ if self.share_layernorm:
+ hidden_states = self.ln_post(hidden_states)
+ else:
+ hidden_states_stack = []
+ for hidden_states, ln in zip(hidden_states, self.ln_separate):
+ hidden_states = ln(hidden_states)
+ hidden_states_stack.append(hidden_states)
+ # shape = [num_hidden_layers, *, hidden_size, grid ** 2]
+ hidden_states = torch.stack(hidden_states_stack, dim=0)
+ return hidden_states
+
+ def forward_pre(
+ self,
+ pixel_values: torch.Tensor,
+ interpolate_pos_encoding: bool = False,
+ ):
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+ hidden_states = self.ln_pre(hidden_states)
+ # NLD -> LND
+ hidden_states = hidden_states.permute(1, 0, 2)
+ return hidden_states
+
+ def forward_post(self, hidden_state: torch.Tensor):
+ visual_output_post = hidden_state.permute(1, 0, 2)
+ visual_output_post = self.ln_post(visual_output_post)
+ return visual_output_post
+
+
+class BridgeTowerLinkTower(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.link_tower_type = config.link_tower_type
+ self.hidden_size = config.hidden_size
+ if config.link_tower_type in ["add", "scaled_add", "interpolate"]:
+ if config.link_tower_type == "scaled_add":
+ self.scaled_factor = nn.Parameter(torch.tensor(1.0))
+ elif config.link_tower_type == "interpolate":
+ self.beta = nn.Parameter(torch.tensor(0.5))
+ self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
+ else:
+ raise NotImplementedError(f"link_tower_type {config.link_tower_type} is not implemented")
+
+ def forward(self, hidden_states, cross_modal_hidden_states, attention_mask):
+ if self.link_tower_type == "add":
+ return self.LayerNorm(hidden_states + cross_modal_hidden_states)
+ elif self.link_tower_type == "scaled_add":
+ return self.LayerNorm(hidden_states * self.scaled_factor + cross_modal_hidden_states)
+ elif self.link_tower_type == "interpolate":
+ return self.LayerNorm(hidden_states * (1 - self.beta) + cross_modal_hidden_states * self.beta)
+ else:
+ raise NotImplementedError(f"link_tower_type {self.link_tower_type} is not implemented")
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->BridgeTower
+class BridgeTowerSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->BridgeTower
+class BridgeTowerIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->BridgeTower
+class BridgeTowerOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->BridgeTower
+class BridgeTowerPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->BridgeTower
+class BridgeTowerSelfAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None, layer_idx=None):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = position_embedding_type or getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+ self.layer_idx = layer_idx
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor]:
+ batch_size, seq_length, _ = hidden_states.shape
+ query_layer = self.query(hidden_states)
+ query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
+ 1, 2
+ )
+
+ is_updated = False
+ is_cross_attention = encoder_hidden_states is not None
+ if past_key_values is not None:
+ if isinstance(past_key_values, EncoderDecoderCache):
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_layer from cache
+ curr_past_key_value = past_key_values.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_values.self_attention_cache
+ else:
+ curr_past_key_value = past_key_values
+
+ current_states = encoder_hidden_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_values is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_layer = curr_past_key_value.layers[self.layer_idx].keys
+ value_layer = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_layer = self.key(current_states)
+ key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
+ 1, 2
+ )
+ value_layer = self.value(current_states)
+ value_layer = value_layer.view(
+ batch_size, -1, self.num_attention_heads, self.attention_head_size
+ ).transpose(1, 2)
+
+ if past_key_values is not None:
+ # save all key/value_layer to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_layer, value_layer = curr_past_key_value.update(
+ key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
+ past_key_values.is_updated[self.layer_idx] = True
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
+ if past_key_values is not None:
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
+ -1, 1
+ )
+ else:
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BridgeTowerModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ return context_layer, attention_probs
+
+
+BRIDGE_TOWER_SELF_ATTENTION_CLASSES = {
+ "eager": BridgeTowerSelfAttention,
+}
+
+
+# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower,BERT->BRIDGE_TOWER
+class BridgeTowerAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None, layer_idx=None):
+ super().__init__()
+ self.self = BRIDGE_TOWER_SELF_ATTENTION_CLASSES[config._attn_implementation](
+ config,
+ position_embedding_type=position_embedding_type,
+ layer_idx=layer_idx,
+ )
+ self.output = BridgeTowerSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor]:
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BridgeTowerBertCrossLayer(nn.Module):
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BridgeTowerAttention(config, layer_idx=layer_idx)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ self.crossattention = BridgeTowerAttention(config, layer_idx=layer_idx)
+ self.intermediate = BridgeTowerIntermediate(config)
+ self.output = BridgeTowerOutput(config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states,
+ encoder_hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=None,
+ output_attentions=output_attentions,
+ past_key_values=None,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ # add self attentions if we output attention weights
+ outputs = self_attention_outputs[1:]
+
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ attention_output = cross_attention_outputs[0]
+ # add cross attentions if we output attention weights
+ outputs = outputs + cross_attention_outputs[1:]
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BridgeTowerTextLayer(GradientCheckpointingLayer):
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BridgeTowerAttention(config, layer_idx=layer_idx)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = BridgeTowerAttention(config, position_embedding_type="absolute", layer_idx=layer_idx)
+ self.intermediate = BridgeTowerIntermediate(config)
+ self.output = BridgeTowerOutput(config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor]:
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ return (layer_output,) + outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->BridgeTowerText
+class BridgeTowerTextEncoder(nn.Module):
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList(
+ [BridgeTowerTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
+ )
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if use_cache and self.config.is_decoder and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+
+ if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ past_key_values,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->BridgeTowerText
+class BridgeTowerTextEmbeddings(nn.Module):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+ """
+
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+ self.register_buffer(
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
+ )
+
+ # End copy
+ self.padding_idx = config.pad_token_id
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
+ )
+
+ def forward(
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
+ else:
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+ """
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+
+ Args:
+ inputs_embeds: torch.Tensor
+
+ Returns: torch.Tensor
+ """
+ input_shape = inputs_embeds.size()[:-1]
+ sequence_length = input_shape[1]
+
+ position_ids = torch.arange(
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
+ )
+ return position_ids.unsqueeze(0).expand(input_shape)
+
+
+# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
+def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+ are ignored. This is modified from fairseq's `utils.make_positions`.
+
+ Args:
+ x: torch.Tensor x:
+
+ Returns: torch.Tensor
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = input_ids.ne(padding_idx).int()
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
+ return incremental_indices.long() + padding_idx
+
+
+@auto_docstring
+class BridgeTowerPreTrainedModel(PreTrainedModel):
+ config: BridgeTowerConfig
+ base_model_prefix = "bridgetower"
+ supports_gradient_checkpointing = False
+ _no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
+ _skip_keys_device_placement = "past_key_values"
+
+ def _init_weights(self, module: nn.Module):
+ std = self.config.initializer_factor
+ if isinstance(module, BridgeTowerVisionTransformer):
+ proj_std = (self.config.hidden_size**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5)
+ attn_std = self.config.hidden_size**-0.5
+ fc_std = (2 * self.config.hidden_size) ** -0.5
+ for block in module.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std * std)
+ block.attn.in_proj_bias.data.zero_()
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std * std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std * std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std * std)
+
+ nn.init.normal_(module.embeddings.class_embedding, std=attn_std * std)
+ nn.init.normal_(module.embeddings.position_embedding.weight, std=attn_std * std)
+ elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.05 * std)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, BridgeTowerForContrastiveLearning):
+ module.logit_scale.data.fill_(self.config.logit_scale_init_value)
+
+ if isinstance(module, (nn.Linear, BridgeTowerMLMHead)) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
+ config: BridgeTowerVisionConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = BridgeTowerVisionTransformer(config)
+
+ @property
+ def dtype(self):
+ return self.visual.embeddings.patch_embedding.weight.dtype
+
+ def forward(self, image, image_mask=None, interpolate_pos_encoding=False):
+ return self.visual(image.type(self.dtype), image_mask, interpolate_pos_encoding)
+
+
+@auto_docstring(
+ custom_intro="""
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in *Attention is
+ all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
+ Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+
+ .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
+ """
+)
+class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
+ config: BridgeTowerTextConfig
+
+ def __init__(self, config, add_pooling_layer=True):
+ r"""
+ add_pooling_layer (bool, *optional*, defaults to `True`):
+ Whether to add a pooling layer
+ """
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BridgeTowerTextEmbeddings(config)
+ self.encoder = BridgeTowerTextEncoder(config)
+
+ self.pooler = BridgeTowerPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ past_key_values_length = 0
+ if past_key_values is not None:
+ past_key_values_length = (
+ past_key_values[0][0].shape[-2]
+ if not isinstance(past_key_values, Cache)
+ else past_key_values.get_seq_length()
+ )
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The bare BridgeTower Model transformer outputting BridgeTowerModelOutput object without any specific head on
+ """
+)
+class BridgeTowerModel(BridgeTowerPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+ vision_config = config.vision_config
+ text_config = config.text_config
+
+ if config.share_cross_modal_transformer_layers:
+ self.cross_modal_text_transform = nn.Linear(text_config.hidden_size, config.hidden_size)
+ self.cross_modal_image_transform = nn.Linear(vision_config.hidden_size, config.hidden_size)
+ else:
+ self.cross_modal_text_transform = nn.ModuleList(
+ [nn.Linear(text_config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)]
+ )
+ self.cross_modal_image_transform = nn.ModuleList(
+ [nn.Linear(vision_config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)]
+ )
+
+ self.token_type_embeddings = nn.Embedding(2, config.hidden_size)
+
+ self.vision_model = BridgeTowerVisionModel(vision_config)
+
+ self.text_model = BridgeTowerTextModel(text_config)
+
+ if not vision_config.share_layernorm and config.init_layernorm_from_vision_encoder:
+ for ln in self.vision_model.visual.cross_modal_ln_separate:
+ ln.weight.data = self.vision_model.visual.ln_post.weight.data
+ ln.bias.data = self.vision_model.visual.ln_post.bias.data
+
+ self.cross_modal_image_layers = nn.ModuleList(
+ [BridgeTowerBertCrossLayer(text_config, layer_idx=i) for i in range(config.num_hidden_layers)]
+ )
+ self.cross_modal_text_layers = nn.ModuleList(
+ [BridgeTowerBertCrossLayer(text_config, layer_idx=i) for i in range(config.num_hidden_layers)]
+ )
+
+ # Class token => Linear => Tanh
+ self.cross_modal_image_pooler = BridgeTowerPooler(config)
+ self.cross_modal_text_pooler = BridgeTowerPooler(config)
+
+ # Initialize BridgeTower Components
+ self.cross_modal_text_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.cross_modal_image_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ if config.share_link_tower_layers:
+ self.cross_modal_text_link_tower = BridgeTowerLinkTower(config)
+ self.cross_modal_image_link_tower = BridgeTowerLinkTower(config)
+ else:
+ self.cross_modal_text_link_tower = nn.ModuleList(
+ [BridgeTowerLinkTower(config) for _ in range(config.num_hidden_layers - 1)]
+ )
+ self.cross_modal_image_link_tower = nn.ModuleList(
+ [BridgeTowerLinkTower(config) for _ in range(config.num_hidden_layers - 1)]
+ )
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.text_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.text_model.set_input_embeddings(value)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ image_embeds: Optional[torch.FloatTensor] = None,
+ image_token_type_idx: Optional[int] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> Union[tuple[torch.Tensor], BridgeTowerModelOutput]:
+ r"""
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
+ Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
+ image_token_type_idx (`int`, *optional*):
+ - The token type ids for images.
+ output_hidden_states (`bool`, *optional*):
+ If set to `True`, hidden states are returned as a list containing the hidden states of text, image, and
+ cross-modal components respectively. i.e. `(hidden_states_text, hidden_states_image,
+ hidden_states_cross_modal)` where each element is a list of the hidden states of the corresponding
+ modality. `hidden_states_txt/img` are a list of tensors corresponding to unimodal hidden states and
+ `hidden_states_cross_modal` is a list of tuples containing `cross_modal_text_hidden_states` and
+ `cross_modal_image_hidden_states` of each brdige layer.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels are currently not supported.
+
+ Examples:
+
+ ```python
+ >>> from transformers import BridgeTowerProcessor, BridgeTowerModel
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> # prepare image and text
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> text = "hello world"
+ >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base")
+ >>> model = BridgeTowerModel.from_pretrained("BridgeTower/bridgetower-base")
+
+ >>> inputs = processor(image, text, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> outputs.keys()
+ odict_keys(['text_features', 'image_features', 'pooler_output'])
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ all_hidden_states_text = () if output_hidden_states else None
+ all_hidden_states_image = () if output_hidden_states else None
+ all_hidden_states_cross = () if output_hidden_states else None
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if inputs_embeds is not None and input_ids is None:
+ raise NotImplementedError(
+ "BridgeTowerModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead."
+ )
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ image_token_type_idx = image_token_type_idx if image_token_type_idx else 1
+ input_shape = input_ids.size()
+ text_embeds = self.text_model.embeddings(input_ids=input_ids)
+
+ if output_hidden_states:
+ all_hidden_states_text += (text_embeds,)
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, dtype=torch.long, device=input_ids.device)
+ extend_text_masks = self.text_model.get_extended_attention_mask(attention_mask, input_shape).to(
+ input_ids.device
+ )
+
+ # The split_index determines how many layers of the uni-modal encoder are applied before the cross-modal encoder
+ split_index = len(self.text_model.encoder.layer) - self.config.num_hidden_layers + 1
+
+ # Run the first 'split_index' layers of the textual encoder
+ for layer in self.text_model.encoder.layer[:split_index]:
+ text_embeds = layer(text_embeds, extend_text_masks)[0]
+
+ if output_hidden_states:
+ all_hidden_states_text += (text_embeds,)
+
+ if image_embeds is None:
+ image_embeds = self.vision_model.visual.forward_pre(
+ pixel_values.type(self.vision_model.dtype), interpolate_pos_encoding=interpolate_pos_encoding
+ )
+ else:
+ # Permute as BridgeTowerResidualAttention has batch_first=True
+ image_embeds = image_embeds.permute(1, 0, 2)
+
+ if output_hidden_states:
+ all_hidden_states_image += (image_embeds,)
+
+ # Run the first 'split_index' layers of the visual encoder
+ for block in self.vision_model.visual.transformer.resblocks[:split_index]:
+ image_embeds = block(image_embeds)
+ if output_hidden_states:
+ all_hidden_states_image += (image_embeds,)
+
+ image_embeds_with_ln = self.vision_model.visual.forward_post(image_embeds.type(self.vision_model.dtype))
+
+ # first layer is a special case because we don't have the output from the cross-encoder yet
+ cross_modal_text = self.cross_modal_text_transform(text_embeds)
+
+ text_token_type_embeddings = self.token_type_embeddings(
+ torch.zeros(1, dtype=torch.long, device=input_ids.device)
+ ).expand_as(cross_modal_text)
+
+ cross_modal_text = self.cross_modal_text_layernorm(cross_modal_text + text_token_type_embeddings)
+
+ image_embeds_with_ln = self.cross_modal_image_transform(image_embeds_with_ln)
+ image_token_type_embeddings = self.token_type_embeddings(
+ torch.full((1,), image_token_type_idx, dtype=torch.long, device=input_ids.device)
+ ).expand_as(image_embeds_with_ln)
+
+ image_embeds_with_ln = image_embeds_with_ln + image_token_type_embeddings
+ cross_modal_image = self.cross_modal_image_layernorm(image_embeds_with_ln)
+
+ pixel_mask = torch.ones(
+ (cross_modal_image.size(0), cross_modal_image.size(1)),
+ dtype=torch.long,
+ device=input_ids.device,
+ )
+ extend_image_masks = self.text_model.get_extended_attention_mask(pixel_mask, pixel_mask.size()).to(
+ input_ids.device
+ )
+
+ layer_outputs_text = self.cross_modal_text_layers[0](
+ cross_modal_text,
+ cross_modal_image,
+ attention_mask=extend_text_masks,
+ encoder_attention_mask=extend_image_masks,
+ output_attentions=output_attentions,
+ )
+ cross_text_features = layer_outputs_text[0]
+
+ layer_outputs_image = self.cross_modal_image_layers[0](
+ cross_modal_image,
+ cross_modal_text,
+ attention_mask=extend_image_masks,
+ encoder_attention_mask=extend_text_masks,
+ output_attentions=output_attentions,
+ )
+ cross_image_features = layer_outputs_image[0]
+
+ if output_hidden_states:
+ all_hidden_states_cross += ((cross_text_features, cross_image_features),)
+
+ if output_attentions:
+ all_self_attentions += ((layer_outputs_text[1], layer_outputs_image[1]),)
+
+ link_layer_index = 0
+
+ # Each of the top 6 layers of the visual and textual encoders ([split_index:]) is connected to each layer of
+ # the cross-modal encoder via bridge layers, which brings bottom-up alignment and fusion to the cross-modal encoder.
+ for i in range(split_index, len(self.text_model.encoder.layer)):
+ text_embeds = self.text_model.encoder.layer[i](text_embeds, extend_text_masks)[0]
+ image_embeds = self.vision_model.visual.transformer.resblocks[i](image_embeds).type(
+ self.vision_model.dtype
+ )
+ image_embeds_with_ln = (
+ self.cross_modal_image_transform(self.vision_model.visual.forward_post(image_embeds))
+ + image_token_type_embeddings
+ )
+
+ text_link_tower = self.cross_modal_text_link_tower[link_layer_index]
+ image_link_tower = self.cross_modal_image_link_tower[link_layer_index]
+
+ # Bridge layers for textual and visual encoders
+ cross_text_features_ = text_link_tower(
+ self.cross_modal_text_transform(text_embeds) + text_token_type_embeddings,
+ cross_text_features,
+ extend_text_masks,
+ )
+ cross_image_features_ = image_link_tower(image_embeds_with_ln, cross_image_features, extend_image_masks)
+
+ # Cross-modal encoder via bridge layers of textual and visual encoders
+ layer_outputs_text = self.cross_modal_text_layers[link_layer_index + 1](
+ cross_text_features_,
+ cross_image_features_,
+ attention_mask=extend_text_masks,
+ encoder_attention_mask=extend_image_masks,
+ output_attentions=output_attentions,
+ )
+ cross_text_features = layer_outputs_text[0]
+
+ layer_outputs_image = self.cross_modal_image_layers[link_layer_index + 1](
+ cross_image_features_,
+ cross_text_features_,
+ attention_mask=extend_image_masks,
+ encoder_attention_mask=extend_text_masks,
+ output_attentions=output_attentions,
+ )
+ cross_image_features = layer_outputs_image[0]
+
+ link_layer_index += 1
+
+ if output_hidden_states:
+ all_hidden_states_text += (text_embeds,)
+ all_hidden_states_image += (image_embeds,)
+ all_hidden_states_cross += ((cross_text_features, cross_image_features),)
+
+ if output_attentions:
+ all_self_attentions += ((layer_outputs_text[1], layer_outputs_image[1]),)
+
+ # Concatenate the cls token of the text and image features to get the final represtation
+ text_features, image_features = cross_text_features, cross_image_features
+ cls_features = self.get_cls_features(text_features, image_features)
+
+ if output_hidden_states:
+ all_hidden_states = (all_hidden_states_text, all_hidden_states_image, all_hidden_states_cross)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [text_features, image_features, cls_features, all_hidden_states, all_self_attentions]
+ if v is not None
+ )
+
+ return BridgeTowerModelOutput(
+ text_features=text_features,
+ image_features=image_features,
+ pooler_output=cls_features,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ def get_cls_features(self, text_features, image_features):
+ cls_features_text = self.cross_modal_text_pooler(text_features)
+ cls_features_image = self.cross_modal_image_pooler(image_features)
+ return torch.cat([cls_features_text, cls_features_image], dim=-1)
+
+
+# Copied from transformers.models.vilt.modeling_vilt.ViltPredictionHeadTransform with Vilt->BridgeTower
+class BridgeTowerPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BridgeTowerMLMHead(nn.Module):
+ def __init__(self, config, weight=None):
+ super().__init__()
+ self.config = config
+ self.transform = BridgeTowerPredictionHeadTransform(config)
+ self.decoder = nn.Linear(config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.bias = nn.Parameter(torch.zeros(config.text_config.vocab_size))
+ if weight is not None:
+ self.decoder.weight = weight
+
+ def forward(self, x):
+ mlm_score = self.transform(x)
+ mlm_score = self.decoder(mlm_score) + self.bias
+ return mlm_score
+
+
+class BridgeTowerITMHead(nn.Module):
+ def __init__(self, hidden_size):
+ super().__init__()
+ self.fc = nn.Linear(hidden_size, 2)
+
+ def forward(self, x):
+ itm_score = self.fc(x)
+ return itm_score
+
+
+@auto_docstring(
+ custom_intro="""
+ BridgeTower Model with a language modeling head on top as done during pretraining.
+ """
+)
+class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
+ _tied_weights_keys = ["mlm_score.decoder.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bridgetower = BridgeTowerModel(config)
+ self.mlm_score = BridgeTowerMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.mlm_score.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.mlm_score.decoder = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ image_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ ) -> Union[MaskedLMOutput, tuple[torch.FloatTensor]]:
+ r"""
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
+ Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+
+ Examples:
+
+ ```python
+ >>> from transformers import BridgeTowerProcessor, BridgeTowerForMaskedLM
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000360943.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
+ >>> text = "a looking out of the window"
+
+ >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
+ >>> model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
+
+ >>> # prepare inputs
+ >>> encoding = processor(image, text, return_tensors="pt")
+
+ >>> # forward pass
+ >>> outputs = model(**encoding)
+
+ >>> results = processor.decode(outputs.logits.argmax(dim=-1).squeeze(0).tolist())
+
+ >>> print(results)
+ .a cat looking out of the window.
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ outputs = self.bridgetower(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ image_embeds=image_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ mlm_logits = self.mlm_score(outputs.text_features if return_dict else outputs[0])
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+
+ labels = labels.to(mlm_logits.device)
+ masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = tuple(mlm_logits)
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=mlm_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ BridgeTower Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the
+ [CLS] token) for image-to-text matching.
+ """
+)
+class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bridgetower = BridgeTowerModel(config)
+
+ self.itm_score = BridgeTowerITMHead(config.hidden_size * 2)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ image_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ ) -> Union[SequenceClassifierOutput, tuple[torch.FloatTensor]]:
+ r"""
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
+ Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
+ labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
+ Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
+ The pairs with 0 will be skipped for calculation.
+
+ Examples:
+
+ ```python
+ >>> from transformers import BridgeTowerProcessor, BridgeTowerForImageAndTextRetrieval
+ >>> import requests
+ >>> from PIL import Image
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]
+
+ >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
+ >>> model = BridgeTowerForImageAndTextRetrieval.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
+
+ >>> # forward pass
+ >>> scores = dict()
+ >>> for text in texts:
+ ... # prepare inputs
+ ... encoding = processor(image, text, return_tensors="pt")
+ ... outputs = model(**encoding)
+ ... scores[text] = outputs.logits[0, 1].item()
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bridgetower(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ image_embeds=image_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooler_output = outputs.pooler_output if return_dict else outputs[2]
+
+ logits = self.itm_score(pooler_output)
+
+ itm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+
+ labels = labels.to(logits.device)
+ itm_loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = tuple(logits)
+ return ((itm_loss,) + output) if itm_loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=itm_loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class BridgeTowerContrastiveHead(nn.Module):
+ def __init__(self, hidden_size, embed_size):
+ super().__init__()
+ self.fc = nn.Linear(hidden_size, embed_size)
+
+ def forward(self, x):
+ x = self.fc(x)
+ return x
+
+
+@auto_docstring(
+ custom_intro="""
+ BridgeTower Model with a image-text contrastive head on top computing image-text contrastive loss.
+ """
+)
+class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bridgetower = BridgeTowerModel(config)
+
+ self.itc_text_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
+ self.itc_image_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
+ self.itc_cross_modal_head = BridgeTowerContrastiveHead(config.hidden_size * 2, config.contrastive_hidden_size)
+
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ image_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = True,
+ return_dict: Optional[bool] = None,
+ return_loss: Optional[bool] = None,
+ ) -> Union[BridgeTowerContrastiveOutput, tuple[torch.FloatTensor]]:
+ r"""
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
+ Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+
+ Examples:
+
+ ```python
+ >>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
+ >>> import requests
+ >>> from PIL import Image
+ >>> import torch
+
+ >>> image_urls = [
+ ... "https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg",
+ ... "http://images.cocodataset.org/val2017/000000039769.jpg",
+ ... ]
+ >>> texts = ["two dogs in a car", "two cats sleeping on a couch"]
+ >>> images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls]
+
+ >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
+ >>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
+
+ >>> inputs = processor(images, texts, padding=True, return_tensors="pt")
+ >>> loss = model(**inputs, return_loss=True).loss
+
+ >>> inputs = processor(images, texts[::-1], padding=True, return_tensors="pt")
+ >>> loss_swapped = model(**inputs, return_loss=True).loss
+
+ >>> print("Loss", round(loss.item(), 4))
+ Loss 0.0019
+
+ >>> print("Loss with swapped images", round(loss_swapped.item(), 4))
+ Loss with swapped images 2.126
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bridgetower(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ image_embeds=image_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ return_dict=return_dict,
+ )
+
+ pooler_output = outputs.pooler_output if return_dict else outputs[2]
+ hidden_states_txt, hidden_states_img, hidden_states_cross_modal = (
+ outputs.hidden_states if return_dict else outputs[3]
+ )
+
+ text_embeds = hidden_states_txt[-1]
+ image_embeds = hidden_states_img[-1]
+
+ image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(image_embeds)
+ image_token_type_embeddings = self.bridgetower.token_type_embeddings(
+ torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device)
+ ).expand_as(image_embeds_with_ln)
+
+ image_embeds = self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) + image_token_type_embeddings
+
+ # normalized features
+ text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2)
+ image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2).to(
+ device=text_embeds.device
+ )
+ cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2).to(
+ device=text_embeds.device
+ )
+
+ logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2)
+
+ logit_scale = self.logit_scale.exp().to(device=text_embeds.device)
+ logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
+ logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale
+ logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale
+
+ itc_loss = None
+
+ if return_loss:
+ labels = torch.arange(len(logits), device=logits.device)
+ text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels)
+ text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels)
+ image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels)
+ itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0
+
+ if not return_dict:
+ output = (logits, text_embeds, image_embeds, cross_embeds) + outputs[3:]
+ return ((itc_loss,) + output) if itc_loss is not None else output
+
+ return BridgeTowerContrastiveOutput(
+ loss=itc_loss,
+ logits=logits,
+ text_embeds=text_embeds,
+ image_embeds=image_embeds,
+ cross_embeds=cross_embeds,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "BridgeTowerForContrastiveLearning",
+ "BridgeTowerForImageAndTextRetrieval",
+ "BridgeTowerForMaskedLM",
+ "BridgeTowerModel",
+ "BridgeTowerPreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/bridgetower/processing_bridgetower.py b/venv/lib/python3.13/site-packages/transformers/models/bridgetower/processing_bridgetower.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d7059c4c5a5dfef87ca20117d60a4b6e9fb5f72
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/bridgetower/processing_bridgetower.py
@@ -0,0 +1,73 @@
+# coding=utf-8
+# Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for BridgeTower.
+"""
+
+from typing import Optional
+
+from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
+
+
+class BridgeTowerImagesKwargs(ImagesKwargs):
+ size_divisor: Optional[int]
+
+
+class BridgeTowerProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs: BridgeTowerImagesKwargs
+ _defaults = {
+ "text_kwargs": {
+ "add_special_tokens": True,
+ "padding": False,
+ "stride": 0,
+ "return_overflowing_tokens": False,
+ "return_special_tokens_mask": False,
+ "return_offsets_mapping": False,
+ "return_length": False,
+ "verbose": True,
+ },
+ "images_kwargs": {
+ "do_normalize": True,
+ "do_center_crop": True,
+ },
+ }
+
+
+class BridgeTowerProcessor(ProcessorMixin):
+ r"""
+ Constructs a BridgeTower processor which wraps a Roberta tokenizer and BridgeTower image processor into a single
+ processor.
+
+ [`BridgeTowerProcessor`] offers all the functionalities of [`BridgeTowerImageProcessor`] and
+ [`RobertaTokenizerFast`]. See the docstring of [`~BridgeTowerProcessor.__call__`] and
+ [`~BridgeTowerProcessor.decode`] for more information.
+
+ Args:
+ image_processor (`BridgeTowerImageProcessor`):
+ An instance of [`BridgeTowerImageProcessor`]. The image processor is a required input.
+ tokenizer (`RobertaTokenizerFast`):
+ An instance of ['RobertaTokenizerFast`]. The tokenizer is a required input.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "BridgeTowerImageProcessor"
+ tokenizer_class = ("RobertaTokenizer", "RobertaTokenizerFast")
+ valid_processor_kwargs = BridgeTowerProcessorKwargs
+
+ def __init__(self, image_processor, tokenizer):
+ super().__init__(image_processor, tokenizer)
+
+
+__all__ = ["BridgeTowerProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/byt5/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/byt5/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb726942b0f16105f8a5a5f7c661485951d7ccc7
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/byt5/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .tokenization_byt5 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/byt5/tokenization_byt5.py b/venv/lib/python3.13/site-packages/transformers/models/byt5/tokenization_byt5.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a9804db1014a963fb2054083f2db2782a41016d
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/byt5/tokenization_byt5.py
@@ -0,0 +1,236 @@
+# coding=utf-8
+# Copyright 2021 T5 Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for model ByT5."""
+
+import warnings
+from typing import Optional
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class ByT5Tokenizer(PreTrainedTokenizer):
+ """
+ Construct a ByT5 tokenizer. ByT5 simply uses raw bytes utf-8 encoding.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ extra_ids (`int`, *optional*, defaults to 125):
+ Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
+ accessible as "" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
+ indexed from the end of the vocabulary up to beginning ("" is the last token in the vocabulary
+ like in ByT5 preprocessing see
+ [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
+ additional_special_tokens (`list[str]`, *optional*):
+ Additional special tokens used by the tokenizer.
+ """
+
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ eos_token="",
+ unk_token="",
+ pad_token="",
+ extra_ids=125,
+ additional_special_tokens=None,
+ **kwargs,
+ ) -> None:
+ # Add extra_ids to the special token list
+ if extra_ids > 0 and additional_special_tokens is None:
+ additional_special_tokens = [f"" for i in range(extra_ids)]
+ elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0:
+ # Check that we have the right number of extra_id special tokens
+ extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
+ if extra_tokens != extra_ids:
+ raise ValueError(
+ f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
+ " provided to ByT5Tokenizer. In this case the additional_special_tokens must include the"
+ " extra_ids tokens"
+ )
+
+ pad_token = AddedToken(pad_token, lstrip=True, rstrip=True) if isinstance(pad_token, str) else pad_token
+ # we force left and right stripping for backward compatibility. The byt5tests depend on this.
+ eos_token = AddedToken(eos_token, lstrip=True, rstrip=True) if isinstance(eos_token, str) else eos_token
+ unk_token = AddedToken(unk_token, lstrip=True, rstrip=True) if isinstance(unk_token, str) else unk_token
+ # unk token needs to be in the vocab with correct index
+ self._added_tokens_decoder = {0: pad_token, 1: eos_token, 2: unk_token}
+ self.offset = len(self._added_tokens_decoder)
+ self._utf_vocab_size = 2**8 # utf is 8 bits
+ super().__init__(
+ eos_token=eos_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ extra_ids=0,
+ additional_special_tokens=additional_special_tokens, # TODO extra ids are not used :sweatywmile:
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self):
+ return self._utf_vocab_size
+
+ def get_vocab(self):
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ # normal case: some special tokens
+ if token_ids_1 is None:
+ return ([0] * len(token_ids_0)) + [1]
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+
+ def _add_eos_if_not_present(self, token_ids: list[int]) -> list[int]:
+ """Do not add eos again if user already added it."""
+ if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
+ warnings.warn(
+ f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
+ " eos tokens being added."
+ )
+ return token_ids
+ else:
+ return token_ids + [self.eos_token_id]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. ByT5 does not
+ make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of zeros.
+ """
+ eos = [self.eos_token_id]
+
+ if token_ids_1 is None:
+ return len(token_ids_0 + eos) * [0]
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A sequence has the following format:
+
+ - single sequence: `X `
+ - pair of sequences: `A B `
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ token_ids_0 = self._add_eos_if_not_present(token_ids_0)
+ if token_ids_1 is None:
+ return token_ids_0
+ else:
+ token_ids_1 = self._add_eos_if_not_present(token_ids_1)
+ return token_ids_0 + token_ids_1
+
+ def _tokenize(self, text: str) -> list[str]:
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
+ tokens = [chr(i) for i in text.encode("utf-8")]
+ return tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+
+ if len(token) != 1:
+ token_id = None
+ else:
+ token_id = ord(token) + self.offset
+
+ return token_id
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ token = chr(index - self.offset)
+ return token
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ bstring = b""
+ for token in tokens:
+ if token in self.added_tokens_decoder:
+ tok_string = self.added_tokens_decoder[token].encode("utf-8")
+ elif token in self.added_tokens_encoder:
+ tok_string = token.encode("utf-8")
+ else:
+ tok_string = bytes([ord(token)])
+ bstring += tok_string
+ string = bstring.decode("utf-8", errors="ignore")
+ return string
+
+ # ByT5Tokenizer has no vocab file
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ return ()
+
+
+__all__ = ["ByT5Tokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/chameleon/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/chameleon/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ad11a90a24bc4e8c9fd744bca6297e5388fd52e
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/chameleon/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_chameleon import *
+ from .image_processing_chameleon import *
+ from .image_processing_chameleon_fast import *
+ from .modeling_chameleon import *
+ from .processing_chameleon import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/chameleon/configuration_chameleon.py b/venv/lib/python3.13/site-packages/transformers/models/chameleon/configuration_chameleon.py
new file mode 100644
index 0000000000000000000000000000000000000000..34436a5288c8187d893d7ca4775b812b4c4d7961
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/chameleon/configuration_chameleon.py
@@ -0,0 +1,282 @@
+# coding=utf-8
+# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""chameleon model configuration"""
+
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class ChameleonVQVAEConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ChameleonVQModel`]. It is used to instantiate a
+ `ChameleonVQModel` according to the specified arguments, defining the model architecture.
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information. Instantiating a
+ configuration with the defaults will yield a similar configuration to the VQModel of the
+ [meta/chameleon-7B](https://huggingface.co/meta/chameleon-7B).
+
+ Args:
+ embed_dim (`int`, *optional*, defaults to 256):
+ Dimensionality of each embedding vector.
+ num_embeddings (`int`, *optional*, defaults to 8192):
+ Number of codebook embeddings.
+ double_latent (`bool`, *optional*, defaults to `False`):
+ Whether to use double z channels.
+ latent_channels (`int`, *optional*, defaults to 256):
+ Number of channels for the latent space.
+ resolution (`int`, *optional*, defaults to 512):
+ Resolution of the input images.
+ in_channels (`int`, *optional*, defaults to 3):
+ Number of input channels.
+ base_channels (`int`, *optional*, defaults to 128):
+ Base channel count.
+ channel_multiplier (`list[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`):
+ Channel multipliers for each resolution.
+ num_res_blocks (`int`, *optional*, defaults to 2):
+ Number of residual blocks.
+ attn_resolutions (`list[int]`, *optional*):
+ Resolutions to apply attention.
+ dropout (`float`, *optional*, defaults to 0.0):
+ Dropout rate.
+ attn_type (`str`, *optional*, defaults to `"vanilla"`):
+ Attention type used in VQ-GAN encoder. Can be "vanilla" or None.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ """
+
+ model_type = "chameleon_vqgan"
+ base_config_key = "vq_config"
+
+ def __init__(
+ self,
+ embed_dim: int = 256,
+ num_embeddings: int = 8192,
+ double_latent: bool = False,
+ latent_channels: int = 256,
+ resolution: int = 512,
+ in_channels: int = 3,
+ base_channels: int = 128,
+ channel_multiplier: list[int] = [1, 1, 2, 2, 4],
+ num_res_blocks: int = 2,
+ attn_resolutions: Optional[list[int]] = None,
+ dropout: float = 0.0,
+ attn_type: str = "vanilla",
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+ self.num_embeddings = num_embeddings
+ self.double_latent = double_latent
+ self.latent_channels = latent_channels
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.base_channels = base_channels
+ self.channel_multiplier = channel_multiplier
+ self.num_res_blocks = num_res_blocks
+ self.attn_resolutions = attn_resolutions
+ self.dropout = dropout
+ self.attn_type = attn_type
+ self.initializer_range = initializer_range
+
+
+class ChameleonConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ChameleonModel`]. It is used to instantiate a
+ chameleon model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the
+ [meta/chameleon-7B](https://huggingface.co/meta/chameleon-7B).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 65536):
+ Vocabulary size of the chameleon model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ChameleonModel`]; this includes text and image tokens.
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 32):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
+ The maximum sequence length that this model might ever be used with. Chameleon supports up to 4096 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/Localchameleon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ model_parallel_size (`int`, *optional*, defaults to 1):
+ Number of shards used when training the model. This will be used in qk layernorm because the original Chameleon inference
+ doesn't do reduction in those layers and each rank has its own biases.
+ swin_norm (`bool`, *optional*, defaults to `False`):
+ Use Swin Transformer normalization.
+ vq_config (`dict`, *optional*):
+ ChameleonVQConfig instance containing the configuration for the VQ-VAE model.
+ vocabulary_map (`dict`, *optional*):
+ A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+
+
+ ```python
+ >>> from transformers import ChameleonModel, ChameleonConfig
+
+ >>> # Initializing a chameleon chameleon-7b style configuration
+ >>> configuration = ChameleonConfig()
+
+ >>> # Initializing a model from the chameleon-7b style configuration
+ >>> model = ChameleonModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "chameleon"
+ sub_configs = {"vq_config": ChameleonVQVAEConfig}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=65536,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ hidden_act="silu",
+ max_position_embeddings=4096,
+ initializer_range=0.02,
+ rms_norm_eps=1e-05,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ model_parallel_size=1,
+ swin_norm=False,
+ vq_config=None,
+ vocabulary_map=None,
+ mlp_bias=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.mlp_bias = mlp_bias
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self._rope_scaling_validation()
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.model_parallel_size = model_parallel_size
+ self.swin_norm = swin_norm
+
+ if vq_config is None:
+ vq_config = {}
+ logger.info("vq_config is None. initializing the ChameleonVQConfig with default values.")
+
+ self.vq_config = ChameleonVQVAEConfig(**vq_config)
+
+ self.vocabulary_map = vocabulary_map
+ self.image_token_id = vocabulary_map.get("") if vocabulary_map is not None else None
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
+ f"got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
+ raise ValueError(
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
+ )
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
+
+
+__all__ = ["ChameleonConfig", "ChameleonVQVAEConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/chameleon/image_processing_chameleon.py b/venv/lib/python3.13/site-packages/transformers/models/chameleon/image_processing_chameleon.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cae9d7bdd34d8b7a2292059ae1e92b587891272
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/chameleon/image_processing_chameleon.py
@@ -0,0 +1,341 @@
+# coding=utf-8
+# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Chameleon."""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import get_resize_output_image_size, resize, to_channel_dimension_format
+from ...image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+if is_vision_available():
+ import PIL
+
+
+class ChameleonImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Chameleon image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in the `preprocess` method.
+ size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 512}`):
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to 1):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
+ `preprocess` method.
+ crop_size (`dict[str, int]` *optional*, defaults to {"height": 512, "width": 512}):
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
+ method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to 0.0078):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `[1.0, 1.0, 1.0]`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `[1.0, 1.0, 1.0]`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PIL.Image.LANCZOS,
+ do_center_crop: bool = True,
+ crop_size: Optional[dict[str, int]] = None,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 0.0078,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 512}
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else {"height": 512, "width": 512}
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else [1.0, 1.0, 1.0]
+ self.image_std = image_std if image_std is not None else [1.0, 1.0, 1.0]
+ self.do_convert_rgb = do_convert_rgb
+
+ # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
+ resized to keep the input aspect ratio.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ default_to_square = True
+ if "shortest_edge" in size:
+ size = size["shortest_edge"]
+ default_to_square = False
+ elif "height" in size and "width" in size:
+ size = (size["height"], size["width"])
+ else:
+ raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
+
+ output_size = get_resize_output_image_size(
+ image,
+ size=size,
+ default_to_square=default_to_square,
+ input_data_format=input_data_format,
+ )
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[int] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size, param_name="size", default_to_square=False)
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if do_convert_rgb:
+ images = [self.blend_rgba(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+ all_images = []
+ for image in images:
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_center_crop:
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ all_images.append(image)
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ for image in all_images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def blend_rgba(self, image: ImageInput) -> ImageInput:
+ """
+ Convert image to RGB by blending the transparency layer if it's in RGBA format.
+ If image is not `PIL.Image`, it si simply returned without modifications.
+
+ Args:
+ image (`ImageInput`):
+ Image to convert.
+ """
+
+ if not isinstance(image, PIL.Image.Image):
+ return image
+ elif image.mode == "RGB":
+ return image
+
+ img_rgba = np.array(image.convert("RGBA"))
+
+ # If there is no transparency layer, simple convert and return.
+ if not (img_rgba[:, :, 3] < 255).any():
+ return image.convert("RGB")
+
+ # There is a transparency layer, blend it with a white background.
+ # Calculate the alpha proportion for blending.
+ alpha = img_rgba[:, :, 3] / 255.0
+ img_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[:, :, np.newaxis] * img_rgba[:, :, :3]
+ return PIL.Image.fromarray(img_rgb.astype("uint8"), "RGB")
+
+
+__all__ = ["ChameleonImageProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/chameleon/image_processing_chameleon_fast.py b/venv/lib/python3.13/site-packages/transformers/models/chameleon/image_processing_chameleon_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d102614f7df3700845c00f9d8bfa217930c776b
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/chameleon/image_processing_chameleon_fast.py
@@ -0,0 +1,112 @@
+# coding=utf-8
+# Copyright 2025 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for Chameleon."""
+
+from typing import Optional
+
+import numpy as np
+import PIL
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils_fast import BaseImageProcessorFast
+from ...image_utils import ImageInput, PILImageResampling, SizeDict
+from ...utils import auto_docstring, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+@auto_docstring
+class ChameleonImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.LANCZOS
+ image_mean = [1.0, 1.0, 1.0]
+ image_std = [1.0, 1.0, 1.0]
+ size = {"shortest_edge": 512}
+ default_to_square = False
+ crop_size = {"height": 512, "width": 512}
+ do_resize = True
+ do_center_crop = True
+ do_rescale = True
+ rescale_factor = 0.0078
+ do_normalize = True
+ do_convert_rgb = True
+
+ def convert_to_rgb(self, image: ImageInput) -> ImageInput:
+ """
+ Convert image to RGB by blending the transparency layer if it's in RGBA format.
+ If image is not `PIL.Image`, it si simply returned without modifications.
+
+ Args:
+ image (`ImageInput`):
+ Image to convert.
+ """
+
+ if not isinstance(image, PIL.Image.Image):
+ return image
+ elif image.mode == "RGB":
+ return image
+
+ img_rgba = np.array(image.convert("RGBA"))
+
+ # If there is no transparency layer, simple convert and return.
+ if not (img_rgba[:, :, 3] < 255).any():
+ return image.convert("RGB")
+
+ # There is a transparency layer, blend it with a white background.
+ # Calculate the alpha proportion for blending.
+ alpha = img_rgba[:, :, 3] / 255.0
+ img_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[:, :, np.newaxis] * img_rgba[:, :, :3]
+ return PIL.Image.fromarray(img_rgb.astype("uint8"), "RGB")
+
+ def resize(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ **kwargs,
+ ) -> "torch.Tensor":
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to resize.
+ size (`SizeDict`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
+
+ Returns:
+ `torch.Tensor`: The resized image.
+ """
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
+ if interpolation == F.InterpolationMode.LANCZOS:
+ logger.warning_once(
+ "You have used fast image processor with LANCZOS resample which not yet supported for torch.Tensor. "
+ "BICUBIC resample will be used as an alternative. Please fall back to slow image processor if you "
+ "want full consistency with the original model."
+ )
+ interpolation = F.InterpolationMode.BICUBIC
+
+ return super().resize(
+ image=image,
+ size=size,
+ interpolation=interpolation,
+ **kwargs,
+ )
+
+
+__all__ = ["ChameleonImageProcessorFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/chameleon/modeling_chameleon.py b/venv/lib/python3.13/site-packages/transformers/models/chameleon/modeling_chameleon.py
new file mode 100644
index 0000000000000000000000000000000000000000..033b8ecd7c630208602e1ae4d38b95df7718aaed
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/chameleon/modeling_chameleon.py
@@ -0,0 +1,1169 @@
+# coding=utf-8
+# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Chameleon model."""
+
+from functools import cached_property
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ TransformersKwargs,
+ auto_docstring,
+ can_return_tuple,
+ logging,
+)
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Chameleon
+class ChameleonRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ ChameleonRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon
+# TODO(joao): add me back asap :)
+class ChameleonRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
+ super().__init__()
+ self.scaling_factor = scaling_factor
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (
+ self.base
+ ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ # For BC we register cos and sin cached
+ self.max_seq_len_cached = max_position_embeddings
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
+ """ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: a scaling factor is applied to the position ids
+ position_ids = position_ids.float() / self.scaling_factor
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding):
+ """ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_position_embeddings:
+ base = self.base * (
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
+ ) ** (self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (
+ base
+ ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=x.device, dtype=torch.float) / self.dim)
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
+
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Chameleon
+class ChameleonMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ # Ignore copy
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class ChameleonLayerNorm(nn.LayerNorm):
+ """
+ LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta
+ from each shard separately to each head, instead of reducing. We can apply each head's own
+ gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed
+ in the last dimension. This module applies gamma/beta manually to fulfill this requirement.
+ """
+
+ def __init__(self, hidden_size, *args, **kwargs):
+ super().__init__(hidden_size, *args, **kwargs)
+ self.normalized_shape = (hidden_size[-1],)
+
+ def forward(self, hidden_states):
+ hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5)
+ hidden_states = hidden_states * self.weight + self.bias
+ return hidden_states
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+# Copied from transformers.models.llama.modeling_llama.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class ChameleonAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: ChameleonConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+ self.model_parallel_size = config.model_parallel_size
+ self.scaling = self.head_dim**-0.5
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
+ self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
+ self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim))
+ self._init_rope()
+
+ # copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon
+ # TODO(joao): add me back asap :)
+ def _init_rope(self):
+ if self.config.rope_scaling is None:
+ self.rotary_emb = ChameleonRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+ else:
+ scaling_type = self.config.rope_scaling["type"]
+ scaling_factor = self.config.rope_scaling["factor"]
+ if scaling_type == "linear":
+ self.rotary_emb = ChameleonLinearScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ elif scaling_type == "dynamic":
+ self.rotary_emb = ChameleonDynamicNTKScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ else:
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
+ query_states = self.q_norm(query_states)
+
+ key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
+ key_states = self.k_norm(key_states)
+
+ query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
+class ChameleonDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: ChameleonConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = ChameleonMLP(config)
+ self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+class ChameleonSwinDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: ChameleonConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = ChameleonMLP(config)
+ self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ """
+
+ residual = hidden_states
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+class ChameleonVQVAEVectorQuantizer(nn.Module):
+ """
+ A module for vector quantization using learned embedding vectors.
+
+ This module implements the quantization process similar to te one described in
+ the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
+ input vectors into discrete codebook vectors, which are learned during training.
+ Current implementation improves over previous ones by avoiding costly matrix multiplications
+ and allowing for post-hoc remapping of indices.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.num_embeddings = config.num_embeddings
+ self.embedding_dim = config.embed_dim
+ self.beta = getattr(config, "beta", 0.25)
+
+ self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
+
+ def forward(self, hidden_state: torch.Tensor):
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
+ hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ distances = (
+ torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
+ + torch.sum(self.embedding.weight**2, dim=1)
+ - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1))
+ )
+
+ min_encoding_indices = torch.argmin(distances, dim=1)
+ hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape)
+
+ # compute loss for embedding
+ loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
+ (hidden_state_quant - hidden_state.detach()) ** 2
+ )
+
+ # preserve gradients
+ hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
+
+ # reshape back to match original input shape
+ hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
+
+ return hidden_state_quant, loss, min_encoding_indices
+
+
+class ChameleonVQVAEEncoderConvDownsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, hidden_states):
+ # no asymmetric padding in torch conv, must do it ourselves
+ hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class ChameleonVQVAEEncoderResnetBlock(nn.Module):
+ def __init__(
+ self,
+ config,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
+ self.dropout = torch.nn.Dropout(config.dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ hidden_states *= torch.sigmoid(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states *= torch.sigmoid(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ residual = self.conv_shortcut(residual)
+ else:
+ residual = self.nin_shortcut(residual)
+
+ return residual + hidden_states
+
+
+class ChameleonVQVAEEncoderAttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ hidden_states = self.norm(hidden_states)
+ query_states = self.q(hidden_states)
+ key_states = self.k(hidden_states)
+ value_states = self.v(hidden_states)
+
+ # compute attention
+ batch_size, channels, height, width = query_states.shape
+ query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1)
+ key_states = key_states.reshape(batch_size, channels, height * width)
+ attn_weights = torch.bmm(query_states, key_states)
+ attn_weights = attn_weights * (int(channels) ** (-0.5))
+ attn_weights = F.softmax(attn_weights, dim=2)
+
+ # attend to values
+ value_states = value_states.reshape(batch_size, channels, height * width)
+ attn_weights = attn_weights.permute(0, 2, 1)
+ attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width)
+
+ attn_output = self.proj_out(attn_output)
+ return residual + attn_output
+
+
+class ChameleonVQVAEEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.num_resolutions = len(config.channel_multiplier)
+ self.num_res_blocks = config.num_res_blocks
+ base_channels = config.base_channels
+ resolution = config.resolution
+ in_channels = config.in_channels
+ double_latent = config.double_latent
+ latent_channels = config.latent_channels
+ channel_multiplier = config.channel_multiplier
+
+ self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
+
+ curr_res = resolution
+ in_channel_multiplier = (1,) + tuple(channel_multiplier)
+ self.in_channel_multiplier = in_channel_multiplier
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = base_channels * in_channel_multiplier[i_level]
+ block_out = base_channels * channel_multiplier[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ChameleonVQVAEEncoderResnetBlock(
+ config=config,
+ in_channels=block_in,
+ out_channels=block_out,
+ )
+ )
+ block_in = block_out
+ if (
+ config.attn_resolutions is not None
+ and curr_res in config.attn_resolutions
+ and config.attn_type == "vanilla"
+ ):
+ attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
+
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ self.mid = nn.Module()
+ self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
+ config=config,
+ in_channels=block_in,
+ out_channels=block_in,
+ )
+ self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity()
+ self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
+ config=config,
+ in_channels=block_in,
+ out_channels=block_in,
+ )
+
+ self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ 2 * latent_channels if double_latent else latent_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, pixel_values: torch.LongTensor):
+ # downsampling
+ hidden_states = [self.conv_in(pixel_values)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ hidden_state = self.down[i_level].block[i_block](
+ hidden_states[-1],
+ )
+ if len(self.down[i_level].attn) > 0:
+ hidden_state = self.down[i_level].attn[i_block](hidden_state)
+ hidden_states.append(hidden_state)
+ if i_level != self.num_resolutions - 1:
+ hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
+
+ # middle
+ last_hidden_state = hidden_states[-1]
+ last_hidden_state = self.mid.block_1(last_hidden_state)
+ last_hidden_state = self.mid.attn_1(last_hidden_state)
+ last_hidden_state = self.mid.block_2(last_hidden_state)
+
+ # end
+ last_hidden_state = self.norm_out(last_hidden_state)
+ last_hidden_state *= torch.sigmoid(last_hidden_state)
+ last_hidden_state = self.conv_out(last_hidden_state)
+ return last_hidden_state
+
+
+class ChameleonImageVocabularyMapping:
+ """
+ A class for mapping discrete image tokens from VQGAN to BPE tokens.
+ """
+
+ def __init__(self, vocab_map):
+ self.vocab_map = vocab_map
+ self.image_token_id = vocab_map.get("")
+
+ @cached_property
+ def val2name(self):
+ return {v: k for k, v in self.vocab_map.items()}
+
+ @cached_property
+ def image_tokens(self):
+ return sorted([val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")])
+
+ @cached_property
+ def bpe2img(self):
+ img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
+
+ def remap(old_name: str) -> str:
+ return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1])
+
+ return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens}
+
+ @cached_property
+ def img2bpe(self):
+ return {v: k for k, v in self.bpe2img.items()}
+
+ @cached_property
+ def bpe2img_search_tensors(self):
+ return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(sorted(self.bpe2img.values()))
+
+ @cached_property
+ def img2bpe_mapping_tensor(self):
+ mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
+ for k, v in self.img2bpe.items():
+ mapping[k] = v
+ return mapping
+
+ def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
+ device = img_batch.device
+ img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
+ return img_tokens.to(device)
+
+
+@auto_docstring
+class ChameleonPreTrainedModel(PreTrainedModel):
+ config: ChameleonConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values", "causal_mask"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_param_buffer_assignment = False
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+
+
+@auto_docstring(
+ custom_intro="""
+ The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
+ This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
+ [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
+ Taigman](https://huggingface.co/papers/2203.13131).
+ """
+)
+class ChameleonVQVAE(ChameleonPreTrainedModel):
+ config: ChameleonVQVAEConfig
+ _no_split_modules = [
+ "ChameleonVQVAEVectorQuantizer",
+ "ChameleonVQVAEEncoderAttnBlock",
+ "ChameleonVQVAEEncoderResnetBlock",
+ ]
+
+ def __init__(self, config: ChameleonVQVAEConfig):
+ super().__init__(config)
+
+ self.encoder = ChameleonVQVAEEncoder(config)
+ self.quantize = ChameleonVQVAEVectorQuantizer(config)
+ self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
+ self.eval() # Chameleon's VQ model is frozen
+
+ def encode(self, pixel_values: torch.LongTensor):
+ hidden_states = self.encoder(pixel_values)
+ hidden_states = self.quant_conv(hidden_states)
+ quant, emb_loss, indices = self.quantize(hidden_states)
+ return quant, emb_loss, indices
+
+
+@auto_docstring
+class ChameleonModel(ChameleonPreTrainedModel):
+ def __init__(self, config: ChameleonConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map)
+ decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm else ChameleonSwinDecoderLayer
+ self.layers = nn.ModuleList(
+ [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.vqmodel = ChameleonVQVAE._from_config(config.vq_config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_image_tokens(self, pixel_values: torch.FloatTensor):
+ """
+ Tokenizes images into discrete tokens with VQGAN module. Converts
+ obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
+ special tokens.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
+ The tensors corresponding to the input images.
+ """
+ batch_size = pixel_values.shape[0]
+ _, _, image_toks = self.vqmodel.encode(pixel_values)
+ bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
+ bpe_toks = bpe_toks.view(batch_size, -1)
+ return bpe_toks
+
+ def get_image_features(self, pixel_values: torch.FloatTensor):
+ """
+ Tokenizes images into discrete tokens with VQGAN module and embeds
+ them with text embeddings layer
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
+ The tensors corresponding to the input images.
+ """
+ image_tokens = self.get_image_tokens(pixel_values)
+ vision_embeddings = self.get_input_embeddings()(image_tokens)
+ return vision_embeddings
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ return special_image_mask
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if pixel_values is not None:
+ image_embeds = self.get_image_features(pixel_values)
+ special_image_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
+
+ # torch.jit.trace() doesn't support cache objects in the output
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Chameleon Model with a head on top used for outputting logits for next token prediction.
+ """
+)
+class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = ChameleonModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_image_tokens(self, pixel_values):
+ return self.model.get_image_tokens(pixel_values)
+
+ def get_image_features(self, pixel_values):
+ return self.model.get_image_features(pixel_values)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
+ >>> import torch
+ >>> import requests
+ >>> from PIL import Image
+
+ >>> model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", dtype=torch.bfloat16)
+ >>> processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
+
+ >>> prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
+ >>> image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw)
+ >>> image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw)
+
+ >>> inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16)
+
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
+ >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+
+ # Disallow image tokens which does not include special begin-image and end-image tokens
+ image_tokens = self.model.vocabulary_mapping.image_tokens
+ logits[:, :, image_tokens] = torch.finfo(logits.dtype).min
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ pixel_values=None,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ pixel_values=pixel_values,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ if cache_position[0] != 0:
+ # If we're in cached decoding stage, pixel values should be `None` because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = None
+
+ return model_inputs
+
+
+__all__ = ["ChameleonForConditionalGeneration", "ChameleonModel", "ChameleonPreTrainedModel", "ChameleonVQVAE"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/chameleon/processing_chameleon.py b/venv/lib/python3.13/site-packages/transformers/models/chameleon/processing_chameleon.py
new file mode 100644
index 0000000000000000000000000000000000000000..d481a62b6fc6608bd3088e4766b91ff843680ea0
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/chameleon/processing_chameleon.py
@@ -0,0 +1,196 @@
+# coding=utf-8
+# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for Chameleon.
+"""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...processing_utils import (
+ MultiModalData,
+ ProcessingKwargs,
+ ProcessorMixin,
+ TextKwargs,
+ Unpack,
+)
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+
+
+class ChameleonTextKwargs(TextKwargs, total=False):
+ return_for_text_completion: bool
+
+
+class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
+ text_kwargs: ChameleonTextKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "return_for_text_completion": False,
+ "return_mm_token_type_ids": False,
+ },
+ "common_kwargs": {
+ "return_tensors": "pt",
+ },
+ }
+
+
+class ChameleonProcessor(ProcessorMixin):
+ r"""
+ Constructs a Chameleon processor which wraps a Chameleon image processor and a Chameleon tokenizer into a single
+ processor.
+
+ [`ChameleonProcessor`] offers all the functionalities of [`ChameleonImageProcessor`] and [`LlamaTokenizerFast`].
+ See the [`~ChameleonProcessor.__call__`] and [`~ChameleonProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`ChameleonImageProcessor`]):
+ The image processor is a required input.
+ tokenizer ([`LlamaTokenizerFast`]):
+ The tokenizer is a required input.
+ image_seq_length (`int`, *optional*, defaults to 1024):
+ Sequence length of one image embedding.
+ image_token (`str`, *optional*, defaults to `""`):
+ The special token used to indicate image in the text.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
+ image_processor_class = "ChameleonImageProcessor"
+
+ def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""):
+ self.image_seq_length = image_seq_length
+ self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
+ self.image_start_token = (
+ tokenizer.boi_token if hasattr(tokenizer, "boi_token") else ""
+ ) # fixed tokens for start and end, so can hardcode
+ self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else ""
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
+ self.image_start_token_id = tokenizer.convert_tokens_to_ids(self.image_start_token)
+ self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token)
+ self.image_ids = [self.image_token_id, self.image_start_token_id, self.image_end_token_id]
+
+ super().__init__(image_processor, tokenizer)
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
+ audio=None,
+ videos=None,
+ **kwargs: Unpack[ChameleonProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
+ of the above two methods for more information.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list) and not isinstance(text[0], str):
+ raise TypeError("Invalid input text. Please provide a string, or a list of strings")
+ if text is None and images is None:
+ raise ValueError("You must provide either text or images")
+
+ output_kwargs = self._merge_kwargs(
+ ChameleonProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False)
+
+ # Replace the image token with the expanded image token sequence
+ prompt_strings = []
+ one_img_tokens = self.image_start_token + (self.image_token * self.image_seq_length) + self.image_end_token
+ for sample in text:
+ sample = sample.replace(self.image_token, one_img_tokens)
+ if not return_for_text_completion:
+ sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
+ prompt_strings.append(sample)
+
+ image_inputs = {}
+ if images is not None:
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
+ self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+
+ Args:
+ image_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (height, width) per each image.
+
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+
+ vision_data = {}
+ if image_sizes is not None:
+ # add 2 for BOI and EOI tokens
+ num_image_tokens = [self.image_seq_length + 2] * len(image_sizes)
+ num_image_patches = [1] * len(image_sizes)
+
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+
+ return MultiModalData(**vision_data)
+
+
+__all__ = ["ChameleonProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/code_llama/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/code_llama/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b65c4bddb4b0cd3fa8dfd6a781a3c0f58e30e5a7
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/code_llama/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .tokenization_code_llama import *
+ from .tokenization_code_llama_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/code_llama/tokenization_code_llama.py b/venv/lib/python3.13/site-packages/transformers/models/code_llama/tokenization_code_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..94d1b4d659851a5f466c8123dd032ab213a90ed4
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/code_llama/tokenization_code_llama.py
@@ -0,0 +1,454 @@
+# coding=utf-8
+# Copyright 2023 MetaAI and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tokenization classes for Code LLaMA."""
+
+import os
+from shutil import copyfile
+from typing import Any, Optional
+
+import sentencepiece as spm
+
+from ...convert_slow_tokenizer import import_protobuf
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging, requires_backends
+from ...utils.import_utils import requires
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
+
+SPIECE_UNDERLINE = "▁"
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<>\n", "\n<>\n\n"
+
+# fmt: off
+DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
+answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
+ that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
+correct. If you don't know the answer to a question, please don't share false information."""
+# fmt: on
+
+
+@requires(backends=("sentencepiece",))
+class CodeLlamaTokenizer(PreTrainedTokenizer):
+ """
+ Construct a CodeLlama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as
+ there is no padding token in the original model.
+
+ The default configuration match that of
+ [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json)
+ which supports prompt infilling.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ prefix_token (`str`, *optional*, defaults to `"▁
"`):
+ Prefix token used for infilling.
+ middle_token (`str`, *optional*, defaults to `"▁"`):
+ Middle token used for infilling.
+ suffix_token (`str`, *optional*, defaults to `"▁"`):
+ Suffix token used for infilling.
+ eot_token (`str`, *optional*, defaults to `"▁"`):
+ End of text token used for infilling.
+ fill_token (`str`, *optional*, defaults to `""`):
+ The token used to split the input between the prefix and suffix.
+ suffix_first (`bool`, *optional*, defaults to `False`):
+ Whether the input prompt and suffix should be formatted with the suffix first.
+ sp_model_kwargs (`dict`, *optional*):
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+ to set:
+
+ - `enable_sampling`: Enable subword regularization.
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+ - `nbest_size = {0,1}`: No sampling is performed.
+ - `nbest_size > 1`: samples from the nbest_size results.
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+ using forward-filtering-and-backward-sampling algorithm.
+
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+ BPE-dropout.
+ add_bos_token (`bool`, *optional*, defaults to `True`):
+ Whether to add a beginning of sequence token at the start of sequences.
+ add_eos_token (`bool`, *optional*, defaults to `False`):
+ Whether to add an end of sequence token at the end of sequences.
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+ Whether or not to clean up the tokenization spaces.
+ additional_special_tokens (`list[str]`, *optional*):
+ Additional special tokens used by the tokenizer.
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
+ Whether or not the default system prompt for Llama should be used.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ unk_token="",
+ bos_token="",
+ eos_token="",
+ prefix_token="▁
",
+ middle_token="▁",
+ suffix_token="▁",
+ eot_token="▁",
+ fill_token="",
+ suffix_first=False,
+ sp_model_kwargs: Optional[dict[str, Any]] = None,
+ add_bos_token=True,
+ add_eos_token=False,
+ clean_up_tokenization_spaces=False,
+ additional_special_tokens=None,
+ use_default_system_prompt=False,
+ **kwargs,
+ ):
+ requires_backends(self, "protobuf")
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
+
+ self.use_default_system_prompt = use_default_system_prompt
+ # mark tokens special to skip them
+ additional_special_tokens = additional_special_tokens or []
+ for token in [prefix_token, middle_token, suffix_token, eot_token]:
+ additional_special_tokens += [token] if token is not None else []
+
+ self.vocab_file = vocab_file
+ self.add_bos_token = add_bos_token
+ self.add_eos_token = add_eos_token
+ self._prefix_token = prefix_token
+ self._middle_token = middle_token
+ self._suffix_token = suffix_token
+ self._eot_token = eot_token
+ self.fill_token = fill_token
+ self.suffix_first = suffix_first
+ self.sp_model = self.get_spm_processor()
+
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
+ prefix_token=prefix_token,
+ middle_token=middle_token,
+ suffix_token=suffix_token,
+ eot_token=eot_token,
+ fill_token=fill_token,
+ sp_model_kwargs=self.sp_model_kwargs,
+ suffix_first=suffix_first,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ additional_special_tokens=additional_special_tokens,
+ use_default_system_prompt=use_default_system_prompt,
+ **kwargs,
+ )
+
+ @property
+ def unk_token_length(self):
+ return len(self.sp_model.encode(str(self.unk_token)))
+
+ def get_spm_processor(self):
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ with open(self.vocab_file, "rb") as f:
+ sp_model = f.read()
+ model_pb2 = import_protobuf()
+ model = model_pb2.ModelProto.FromString(sp_model)
+ normalizer_spec = model_pb2.NormalizerSpec()
+ normalizer_spec.add_dummy_prefix = False
+ model.normalizer_spec.MergeFrom(normalizer_spec)
+ sp_model = model.SerializeToString()
+ tokenizer.LoadFromSerializedProto(sp_model)
+ return tokenizer
+
+ @property
+ def prefix_token(self):
+ return self._prefix_token
+
+ @property
+ def prefix_id(self):
+ if self._prefix_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.prefix_token)
+
+ @property
+ def middle_token(self):
+ return self._middle_token
+
+ @property
+ def middle_id(self):
+ if self._middle_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.middle_token)
+
+ @property
+ def suffix_token(self):
+ return self._suffix_token
+
+ @property
+ def suffix_id(self):
+ if self._suffix_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.suffix_token)
+
+ @property
+ def eot_token(self):
+ return self._eot_token
+
+ @property
+ def eot_id(self):
+ if self._eot_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.eot_token)
+
+ @property
+ def vocab_size(self):
+ """Returns vocab size"""
+ return self.sp_model.get_piece_size()
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab
+ def get_vocab(self):
+ """Returns vocab as a dict"""
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def tokenize(self, prefix, suffix=None, suffix_first=False, **kwargs) -> list[int]:
+ # add a prefix space to `prefix`
+ if self.fill_token is not None and self.fill_token in prefix and suffix is None:
+ prefix, suffix = prefix.split(self.fill_token)
+
+ if len(prefix) > 0:
+ prefix = SPIECE_UNDERLINE + prefix.replace(SPIECE_UNDERLINE, " ")
+
+ if suffix is None or len(suffix) < 1:
+ tokens = super().tokenize(prefix, **kwargs)
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
+ tokens = tokens[1:]
+ return tokens
+
+ prefix_tokens = self._tokenize(prefix) # prefix has an extra `SPIECE_UNDERLINE`
+
+ if None in (self.prefix_id, self.middle_id, self.suffix_id):
+ raise ValueError(
+ "The input either includes a `prefix` and a `suffix` used for the infilling task,"
+ f" or can be split on the {self.fill_token} token, creating a suffix and prefix,"
+ " but the model does not support `infilling`."
+ )
+ suffix_tokens = self._tokenize(suffix) # make sure CodeLlama sp model does not mess up
+
+ suffix_first = suffix_first if suffix_first is not None else self.suffix_first
+ if suffix_first:
+ # format as "
{suf} {pre}"
+ return [self.prefix_token, self.suffix_token] + suffix_tokens + [self.middle_token] + prefix_tokens
+ else:
+ # format as "
{pre} {suf} "
+ return [self.prefix_token] + prefix_tokens + [self.suffix_token] + suffix_tokens + [self.middle_token]
+
+ def _tokenize(self, text, **kwargs):
+ """
+ Returns a tokenized string.
+
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
+ SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
+ `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
+ `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`.
+ `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`.
+ """
+ tokens = self.sp_model.encode(text, out_type=str)
+ if not text.startswith((SPIECE_UNDERLINE, " ")):
+ return tokens
+ # 1. Encode string + prefix ex: " Hey"
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.sp_model.piece_to_id(token)
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ token = self.sp_model.IdToPiece(index)
+ return token
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ # since we manually add the prefix space, we have to remove it when decoding
+ if tokens[0].startswith(SPIECE_UNDERLINE):
+ tokens[0] = tokens[0][1:]
+
+ current_sub_tokens = []
+ out_string = ""
+ for _, token in enumerate(tokens):
+ # make sure that special tokens are not decoded using sentencepiece model
+ if token in self.all_special_tokens:
+ out_string += self.sp_model.decode(current_sub_tokens) + token
+ current_sub_tokens = []
+ else:
+ current_sub_tokens.append(token)
+ out_string += self.sp_model.decode(current_sub_tokens)
+ return out_string
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> tuple[str]:
+ """
+ Save the vocabulary and special tokens file to a directory.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+ output = bos_token_id + token_ids_0 + eos_token_id
+
+ if token_ids_1 is not None:
+ output = output + bos_token_id + token_ids_1 + eos_token_id
+
+ return output
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ bos_token_id = [1] if self.add_bos_token else []
+ eos_token_id = [1] if self.add_eos_token else []
+
+ if token_ids_1 is None:
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
+ return (
+ bos_token_id
+ + ([0] * len(token_ids_0))
+ + eos_token_id
+ + bos_token_id
+ + ([0] * len(token_ids_1))
+ + eos_token_id
+ )
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
+ sequence pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of ids.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
+
+ if token_ids_1 is not None:
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
+
+ return output
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
+
+
+__all__ = ["CodeLlamaTokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py b/venv/lib/python3.13/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3978587e7f02512a5344f9ad0a33bf86b839757
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py
@@ -0,0 +1,374 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from shutil import copyfile
+from typing import Optional
+
+from tokenizers import normalizers, processors
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import is_sentencepiece_available, logging
+
+
+if is_sentencepiece_available():
+ from .tokenization_code_llama import CodeLlamaTokenizer
+else:
+ CodeLlamaTokenizer = None
+
+logger = logging.get_logger(__name__)
+VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
+
+SPIECE_UNDERLINE = "▁"
+
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<>\n", "\n<>\n\n"
+
+# fmt: off
+DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
+answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
+ that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
+correct. If you don't know the answer to a question, please don't share false information."""
+# fmt: on
+
+
+class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ This uses notably ByteFallback and no normalization.
+
+ ```python
+ >>> from transformers import CodeLlamaTokenizerFast
+
+ >>> tokenizer = CodeLlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
+ >>> tokenizer.encode("Hello this is a test")
+ [1, 15043, 445, 338, 263, 1243]
+ ```
+
+ If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
+ call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
+ values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
+ [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
+
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods. The default configuration match that of
+ [meta-llama/CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json)
+ which supports prompt infilling.
+
+ Args:
+ vocab_file (`str`, *optional*):
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ tokenizer_file (`str`, *optional*):
+ [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
+ contains everything needed to load the tokenizer.
+ clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
+ Whether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
+ spaces.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+ prefix_token (`str`, *optional*, defaults to `"▁
"`):
+ Prefix token used for infilling.
+ middle_token (`str`, *optional*, defaults to `"▁"`):
+ Middle token used for infilling.
+ suffix_token (`str`, *optional*, defaults to `"▁"`):
+ Suffix token used for infilling.
+ eot_token (`str`, *optional*, defaults to `"▁"`):
+ End of text token used for infilling.
+ fill_token (`str`, *optional*, defaults to `""`):
+ The token used to split the input between the prefix and suffix.
+ additional_special_tokens (`list[str]`, *optional*):
+ Additional special tokens used by the tokenizer.
+ add_bos_token (`bool`, *optional*, defaults to `True`):
+ Whether to add a beginning of sequence token at the start of sequences.
+ add_eos_token (`bool`, *optional*, defaults to `False`):
+ Whether to add an end of sequence token at the end of sequences.
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
+ Whether or not the default system prompt for Llama should be used.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = CodeLlamaTokenizer
+ padding_side = "left"
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ clean_up_tokenization_spaces=False,
+ unk_token="",
+ bos_token="",
+ eos_token="",
+ prefix_token="▁
",
+ middle_token="▁",
+ suffix_token="▁",
+ eot_token="▁",
+ fill_token="",
+ additional_special_tokens=None,
+ add_bos_token=True,
+ add_eos_token=False,
+ use_default_system_prompt=False,
+ **kwargs,
+ ):
+ # mark tokens special to skip them
+ additional_special_tokens = additional_special_tokens or []
+ for token in [prefix_token, middle_token, suffix_token, eot_token]:
+ additional_special_tokens += [token] if token is not None else []
+ self.use_default_system_prompt = use_default_system_prompt
+
+ super().__init__(
+ vocab_file=vocab_file,
+ tokenizer_file=tokenizer_file,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ additional_special_tokens=additional_special_tokens,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
+ prefix_token=prefix_token,
+ middle_token=middle_token,
+ suffix_token=suffix_token,
+ eot_token=eot_token,
+ fill_token=fill_token,
+ use_default_system_prompt=use_default_system_prompt,
+ **kwargs,
+ )
+ self._add_bos_token = add_bos_token
+ self._add_eos_token = add_eos_token
+ self.update_post_processor()
+
+ self.vocab_file = vocab_file
+
+ self._prefix_token = prefix_token
+ self._middle_token = middle_token
+ self._suffix_token = suffix_token
+ self._eot_token = eot_token
+ self.fill_token = fill_token
+
+ # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor
+ def update_post_processor(self):
+ """
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
+ """
+ bos = self.bos_token
+ bos_token_id = self.bos_token_id
+ if bos is None and self.add_bos_token:
+ raise ValueError("add_bos_token = True but bos_token = None")
+
+ eos = self.eos_token
+ eos_token_id = self.eos_token_id
+ if eos is None and self.add_eos_token:
+ raise ValueError("add_eos_token = True but eos_token = None")
+
+ single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
+ pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
+
+ special_tokens = []
+ if self.add_bos_token:
+ special_tokens.append((bos, bos_token_id))
+ if self.add_eos_token:
+ special_tokens.append((eos, eos_token_id))
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=single, pair=pair, special_tokens=special_tokens
+ )
+
+ @property
+ def prefix_token(self):
+ return self._prefix_token
+
+ @property
+ def prefix_id(self):
+ if self._prefix_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.prefix_token)
+
+ @property
+ def middle_token(self):
+ return self._middle_token
+
+ @property
+ def middle_id(self):
+ if self._middle_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.middle_token)
+
+ @property
+ def suffix_token(self):
+ return self._suffix_token
+
+ @property
+ def suffix_id(self):
+ if self._suffix_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.suffix_token)
+
+ @property
+ def eot_id(self):
+ if self._eot_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.eot_token)
+
+ @property
+ def eot_token(self):
+ return self._eot_token
+
+ @property
+ def add_eos_token(self):
+ return self._add_eos_token
+
+ @property
+ def add_bos_token(self):
+ return self._add_bos_token
+
+ @add_eos_token.setter
+ def add_eos_token(self, value):
+ self._add_eos_token = value
+ self.update_post_processor()
+
+ @add_bos_token.setter
+ def add_bos_token(self, value):
+ self._add_bos_token = value
+ self.update_post_processor()
+
+ def set_infilling_processor(self, reset, suffix_first=False, add_special_tokens=True):
+ """
+ Updates the normalizer to make sure the prompt format for `infilling` is respected. The infilling format is the
+ following: if suffix_first
+ "
{suf} {pre}"
+ else:
+ "
{pre} {suf} "
+
+ If `reset` is set to `True`, the `normalizer` and `post_processor` are reset to their "normal" behaviour, which
+ is to add a prefix space for the normalizer, and add a `bos_token` to the input text for the `post_processor`.
+ """
+ if reset:
+ self._tokenizer.normalizer = normalizers.Sequence(
+ [
+ normalizers.Prepend(prepend="▁"),
+ normalizers.Replace(pattern=" ", content="▁"),
+ ]
+ )
+ self.update_post_processor()
+ return
+
+ self._tokenizer.normalizer = normalizers.Replace(pattern=" ", content="▁")
+ pair = [self.bos_token] if self.add_bos_token and add_special_tokens else []
+ special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else []
+ if suffix_first:
+ # format as "
{pre} {suf} "
+ pair += [self.prefix_token, "$A", self.suffix_token, "$B", self.middle_token]
+ special_tokens += [
+ (self.prefix_token, self.prefix_id),
+ (self.suffix_token, self.suffix_id),
+ (self.middle_token, self.middle_id),
+ ]
+
+ if self.add_eos_token and add_special_tokens:
+ pair += [self.eos_token]
+ special_tokens += [(self.eos_token, self.eos_token_id)]
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single="$A", pair=pair, special_tokens=special_tokens
+ )
+
+ def encode_plus(self, text, text_pair=None, suffix_first=False, add_special_tokens=True, **kwargs):
+ # hack to make sure the input is pre-process but outside rust
+ text_pair = kwargs.pop("suffix", text_pair)
+ if self.fill_token is not None and self.fill_token in text and text_pair is None:
+ text, text_pair = text.split(self.fill_token)
+
+ if text_pair is None or len(text_pair) < 1:
+ return super().encode_plus(text, text_pair, add_special_tokens=add_special_tokens, **kwargs)
+
+ if None in (self.prefix_id, self.middle_id, self.suffix_id):
+ raise ValueError(
+ "Then input includes a `prefix` and a `suffix` used for the infilling task,"
+ " the `prefix_id, middle_id, suffix_id` must all be initialized. Current"
+ f" values : {self.prefix_id, self.middle_id, self.suffix_id}"
+ )
+
+ self.set_infilling_processor(False, suffix_first=suffix_first, add_special_tokens=add_special_tokens)
+ tokens = super().encode_plus(" " + text, text_pair=text_pair, add_special_tokens=True, **kwargs)
+ self.set_infilling_processor(True)
+ return tokens
+
+ # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.save_vocabulary
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not self.can_save_slow_tokenizer:
+ raise ValueError(
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+ "tokenizer."
+ )
+
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. The special tokens depend on calling set_lang.
+
+ An NLLB sequence has the following format, where `X` represents the sequence:
+
+ - `input_ids` (for encoder) `X [eos, src_lang_code]`
+ - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
+
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+ separator.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return self.bos_token_id + token_ids_0 + self.eos_token_id
+ return self.bos_token_id + token_ids_0 + token_ids_1 + self.eos_token_id
+
+
+__all__ = ["CodeLlamaTokenizerFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/convnextv2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/convnextv2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fd1293963b233da99850c67212dc2998102b126
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/convnextv2/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_convnextv2 import *
+ from .modeling_convnextv2 import *
+ from .modeling_tf_convnextv2 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/convnextv2/configuration_convnextv2.py b/venv/lib/python3.13/site-packages/transformers/models/convnextv2/configuration_convnextv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..53f1825ca57c89249645c98aa54c278d8e7b4a61
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/convnextv2/configuration_convnextv2.py
@@ -0,0 +1,118 @@
+# coding=utf-8
+# Copyright 2023 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ConvNeXTV2 model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class ConvNextV2Config(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ConvNextV2Model`]. It is used to instantiate an
+ ConvNeXTV2 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the ConvNeXTV2
+ [facebook/convnextv2-tiny-1k-224](https://huggingface.co/facebook/convnextv2-tiny-1k-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ patch_size (`int`, *optional*, defaults to 4):
+ Patch size to use in the patch embedding layer.
+ num_stages (`int`, *optional*, defaults to 4):
+ The number of stages in the model.
+ hidden_sizes (`list[int]`, *optional*, defaults to `[96, 192, 384, 768]`):
+ Dimensionality (hidden size) at each stage.
+ depths (`list[int]`, *optional*, defaults to `[3, 3, 9, 3]`):
+ Depth (number of blocks) for each stage.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in each block. If string, `"gelu"`, `"relu"`,
+ `"selu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ The drop rate for stochastic depth.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ out_features (`list[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`list[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+
+ Example:
+ ```python
+ >>> from transformers import ConvNeXTV2Config, ConvNextV2Model
+
+ >>> # Initializing a ConvNeXTV2 convnextv2-tiny-1k-224 style configuration
+ >>> configuration = ConvNeXTV2Config()
+
+ >>> # Initializing a model (with random weights) from the convnextv2-tiny-1k-224 style configuration
+ >>> model = ConvNextV2Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "convnextv2"
+
+ def __init__(
+ self,
+ num_channels=3,
+ patch_size=4,
+ num_stages=4,
+ hidden_sizes=None,
+ depths=None,
+ hidden_act="gelu",
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ drop_path_rate=0.0,
+ image_size=224,
+ out_features=None,
+ out_indices=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.num_stages = num_stages
+ self.hidden_sizes = [96, 192, 384, 768] if hidden_sizes is None else hidden_sizes
+ self.depths = [3, 3, 9, 3] if depths is None else depths
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.drop_path_rate = drop_path_rate
+ self.image_size = image_size
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+
+
+__all__ = ["ConvNextV2Config"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/convnextv2/modeling_convnextv2.py b/venv/lib/python3.13/site-packages/transformers/models/convnextv2/modeling_convnextv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfa5338f5e86fe39549e7d30928f5e72422fc9cc
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/convnextv2/modeling_convnextv2.py
@@ -0,0 +1,447 @@
+# coding=utf-8
+# Copyright 2023 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ConvNextV2 model."""
+
+from typing import Optional
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+ BackboneOutput,
+ BaseModelOutputWithNoAttention,
+ BaseModelOutputWithPoolingAndNoAttention,
+ ImageClassifierOutputWithNoAttention,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging
+from ...utils.backbone_utils import BackboneMixin
+from ...utils.generic import can_return_tuple
+from .configuration_convnextv2 import ConvNextV2Config
+
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNextV2
+class ConvNextV2DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return f"p={self.drop_prob}"
+
+
+class ConvNextV2GRN(nn.Module):
+ """GRN (Global Response Normalization) layer"""
+
+ def __init__(self, dim: int):
+ super().__init__()
+ self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim))
+ self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim))
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ # Compute and normalize global spatial feature maps
+ global_features = torch.linalg.vector_norm(hidden_states, ord=2, dim=(1, 2), keepdim=True)
+ norm_features = global_features / (global_features.mean(dim=-1, keepdim=True) + 1e-6)
+ hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states
+
+ return hidden_states
+
+
+# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->ConvNextV2
+class ConvNextV2LayerNorm(nn.LayerNorm):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
+ super().__init__(normalized_shape, eps=eps, **kwargs)
+ if data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError(f"Unsupported data format: {data_format}")
+ self.data_format = data_format
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
+ """
+ if self.data_format == "channels_first":
+ features = features.permute(0, 2, 3, 1)
+ features = super().forward(features)
+ features = features.permute(0, 3, 1, 2)
+ else:
+ features = super().forward(features)
+ return features
+
+
+# Copied from transformers.models.convnext.modeling_convnext.ConvNextEmbeddings with ConvNext->ConvNextV2
+class ConvNextV2Embeddings(nn.Module):
+ """This class is comparable to (and inspired by) the SwinEmbeddings class
+ found in src/transformers/models/swin/modeling_swin.py.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.patch_embeddings = nn.Conv2d(
+ config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
+ )
+ self.layernorm = ConvNextV2LayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
+ self.num_channels = config.num_channels
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ embeddings = self.patch_embeddings(pixel_values)
+ embeddings = self.layernorm(embeddings)
+ return embeddings
+
+
+class ConvNextV2Layer(nn.Module):
+ """This corresponds to the `Block` class in the original implementation.
+
+ There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
+ H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
+
+ The authors used (2) as they find it slightly faster in PyTorch.
+
+ Args:
+ config ([`ConvNextV2Config`]): Model configuration class.
+ dim (`int`): Number of input channels.
+ drop_path (`float`): Stochastic depth rate. Default: 0.0.
+ """
+
+ def __init__(self, config, dim, drop_path=0):
+ super().__init__()
+ # depthwise conv
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
+ self.layernorm = ConvNextV2LayerNorm(dim, eps=1e-6)
+ # pointwise/1x1 convs, implemented with linear layers
+ self.pwconv1 = nn.Linear(dim, 4 * dim)
+ self.act = ACT2FN[config.hidden_act]
+ self.grn = ConvNextV2GRN(4 * dim)
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.drop_path = ConvNextV2DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ residual = features
+ features = self.dwconv(features)
+ # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
+ features = features.permute(0, 2, 3, 1)
+ features = self.layernorm(features)
+ features = self.pwconv1(features)
+ features = self.act(features)
+ features = self.grn(features)
+ features = self.pwconv2(features)
+ # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width)
+ features = features.permute(0, 3, 1, 2)
+
+ features = residual + self.drop_path(features)
+ return features
+
+
+# Copied from transformers.models.convnext.modeling_convnext.ConvNextStage with ConvNeXT->ConvNeXTV2, ConvNext->ConvNextV2
+class ConvNextV2Stage(nn.Module):
+ """ConvNeXTV2 stage, consisting of an optional downsampling layer + multiple residual blocks.
+
+ Args:
+ config ([`ConvNextV2Config`]): Model configuration class.
+ in_channels (`int`): Number of input channels.
+ out_channels (`int`): Number of output channels.
+ depth (`int`): Number of residual blocks.
+ drop_path_rates(`list[float]`): Stochastic depth rates for each layer.
+ """
+
+ def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None):
+ super().__init__()
+
+ if in_channels != out_channels or stride > 1:
+ self.downsampling_layer = nn.ModuleList(
+ [
+ ConvNextV2LayerNorm(in_channels, eps=1e-6, data_format="channels_first"),
+ nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
+ ]
+ )
+ else:
+ self.downsampling_layer = nn.ModuleList()
+ drop_path_rates = drop_path_rates or [0.0] * depth
+ self.layers = nn.ModuleList(
+ [ConvNextV2Layer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
+ )
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ for layer in self.downsampling_layer:
+ features = layer(features)
+ for layer in self.layers:
+ features = layer(features)
+ return features
+
+
+# Copied from transformers.models.convnext.modeling_convnext.ConvNextEncoder with ConvNext->ConvNextV2
+class ConvNextV2Encoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.stages = nn.ModuleList()
+ drop_path_rates = [
+ x.tolist()
+ for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").split(config.depths)
+ ]
+ prev_chs = config.hidden_sizes[0]
+ for i in range(config.num_stages):
+ out_chs = config.hidden_sizes[i]
+ stage = ConvNextV2Stage(
+ config,
+ in_channels=prev_chs,
+ out_channels=out_chs,
+ stride=2 if i > 0 else 1,
+ depth=config.depths[i],
+ drop_path_rates=drop_path_rates[i],
+ )
+ self.stages.append(stage)
+ prev_chs = out_chs
+
+ def forward(
+ self, hidden_states: torch.Tensor, output_hidden_states: Optional[bool] = False
+ ) -> BaseModelOutputWithNoAttention:
+ all_hidden_states = [hidden_states] if output_hidden_states else None
+
+ for layer_module in self.stages:
+ hidden_states = layer_module(hidden_states)
+ if all_hidden_states is not None:
+ all_hidden_states.append(hidden_states)
+
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
+
+
+@auto_docstring
+class ConvNextV2PreTrainedModel(PreTrainedModel):
+ config: ConvNextV2Config
+ base_model_prefix = "convnextv2"
+ main_input_name = "pixel_values"
+ _no_split_modules = ["ConvNextV2Layer"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, ConvNextV2LayerNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, ConvNextV2GRN):
+ module.weight.data.zero_()
+ module.bias.data.zero_()
+
+
+@auto_docstring
+# Copied from transformers.models.convnext.modeling_convnext.ConvNextModel with CONVNEXT->CONVNEXTV2, ConvNext->ConvNextV2
+class ConvNextV2Model(ConvNextV2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = ConvNextV2Embeddings(config)
+ self.encoder = ConvNextV2Encoder(config)
+
+ # final layernorm layer
+ self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None
+ ) -> BaseModelOutputWithPoolingAndNoAttention:
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.embeddings(pixel_values)
+ encoder_outputs: BaseModelOutputWithNoAttention = self.encoder(
+ embedding_output, output_hidden_states=output_hidden_states
+ )
+ last_hidden_state = encoder_outputs.last_hidden_state
+
+ # global average pooling, (N, C, H, W) -> (N, C)
+ pooled_output = self.layernorm(last_hidden_state.mean([-2, -1]))
+
+ return BaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ ConvNextV2 Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """
+)
+# Copied from transformers.models.convnext.modeling_convnext.ConvNextForImageClassification with CONVNEXT->CONVNEXTV2,ConvNext->ConvNextV2,convnext->convnextv2
+class ConvNextV2ForImageClassification(ConvNextV2PreTrainedModel):
+ accepts_loss_kwargs = False
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.convnextv2 = ConvNextV2Model(config)
+
+ # Classifier head
+ if config.num_labels > 0:
+ self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels)
+ else:
+ self.classifier = nn.Identity()
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self, pixel_values: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, **kwargs
+ ) -> ImageClassifierOutputWithNoAttention:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ outputs: BaseModelOutputWithPoolingAndNoAttention = self.convnextv2(pixel_values, **kwargs)
+ pooled_output = outputs.pooler_output
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels=labels, pooled_logits=logits, config=self.config)
+
+ return ImageClassifierOutputWithNoAttention(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ ConvNeXT V2 backbone, to be used with frameworks like DETR and MaskFormer.
+ """
+)
+# Copied from transformers.models.convnext.modeling_convnext.ConvNextBackbone with CONVNEXT->CONVNEXTV2,ConvNext->ConvNextV2,facebook/convnext-tiny-224->facebook/convnextv2-tiny-1k-224
+class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
+ has_attentions = False
+
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.embeddings = ConvNextV2Embeddings(config)
+ self.encoder = ConvNextV2Encoder(config)
+ self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
+
+ # Add layer norms to hidden states of out_features
+ hidden_states_norms = {}
+ for stage, num_channels in zip(self._out_features, self.channels):
+ hidden_states_norms[stage] = ConvNextV2LayerNorm(num_channels, data_format="channels_first")
+ self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
+
+ # initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ ) -> BackboneOutput:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/convnextv2-tiny-1k-224")
+ >>> model = AutoBackbone.from_pretrained("facebook/convnextv2-tiny-1k-224")
+
+ >>> inputs = processor(image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ ```"""
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ embedding_output = self.embeddings(pixel_values)
+ outputs: BaseModelOutputWithPoolingAndNoAttention = self.encoder(embedding_output, output_hidden_states=True)
+ hidden_states = outputs.hidden_states
+
+ feature_maps = []
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
+ feature_maps.append(hidden_state)
+
+ return BackboneOutput(
+ feature_maps=tuple(feature_maps),
+ hidden_states=hidden_states if output_hidden_states else None,
+ )
+
+
+__all__ = ["ConvNextV2ForImageClassification", "ConvNextV2Model", "ConvNextV2PreTrainedModel", "ConvNextV2Backbone"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/convnextv2/modeling_tf_convnextv2.py b/venv/lib/python3.13/site-packages/transformers/models/convnextv2/modeling_tf_convnextv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..d370c3008d4701cdd4c3c78572a29218cc55693b
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/convnextv2/modeling_tf_convnextv2.py
@@ -0,0 +1,681 @@
+# coding=utf-8
+# Copyright 2023 Meta Platforms Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 ConvNextV2 model."""
+
+from __future__ import annotations
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutputWithNoAttention,
+ TFBaseModelOutputWithPooling,
+ TFBaseModelOutputWithPoolingAndNoAttention,
+ TFImageClassifierOutputWithNoAttention,
+)
+from ...modeling_tf_utils import (
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
+from .configuration_convnextv2 import ConvNextV2Config
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "ConvNextV2Config"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/convnextv2-tiny-1k-224"
+_EXPECTED_OUTPUT_SHAPE = [1, 768, 7, 7]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/convnextv2-tiny-1k-224"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->ConvNextV2
+class TFConvNextV2DropPath(keras.layers.Layer):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ References:
+ (1) github.com:rwightman/pytorch-image-models
+ """
+
+ def __init__(self, drop_path: float, **kwargs):
+ super().__init__(**kwargs)
+ self.drop_path = drop_path
+
+ def call(self, x: tf.Tensor, training=None):
+ if training:
+ keep_prob = 1 - self.drop_path
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
+ random_tensor = tf.floor(random_tensor)
+ return (x / keep_prob) * random_tensor
+ return x
+
+
+class TFConvNextV2GRN(keras.layers.Layer):
+ """GRN (Global Response Normalization) layer"""
+
+ def __init__(self, config: ConvNextV2Config, dim: int, **kwargs):
+ super().__init__(**kwargs)
+ self.dim = dim
+
+ def build(self, input_shape: tf.TensorShape = None):
+ # PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
+ self.weight = self.add_weight(
+ name="weight",
+ shape=(1, 1, 1, self.dim),
+ initializer=keras.initializers.Zeros(),
+ )
+ self.bias = self.add_weight(
+ name="bias",
+ shape=(1, 1, 1, self.dim),
+ initializer=keras.initializers.Zeros(),
+ )
+ return super().build(input_shape)
+
+ def call(self, hidden_states: tf.Tensor):
+ global_features = tf.norm(hidden_states, ord="euclidean", axis=(1, 2), keepdims=True)
+ norm_features = global_features / (tf.reduce_mean(global_features, axis=-1, keepdims=True) + 1e-6)
+ hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states
+ return hidden_states
+
+
+# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextEmbeddings with ConvNext->ConvNextV2
+class TFConvNextV2Embeddings(keras.layers.Layer):
+ """This class is comparable to (and inspired by) the SwinEmbeddings class
+ found in src/transformers/models/swin/modeling_swin.py.
+ """
+
+ def __init__(self, config: ConvNextV2Config, **kwargs):
+ super().__init__(**kwargs)
+ self.patch_embeddings = keras.layers.Conv2D(
+ filters=config.hidden_sizes[0],
+ kernel_size=config.patch_size,
+ strides=config.patch_size,
+ name="patch_embeddings",
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer=keras.initializers.Zeros(),
+ )
+ self.layernorm = keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
+ self.num_channels = config.num_channels
+ self.config = config
+
+ def call(self, pixel_values):
+ if isinstance(pixel_values, dict):
+ pixel_values = pixel_values["pixel_values"]
+
+ tf.debugging.assert_equal(
+ shape_list(pixel_values)[1],
+ self.num_channels,
+ message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.",
+ )
+
+ # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
+ # So change the input format from `NCHW` to `NHWC`.
+ # shape = (batch_size, in_height, in_width, in_channels)
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+
+ embeddings = self.patch_embeddings(pixel_values)
+ embeddings = self.layernorm(embeddings)
+ return embeddings
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "patch_embeddings", None) is not None:
+ with tf.name_scope(self.patch_embeddings.name):
+ self.patch_embeddings.build([None, None, None, self.config.num_channels])
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, None, None, self.config.hidden_sizes[0]])
+
+
+class TFConvNextV2Layer(keras.layers.Layer):
+ """This corresponds to the `Block` class in the original implementation.
+
+ There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
+ H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
+
+ The authors used (2) as they find it slightly faster in PyTorch. Since we already permuted the inputs to follow
+ NHWC ordering, we can just apply the operations straight-away without the permutation.
+
+ Args:
+ config (`ConvNextV2Config`):
+ Model configuration class.
+ dim (`int`):
+ Number of input channels.
+ drop_path (`float`, *optional*, defaults to 0.0):
+ Stochastic depth rate.
+ """
+
+ def __init__(self, config: ConvNextV2Config, dim: int, drop_path: float = 0.0, **kwargs):
+ super().__init__(**kwargs)
+ self.dim = dim
+ self.config = config
+ self.dwconv = keras.layers.Conv2D(
+ filters=dim,
+ kernel_size=7,
+ padding="same",
+ groups=dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer=keras.initializers.Zeros(),
+ name="dwconv",
+ ) # depthwise conv
+ self.layernorm = keras.layers.LayerNormalization(
+ epsilon=1e-6,
+ name="layernorm",
+ )
+ self.pwconv1 = keras.layers.Dense(
+ units=4 * dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer=keras.initializers.Zeros(),
+ name="pwconv1",
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = get_tf_activation(config.hidden_act)
+ self.grn = TFConvNextV2GRN(config, 4 * dim, dtype=tf.float32, name="grn")
+ self.pwconv2 = keras.layers.Dense(
+ units=dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer=keras.initializers.Zeros(),
+ name="pwconv2",
+ )
+ # Using `layers.Activation` instead of `tf.identity` to better control `training`
+ # behaviour.
+ self.drop_path = (
+ TFConvNextV2DropPath(drop_path, name="drop_path")
+ if drop_path > 0.0
+ else keras.layers.Activation("linear", name="drop_path")
+ )
+
+ def call(self, hidden_states, training=False):
+ input = hidden_states
+ x = self.dwconv(hidden_states)
+ x = self.layernorm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.grn(x)
+ x = self.pwconv2(x)
+ x = self.drop_path(x, training=training)
+ x = input + x
+ return x
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dwconv", None) is not None:
+ with tf.name_scope(self.dwconv.name):
+ self.dwconv.build([None, None, None, self.dim])
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, None, None, self.dim])
+ if getattr(self, "pwconv1", None) is not None:
+ with tf.name_scope(self.pwconv1.name):
+ self.pwconv1.build([None, None, self.dim])
+ if getattr(self, "grn", None) is not None:
+ with tf.name_scope(self.grn.name):
+ self.grn.build(None)
+ if getattr(self, "pwconv2", None) is not None:
+ with tf.name_scope(self.pwconv2.name):
+ self.pwconv2.build([None, None, 4 * self.dim])
+ if getattr(self, "drop_path", None) is not None:
+ with tf.name_scope(self.drop_path.name):
+ self.drop_path.build(None)
+
+
+# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextStage with ConvNext->ConvNextV2
+class TFConvNextV2Stage(keras.layers.Layer):
+ """ConvNextV2 stage, consisting of an optional downsampling layer + multiple residual blocks.
+
+ Args:
+ config (`ConvNextV2V2Config`):
+ Model configuration class.
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`):
+ Number of output channels.
+ depth (`int`):
+ Number of residual blocks.
+ drop_path_rates(`list[float]`):
+ Stochastic depth rates for each layer.
+ """
+
+ def __init__(
+ self,
+ config: ConvNextV2Config,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 2,
+ stride: int = 2,
+ depth: int = 2,
+ drop_path_rates: list[float] | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if in_channels != out_channels or stride > 1:
+ self.downsampling_layer = [
+ keras.layers.LayerNormalization(
+ epsilon=1e-6,
+ name="downsampling_layer.0",
+ ),
+ # Inputs to this layer will follow NHWC format since we
+ # transposed the inputs from NCHW to NHWC in the `TFConvNextV2Embeddings`
+ # layer. All the outputs throughout the model will be in NHWC
+ # from this point on until the output where we again change to
+ # NCHW.
+ keras.layers.Conv2D(
+ filters=out_channels,
+ kernel_size=kernel_size,
+ strides=stride,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer=keras.initializers.Zeros(),
+ name="downsampling_layer.1",
+ ),
+ ]
+ else:
+ self.downsampling_layer = [tf.identity]
+
+ drop_path_rates = drop_path_rates or [0.0] * depth
+ self.layers = [
+ TFConvNextV2Layer(
+ config,
+ dim=out_channels,
+ drop_path=drop_path_rates[j],
+ name=f"layers.{j}",
+ )
+ for j in range(depth)
+ ]
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.stride = stride
+
+ def call(self, hidden_states):
+ for layer in self.downsampling_layer:
+ hidden_states = layer(hidden_states)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layers", None) is not None:
+ for layer in self.layers:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+ if self.in_channels != self.out_channels or self.stride > 1:
+ with tf.name_scope(self.downsampling_layer[0].name):
+ self.downsampling_layer[0].build([None, None, None, self.in_channels])
+ with tf.name_scope(self.downsampling_layer[1].name):
+ self.downsampling_layer[1].build([None, None, None, self.in_channels])
+
+
+class TFConvNextV2Encoder(keras.layers.Layer):
+ def __init__(self, config: ConvNextV2Config, **kwargs):
+ super().__init__(**kwargs)
+ self.stages = []
+ drop_path_rates = tf.linspace(0.0, config.drop_path_rate, sum(config.depths))
+ drop_path_rates = tf.split(drop_path_rates, config.depths)
+ drop_path_rates = [x.numpy().tolist() for x in drop_path_rates]
+ prev_chs = config.hidden_sizes[0]
+ for i in range(config.num_stages):
+ out_chs = config.hidden_sizes[i]
+ stage = TFConvNextV2Stage(
+ config,
+ in_channels=prev_chs,
+ out_channels=out_chs,
+ stride=2 if i > 0 else 1,
+ depth=config.depths[i],
+ drop_path_rates=drop_path_rates[i],
+ name=f"stages.{i}",
+ )
+ self.stages.append(stage)
+ prev_chs = out_chs
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ output_hidden_states: bool | None = False,
+ return_dict: bool | None = True,
+ ) -> tuple | TFBaseModelOutputWithNoAttention:
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, layer_module in enumerate(self.stages):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states = layer_module(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+
+ return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
+
+ def build(self, input_shape=None):
+ for stage in self.stages:
+ with tf.name_scope(stage.name):
+ stage.build(None)
+
+
+@keras_serializable
+class TFConvNextV2MainLayer(keras.layers.Layer):
+ config_class = ConvNextV2Config
+
+ def __init__(self, config: ConvNextV2Config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.embeddings = TFConvNextV2Embeddings(config, name="embeddings")
+ self.encoder = TFConvNextV2Encoder(config, name="encoder")
+ self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+ # We are setting the `data_format` like so because from here on we will revert to the
+ # NCHW output format
+ self.pooler = keras.layers.GlobalAvgPool2D(data_format="channels_last")
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.embeddings(pixel_values, training=training)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+
+ # Change to NCHW output format have uniformity in the modules
+ pooled_output = self.pooler(last_hidden_state)
+ last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
+ pooled_output = self.layernorm(pooled_output)
+
+ # Change the other hidden state outputs to NCHW as well
+ if output_hidden_states:
+ hidden_states = tuple(tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1])
+
+ if not return_dict:
+ hidden_states = hidden_states if output_hidden_states else ()
+ return (last_hidden_state, pooled_output) + hidden_states
+
+ return TFBaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, self.config.hidden_sizes[-1]])
+
+
+class TFConvNextV2PreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ConvNextV2Config
+ base_model_prefix = "convnextv2"
+ main_input_name = "pixel_values"
+
+
+CONVNEXTV2_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`ConvNextV2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CONVNEXTV2_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]`, `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`ConvNextImageProcessor.__call__`] for details.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to `True`.
+"""
+
+
+@add_start_docstrings(
+ "The bare ConvNextV2 model outputting raw features without any specific head on top.",
+ CONVNEXTV2_START_DOCSTRING,
+)
+class TFConvNextV2Model(TFConvNextV2PreTrainedModel):
+ def __init__(self, config: ConvNextV2Config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.convnextv2 = TFConvNextV2MainLayer(config, name="convnextv2")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(CONVNEXTV2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPoolingAndNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPoolingAndNoAttention | tuple[tf.Tensor]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ outputs = self.convnextv2(
+ pixel_values=pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return outputs[:]
+
+ return TFBaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=outputs.last_hidden_state,
+ pooler_output=outputs.pooler_output,
+ hidden_states=outputs.hidden_states,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "convnextv2", None) is not None:
+ with tf.name_scope(self.convnextv2.name):
+ self.convnextv2.build(None)
+
+
+@add_start_docstrings(
+ """
+ ConvNextV2 Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """,
+ CONVNEXTV2_START_DOCSTRING,
+)
+class TFConvNextV2ForImageClassification(TFConvNextV2PreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: ConvNextV2Config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+ self.convnextv2 = TFConvNextV2MainLayer(config, name="convnextv2")
+
+ # Classifier head
+ self.classifier = keras.layers.Dense(
+ units=config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer=keras.initializers.Zeros(),
+ name="classifier",
+ )
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(CONVNEXTV2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=TFImageClassifierOutputWithNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFImageClassifierOutputWithNoAttention | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ outputs = self.convnextv2(
+ pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.classifier(pooled_output)
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFImageClassifierOutputWithNoAttention(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "convnextv2", None) is not None:
+ with tf.name_scope(self.convnextv2.name):
+ self.convnextv2.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_sizes[-1]])
+
+
+__all__ = ["TFConvNextV2ForImageClassification", "TFConvNextV2Model", "TFConvNextV2PreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/cpmant/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/cpmant/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d92eea75693e728f4c1e9afc199d40751ef6af7a
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/cpmant/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_cpmant import *
+ from .modeling_cpmant import *
+ from .tokenization_cpmant import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/cpmant/configuration_cpmant.py b/venv/lib/python3.13/site-packages/transformers/models/cpmant/configuration_cpmant.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3368d67af7ab7db1278bcecdd01f6bf0b6ca59c
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/cpmant/configuration_cpmant.py
@@ -0,0 +1,122 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""CPMAnt model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class CpmAntConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`CpmAntModel`]. It is used to instantiate an
+ CPMAnt model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the CPMAnt
+ [openbmb/cpm-ant-10b](https://huggingface.co/openbmb/cpm-ant-10b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30720):
+ Vocabulary size of the CPMAnt model. Defines the number of different tokens that can be represented by the
+ `input` passed when calling [`CpmAntModel`].
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the encoder layers.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads in the Transformer encoder.
+ dim_head (`int`, *optional*, defaults to 128):
+ Dimension of attention heads for each attention layer in the Transformer encoder.
+ dim_ff (`int`, *optional*, defaults to 10240):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 48):
+ Number of layers of the Transformer encoder.
+ dropout_p (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder.
+ position_bias_num_buckets (`int`, *optional*, defaults to 512):
+ The number of position_bias buckets.
+ position_bias_max_distance (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ init_std (`float`, *optional*, defaults to 1.0):
+ Initialize parameters with std = init_std.
+ prompt_types (`int`, *optional*, defaults to 32):
+ The type of prompt.
+ prompt_length (`int`, *optional*, defaults to 32):
+ The length of prompt.
+ segment_types (`int`, *optional*, defaults to 32):
+ The type of segment.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether to use cache.
+
+ Example:
+
+ ```python
+ >>> from transformers import CpmAntModel, CpmAntConfig
+
+ >>> # Initializing a CPMAnt cpm-ant-10b style configuration
+ >>> configuration = CpmAntConfig()
+
+ >>> # Initializing a model from the cpm-ant-10b style configuration
+ >>> model = CpmAntModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "cpmant"
+
+ def __init__(
+ self,
+ vocab_size: int = 30720,
+ hidden_size: int = 4096,
+ num_attention_heads: int = 32,
+ dim_head: int = 128,
+ dim_ff: int = 10240,
+ num_hidden_layers: int = 48,
+ dropout_p: int = 0.0,
+ position_bias_num_buckets: int = 512,
+ position_bias_max_distance: int = 2048,
+ eps: int = 1e-6,
+ init_std: float = 1.0,
+ prompt_types: int = 32,
+ prompt_length: int = 32,
+ segment_types: int = 32,
+ use_cache: bool = True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.prompt_types = prompt_types
+ self.prompt_length = prompt_length
+ self.segment_types = segment_types
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.dim_head = dim_head
+ self.dim_ff = dim_ff
+ self.num_hidden_layers = num_hidden_layers
+ self.position_bias_num_buckets = position_bias_num_buckets
+ self.position_bias_max_distance = position_bias_max_distance
+ self.dropout_p = dropout_p
+ self.eps = eps
+ self.use_cache = use_cache
+ self.vocab_size = vocab_size
+ self.init_std = init_std
+
+
+__all__ = ["CpmAntConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/cpmant/modeling_cpmant.py b/venv/lib/python3.13/site-packages/transformers/models/cpmant/modeling_cpmant.py
new file mode 100644
index 0000000000000000000000000000000000000000..15881a64eb3799cc23bf95c5bece4079717af0a8
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/cpmant/modeling_cpmant.py
@@ -0,0 +1,807 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch CPMAnt"""
+
+import math
+from typing import Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging
+from .configuration_cpmant import CpmAntConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class CpmAntLayerNorm(nn.Module):
+ """
+ We use Root Mean Square (RMS) Layer Normalization, please see https://huggingface.co/papers/1910.07467 for details."
+ """
+
+ def __init__(self, config: CpmAntConfig):
+ super().__init__()
+
+ self.eps = config.eps
+ self.dim_norm = config.hidden_size
+ self.weight = nn.Parameter(torch.empty(config.hidden_size))
+
+ def forward(self, hidden_states: torch.Tensor):
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
+ """
+ if hidden_states.size(-1) != self.dim_norm:
+ raise AssertionError("hidden_states.size(-1) != self.dim_norm")
+ old_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
+ hidden_states = (hidden_states * torch.rsqrt(variance + self.eps)).to(old_dtype) * self.weight
+ return hidden_states
+
+
+class CpmAntAttention(nn.Module):
+ def __init__(self, config: CpmAntConfig, layer_idx=None):
+ super().__init__()
+ self.dim_model = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.dim_head = config.dim_head
+ self.layer_idx = layer_idx
+
+ self.project_q = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
+ self.project_k = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
+ self.project_v = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
+
+ self.attention_out = nn.Linear(self.num_heads * self.dim_head, self.dim_model, bias=False)
+
+ self.softmax = torch.nn.Softmax(dim=-1)
+
+ if config.dropout_p is not None:
+ self.dropout = torch.nn.Dropout(p=config.dropout_p)
+ else:
+ self.dropout = None
+
+ def forward(
+ self,
+ hidden_q: torch.Tensor,
+ hidden_kv: torch.Tensor,
+ attention_mask: torch.BoolTensor,
+ position_bias: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ):
+ """
+ Args:
+ hidden_q (`torch.Tensor`):
+ Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
+ hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)):
+ Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)`
+ attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
+ Avoid invalid areas to participate in the calculation of self-attention.
+ position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
+ Provide positional information to self-attention block.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers.
+ past_key_values (`Cache`, *optional*):
+ Cached past key and value projection states.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ """
+ batch_size = hidden_q.size(0)
+ len_q = hidden_q.size(1)
+ len_k = hidden_kv.size(1)
+
+ query = self.project_q(hidden_q)
+ key = self.project_k(hidden_kv)
+ value = self.project_v(hidden_kv)
+
+ query = query.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
+ key = key.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
+ value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
+
+ if past_key_values is not None:
+ key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position})
+ len_k = key.size(-2)
+
+ # (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k)
+ score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head)
+ score = score + position_bias
+
+ score = torch.masked_fill(
+ score,
+ attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
+ torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
+ )
+ score = self.softmax(score)
+
+ score = torch.masked_fill(
+ score,
+ attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
+ torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
+ )
+ if output_attentions:
+ attn_weights = score
+ else:
+ attn_weights = None
+
+ if self.dropout is not None:
+ score = self.dropout(score)
+
+ # (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head)
+ score = torch.matmul(score, value)
+
+ score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
+ score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
+
+ score = self.attention_out(score)
+
+ return score, attn_weights
+
+
+class CpmAntSelfAttentionBlock(nn.Module):
+ def __init__(self, config: CpmAntConfig, layer_idx=None):
+ super().__init__()
+ self.layernorm_before_attention = CpmAntLayerNorm(config)
+ self.self_attention = CpmAntAttention(config, layer_idx=layer_idx)
+ if config.dropout_p:
+ self.dropout = torch.nn.Dropout(config.dropout_p)
+ else:
+ self.dropout = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_bias: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ):
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
+ Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
+ attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
+ Avoid invalid areas to participate in the calculation of self-attention.
+ position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
+ Provide positional information to self-attention block.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers.
+ past_key_values (`Cache`, *optional*):
+ Cached past key and value projection states.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ """
+ outputs = self.layernorm_before_attention(hidden_states)
+ outputs, attn_weights = self.self_attention(
+ outputs,
+ outputs,
+ attention_mask,
+ position_bias,
+ output_attentions,
+ past_key_values,
+ use_cache,
+ cache_position,
+ )
+
+ if self.dropout is not None:
+ outputs = self.dropout(outputs)
+ hidden_states = hidden_states + outputs
+
+ return hidden_states, attn_weights
+
+
+class CpmAntDenseGatedACT(nn.Module):
+ def __init__(self, config: CpmAntConfig):
+ super().__init__()
+ self.w_0 = nn.Linear(config.hidden_size, config.dim_ff, bias=False)
+ self.w_1 = nn.Linear(config.hidden_size, config.dim_ff, bias=False)
+ self.act = torch.nn.GELU()
+
+ def forward(self, hidden_states: torch.Tensor):
+ """Transform an input tensor from one feature space to another via a nonlinear operation
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
+ """
+ gate_score = self.act(self.w_0(hidden_states))
+ hidden_states = self.w_1(hidden_states)
+
+ hidden_states = gate_score * hidden_states
+ return hidden_states
+
+
+class CpmAntFeedForward(nn.Module):
+ def __init__(self, config: CpmAntConfig):
+ super().__init__()
+ self.w_in = CpmAntDenseGatedACT(config)
+ if config.dropout_p is not None:
+ self.dropout = torch.nn.Dropout(config.dropout_p)
+ else:
+ self.dropout = None
+
+ self.w_out = nn.Linear(config.dim_ff, config.hidden_size, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor):
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
+ """
+ hidden_states = self.w_in(hidden_states)
+
+ if self.dropout is not None:
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = self.w_out(hidden_states)
+
+ return hidden_states
+
+
+class CpmAntFFNBlock(nn.Module):
+ def __init__(self, config: CpmAntConfig):
+ super().__init__()
+ self.layernorm_before_ffn = CpmAntLayerNorm(config)
+ self.ffn = CpmAntFeedForward(config)
+ if config.dropout_p:
+ self.dropout = torch.nn.Dropout(config.dropout_p)
+ else:
+ self.dropout = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ):
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
+ Hidden states before feed forward layer.
+ """
+ ln_outputs = self.layernorm_before_ffn(hidden_states)
+ outputs = self.ffn(ln_outputs)
+ if self.dropout is not None:
+ outputs = self.dropout(outputs)
+ hidden_states = hidden_states + outputs
+ return hidden_states
+
+
+class CpmAntTransformerBlock(nn.Module):
+ def __init__(self, config: CpmAntConfig, layer_idx=None):
+ super().__init__()
+ self.self_att = CpmAntSelfAttentionBlock(config, layer_idx=layer_idx)
+ self.ffn = CpmAntFFNBlock(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_bias: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ):
+ """
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input to the layer of shape `(batch, seq_len, dim_model)`
+ attention_mask (`torch.Tensor`):
+ Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
+ position_bias (`torch.Tensor`):
+ Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers.
+ past_key_values (`Cache`, *optional*):
+ Cached past key and value projection states
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ """
+ hidden_states, attn_weights = self.self_att(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ output_attentions=output_attentions,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = self.ffn(hidden_states)
+ return hidden_states, attn_weights
+
+
+class CpmAntEncoder(nn.Module):
+ def __init__(self, config: CpmAntConfig):
+ super().__init__()
+ self.num_layers = config.num_hidden_layers
+ self.layers = nn.ModuleList([CpmAntTransformerBlock(config, layer_idx=i) for i in range(self.num_layers)])
+
+ self.output_layernorm = CpmAntLayerNorm(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_bias: torch.Tensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ):
+ """
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input to the layer of shape `(batch, seq_len, dim_model)`
+ attention_mask (`torch.Tensor`):
+ Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
+ position_bias (`torch.Tensor`):
+ Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers.
+ past_key_values (`Cache`, *optional*):
+ Cached past key and value projection states
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ """
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask,
+ position_bias,
+ output_attentions=output_attentions,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ )
+ hidden_states, attn_weights = layer_outputs
+ if output_attentions:
+ all_self_attns += (attn_weights,)
+
+ hidden_states = self.output_layernorm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return hidden_states, all_hidden_states, all_self_attns
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CPMAnt
+class CpmAntIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class CpmAntSegmentPositionEmbedding(nn.Module):
+ def __init__(self, config: CpmAntConfig):
+ super().__init__()
+
+ self.num_heads = config.num_attention_heads
+ self.num_buckets = config.position_bias_num_buckets
+ self.max_distance = config.position_bias_max_distance
+ self.num_segments = config.segment_types
+
+ self.relative_attention_bias = nn.Parameter(
+ torch.empty(
+ config.segment_types * config.segment_types + config.position_bias_num_buckets,
+ config.num_attention_heads,
+ )
+ )
+
+ def forward(
+ self,
+ key_pos: torch.Tensor,
+ query_pos: torch.Tensor,
+ key_segment: torch.Tensor,
+ query_segment: torch.Tensor,
+ ):
+ with torch.no_grad():
+ batch = key_pos.size(0)
+ keylen = key_pos.size(1)
+ querylen = query_pos.size(1)
+
+ if key_pos.size(0) != query_pos.size(0):
+ raise AssertionError(
+ f"key_pos.size(0) should be equal to query_pos.size(0), but got {key_pos.size(0)} and {query_pos.size(0)}!"
+ )
+ if keylen != key_segment.size(1) or querylen != query_segment.size(1):
+ raise AssertionError(
+ f"keylen should be equal to key_segment.size(1), but got {keylen} and {key_segment.size(1)}!"
+ )
+ if querylen != query_segment.size(1):
+ raise AssertionError(
+ f"querylen should be equal to query_segment.size(1), but got {querylen} and {query_segment.size(1)}!"
+ )
+
+ key_pos = key_pos.view(batch, -1, keylen)
+ query_pos = query_pos.view(batch, querylen, -1)
+ key_segment = key_segment.view(batch, -1, keylen)
+ query_segment = query_segment.view(batch, querylen, -1)
+
+ relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment)
+ relative_position_bucket = relative_position_bucket + self.num_buckets
+
+ # (batch, len_q, len_k)
+ absolute_position_bucket = self._position_bucket(
+ torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :]
+ - torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None],
+ num_buckets=self.num_buckets,
+ max_distance=self.max_distance,
+ )
+ relative_position_bucket = torch.where(
+ (key_segment == query_segment),
+ absolute_position_bucket[None, :, :],
+ relative_position_bucket,
+ )
+
+ # (batch, len_q, len_k, num_heads)
+ embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
+ # (batch, num_heads, len_q, len_k)
+ embeds = embeds.permute(0, 3, 1, 2).contiguous()
+ return embeds
+
+ def _segment_relative_position_bucket(self, query_segment, key_segment):
+ return query_segment * self.num_segments + key_segment
+
+ def _position_bucket(self, relative_position, num_buckets=32, max_distance=128):
+ relative_buckets = 0
+ # always bidirectional in CPMAnt
+ num_buckets //= 2
+ relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
+ relative_position = torch.abs(relative_position)
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+ relative_position_if_large = max_exact + (
+ torch.log(relative_position.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.int32)
+ relative_position_if_large = torch.min(
+ relative_position_if_large,
+ torch.full_like(relative_position_if_large, num_buckets - 1),
+ )
+ relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_position_if_large)
+ return relative_buckets
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->CPMAnt
+class CpmAntOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+@auto_docstring
+class CpmAntPreTrainedModel(PreTrainedModel):
+ config: CpmAntConfig
+ base_model_prefix = "cpmant"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.init_std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.init_std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, CpmAntLayerNorm):
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, CpmAntSegmentPositionEmbedding):
+ module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std)
+
+
+@auto_docstring
+class CpmAntModel(CpmAntPreTrainedModel):
+ def __init__(self, config: CpmAntConfig):
+ super().__init__(config)
+ self.encoder = CpmAntEncoder(config)
+ self.segment_embedding = nn.Embedding(config.segment_types, config.hidden_size)
+ self.input_embedding = nn.Embedding(
+ config.vocab_size + config.prompt_types * config.prompt_length, config.hidden_size
+ )
+ self.position_bias = CpmAntSegmentPositionEmbedding(config)
+ self.prompt_length = config.prompt_length
+ self.vocab_size = config.vocab_size
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.input_embedding
+
+ def set_input_embeddings(self, embeddings, **kwargs):
+ self.input_embedding = embeddings
+
+ def _prepare_attention_mask(self, input_ids, span, context, length):
+ batch = input_ids.size(0)
+ seqlen = input_ids.size(1)
+ device = input_ids.device
+ directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
+ attention_mask = context[:, None, :] | (
+ context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
+ )
+ attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
+ # mask for left padding
+ mask_1d = (
+ torch.tensor(list(range(seqlen - self.prompt_length))[::-1], device=device)[None, :].repeat(batch, 1)
+ < length[:, None]
+ )
+ mask_1d = torch.cat((torch.ones(batch, self.prompt_length, device=device).bool(), mask_1d), dim=1)
+ attention_mask = mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask
+ return attention_mask
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPast]:
+ r"""
+ input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ # add prompts ahead
+ if input_ids.dtype != torch.int32:
+ input_ids = input_ids.to(torch.int32)
+ dtype, device = input_ids.dtype, input_ids.device
+ segment = torch.where(input_ids != 0, 2, 0).to(dtype=dtype, device=device)
+ length = (segment != 0).sum(-1).to(dtype=dtype, device=device)
+ input_ids = torch.cat(
+ (
+ torch.arange(
+ self.prompt_length * 2 + self.vocab_size,
+ self.prompt_length * 3 + self.vocab_size,
+ dtype=dtype,
+ device=device,
+ ).repeat(input_ids.size(0), 1),
+ input_ids,
+ ),
+ dim=1,
+ )
+ batch, seq_length = input_ids.size()
+ segment = torch.cat((torch.zeros(batch, self.prompt_length, dtype=dtype, device=device), segment), dim=1)
+ context = torch.full((batch, seq_length), 1, dtype=dtype, device=device)
+ position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1)
+ span = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+ if use_cache and isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `DynamicCache` instead, e.g. "
+ "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ input_ids = input_ids.contiguous()
+ hidden_states = self.input_embedding(input_ids)
+ segment_states = self.segment_embedding(segment)
+ if past_length != 0:
+ segment_states = segment_states[:, -1:, :]
+
+ hidden_states = hidden_states + segment_states
+
+ attention_mask = self._prepare_attention_mask(input_ids, span, context, length)
+ position_bias = self.position_bias(position, position, segment, segment)
+
+ attention_mask = attention_mask[:, past_length:, :]
+ position_bias = position_bias[:, :, past_length:, :]
+ hidden_states = hidden_states[:, past_length:, :]
+
+ hidden_states, all_hidden_states, all_attentions = self.encoder(
+ hidden_states,
+ attention_mask,
+ position_bias,
+ output_attentions,
+ output_hidden_states,
+ past_key_values,
+ use_cache,
+ cache_position,
+ )
+
+ if past_length == 0:
+ hidden_states = hidden_states[:, self.prompt_length :, :]
+ # drop the prompt
+ if all_attentions is not None:
+ new_attentions = ()
+ for attention in all_attentions:
+ new_attentions += (attention[:, :, self.prompt_length :, self.prompt_length :],)
+ all_attentions = new_attentions
+ if all_hidden_states is not None:
+ new_hidden_states = ()
+ for hidden_state in all_hidden_states:
+ new_hidden_states += (hidden_state[:, self.prompt_length :, :],)
+ all_hidden_states = new_hidden_states
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The CPMAnt Model with a language modeling head on top (linear layer with weights tied to the input embeddings).
+ """
+)
+class CpmAntForCausalLM(CpmAntPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: CpmAntConfig):
+ super().__init__(config)
+ self.cpmant = CpmAntModel(config)
+
+ # lm_head.weight is tied to cpmant.input_embedding.weight
+ self.lm_head = nn.Linear(
+ config.hidden_size, config.vocab_size + config.prompt_types * config.prompt_length, bias=False
+ )
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ return_dict: Optional[bool] = None,
+ attention_mask: Optional[torch.Tensor] = None, # dummy parameter for text-generation pipeline
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss.
+
+ Example:
+
+ Text Generation with CpmAntForCausalLM.
+ ```python
+ >>> from transformers import CPMAntTokenizer, CpmAntForCausalLM
+
+ >>> texts = "今天天气不错,"
+ >>> model = CpmAntForCausalLM.from_pretrained("openbmb/cpm-ant-10b")
+ >>> tokenizer = CPMAntTokenizer.from_pretrained("openbmb/cpm-ant-10b")
+ >>> input_ids = tokenizer(texts, return_tensors="pt")
+ >>> outputs = model.generate(**input_ids)
+ >>> output_texts = tokenizer.batch_decode(outputs)
+ >>> print(output_texts)
+ ['今天天气不错,阳光明媚,我和妈妈一起去超市买东西。\n在超市里,我看到了一个很好玩的玩具,它的名字叫“机器人”。它有一个圆圆的脑袋,两只圆圆的眼睛,还有一个圆圆的']
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ model_output = self.cpmant(
+ input_ids,
+ output_attentions,
+ output_hidden_states,
+ past_key_values,
+ use_cache,
+ return_dict,
+ cache_position,
+ )
+ hidden_states = model_output.last_hidden_state if return_dict else model_output[0]
+
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss_func = CrossEntropyLoss()
+ loss = loss_func(logits.view(-1, logits.size(-1)), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + model_output[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=model_output.past_key_values,
+ hidden_states=model_output.hidden_states,
+ attentions=model_output.attentions,
+ )
+
+ def get_input_embeddings(self):
+ return self.cpmant.input_embedding
+
+ def set_input_embeddings(self, embeddings):
+ self.cpmant.input_embedding = embeddings
+
+ def _reorder_cache(self, past_key_values, beam_idx):
+ past_key_values = [list(each) if each is not None else each for each in past_key_values]
+ for key_value_layer in past_key_values:
+ key_value_layer[0] = key_value_layer[0][beam_idx]
+ key_value_layer[1] = key_value_layer[1][beam_idx]
+ return past_key_values
+
+
+__all__ = ["CpmAntForCausalLM", "CpmAntModel", "CpmAntPreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/cpmant/tokenization_cpmant.py b/venv/lib/python3.13/site-packages/transformers/models/cpmant/tokenization_cpmant.py
new file mode 100644
index 0000000000000000000000000000000000000000..38cd9f0c6a25dd5ace5a51afc31be8ff8b10b3eb
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/cpmant/tokenization_cpmant.py
@@ -0,0 +1,272 @@
+# coding=utf-8
+# Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for CPMAnt."""
+
+import collections
+import os
+from typing import Optional
+
+from transformers.utils import is_rjieba_available, requires_backends
+
+
+if is_rjieba_available():
+ import rjieba
+
+from ...tokenization_utils import PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+class WordpieceTokenizer:
+ def __init__(self, vocab, unk_token="", max_input_chars_per_word=200):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, token):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ return [self.unk_token]
+
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ sub_tokens.append(self.unk_token)
+ start += 1
+ else:
+ sub_tokens.append(cur_substr)
+ start = end
+
+ return sub_tokens
+
+
+class CpmAntTokenizer(PreTrainedTokenizer):
+ """
+ Construct a CPMAnt tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ bod_token (`str`, *optional*, defaults to `""`):
+ The beginning of document token.
+ eod_token (`str`, *optional*, defaults to `""`):
+ The end of document token.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token.
+ line_token (`str`, *optional*, defaults to `""`):
+ The line token.
+ space_token (`str`, *optional*, defaults to `""`):
+ The space token.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ add_prefix_space = False
+
+ def __init__(
+ self,
+ vocab_file,
+ bod_token="",
+ eod_token="",
+ bos_token="",
+ eos_token="",
+ pad_token="",
+ unk_token="",
+ line_token="",
+ space_token="",
+ padding_side="left",
+ **kwargs,
+ ):
+ requires_backends(self, ["rjieba"])
+ self.bod_token = bod_token
+ self.eod_token = eod_token
+ self.encoder = load_vocab(vocab_file)
+ self.encoder[" "] = self.encoder[space_token]
+ self.encoder["\n"] = self.encoder[line_token]
+
+ del self.encoder[space_token]
+ del self.encoder[line_token]
+
+ self.encoder = collections.OrderedDict(sorted(self.encoder.items(), key=lambda x: x[1]))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.encoder, unk_token=unk_token)
+
+ super().__init__(
+ bod_token=bod_token,
+ eod_token=eod_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ unk_token=unk_token,
+ line_token=line_token,
+ space_token=space_token,
+ padding_side=padding_side,
+ **kwargs,
+ )
+
+ @property
+ def bod_token_id(self):
+ return self.encoder[self.bod_token]
+
+ @property
+ def eod_token_id(self):
+ return self.encoder[self.eod_token]
+
+ @property
+ def newline_id(self):
+ return self.encoder["\n"]
+
+ @property
+ def vocab_size(self) -> int:
+ return len(self.encoder)
+
+ def get_vocab(self):
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ def _tokenize(self, text):
+ """Tokenize a string."""
+ output_tokens = []
+ for x in rjieba.cut(text, False):
+ output_tokens.extend(self.wordpiece_tokenizer.tokenize(x))
+ return output_tokens
+
+ def _decode(self, token_ids, **kwargs):
+ """Decode ids into a string."""
+ token_ids = [i for i in token_ids if i >= 0]
+ token_ids = [
+ x for x in token_ids if x != self.pad_token_id and x != self.eos_token_id and x != self.bos_token_id
+ ]
+ return super()._decode(token_ids, **kwargs)
+
+ def check(self, token):
+ return token in self.encoder
+
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
+ return "".join(tokens)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index, self.unk_token)
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ else:
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
+ index = 0
+ if " " in self.encoder:
+ self.encoder[""] = self.encoder[" "]
+ del self.encoder[" "]
+ if "\n" in self.encoder:
+ self.encoder[""] = self.encoder["\n"]
+ del self.encoder["\n"]
+ self.encoder = collections.OrderedDict(sorted(self.encoder.items(), key=lambda x: x[1]))
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in self.encoder.items():
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!"
+ )
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+ return (vocab_file,)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A CPMAnt sequence has the following format:
+
+ - single sequence: `[BOS] Sequence`.
+
+ Args:
+ token_ids_0 (`list[int]`): The first tokenized sequence that special tokens will be added.
+ token_ids_1 (`list[int]`): The optional second tokenized sequence that special tokens will be added.
+
+ Returns:
+ `list[int]`: The model input with special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.bos_token_id] + token_ids_0
+ return [self.bos_token_id] + token_ids_0 + [self.bos_token_id] + token_ids_1
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`list[int]`): List of IDs.
+ token_ids_1 (`list[int]`, *optional*): Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
+ return [1] + ([0] * len(token_ids_0))
+
+
+__all__ = ["CpmAntTokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/csm/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/csm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..59468442b52eb71fbcb984c28bb465cec2be91e5
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/csm/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_csm import *
+ from .modeling_csm import *
+ from .processing_csm import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/csm/configuration_csm.py b/venv/lib/python3.13/site-packages/transformers/models/csm/configuration_csm.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e56c5f7686e15b8d83d8eb2afafac7b3dc578c9
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/csm/configuration_csm.py
@@ -0,0 +1,440 @@
+# coding=utf-8
+# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+from ...utils import logging
+from ..auto.configuration_auto import AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class CsmDepthDecoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`CsmDepthDecoderModel`]. It is used to instantiate an CSM depth decoder
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
+ a similar configuration to that of the csm-1b.
+
+ e.g. [sesame/csm-1b](https://huggingface.co/sesame/csm-1b)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ num_codebooks (`int`, *optional*, defaults to 32):
+ Number of codebooks used in the underlying codec model responsible for tokenizing the audio.
+ backbone_hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations of the backbone model used with this depth decoder.
+ vocab_size (`int`, *optional*, defaults to 2051):
+ Vocabulary size of the CsmDepthDecoder model. Defines the number of different audio tokens that can be represented by each codebook.
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 8192):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 4):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 2):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 33):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 2050):
+ Padding token id.
+ bos_token_id (`int`, *optional*):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*):
+ End of stream token id.
+ rope_theta (`float`, *optional*, defaults to 500000):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+ head_dim (`int`, *optional*):
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
+
+ ```python
+ >>> from transformers import CsmDepthDecoder, CsmDepthDecoderConfig
+
+ >>> # Initializing a CsmDepthDecoder
+ >>> configuration = CsmDepthDecoderConfig()
+ >>> model = CsmDepthDecoderModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "csm_depth_decoder_model"
+ base_config_key = "depth_decoder_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ num_codebooks=32,
+ backbone_hidden_size=2048,
+ vocab_size=2051,
+ hidden_size=1024,
+ intermediate_size=8192,
+ num_hidden_layers=4,
+ num_attention_heads=8,
+ num_key_value_heads=2,
+ hidden_act="silu",
+ max_position_embeddings=33,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=None,
+ eos_token_id=None,
+ rope_theta=500000,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ mlp_bias=False,
+ head_dim=None,
+ **kwargs,
+ ):
+ if kwargs.pop("tie_word_embeddings", False):
+ raise ValueError("`tie_word_embeddings=True` is not supported for CsmDepthDecoderConfig")
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=False,
+ **kwargs,
+ )
+ self.num_codebooks = num_codebooks
+ self.vocab_size = vocab_size
+ self.backbone_hidden_size = backbone_hidden_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.mlp_bias = mlp_bias
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+
+class CsmConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`CsmForConditionalGeneration`]. It is used to instantiate an CSM
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the csm-1b.
+
+ e.g. [sesame/csm-1b](https://huggingface.co/sesame/csm-1b)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_codebooks (`int`, *optional*, defaults to 32):
+ Number of codebooks used in the underlying codec model responsible for tokenizing the audio.
+ vocab_size (`int`, *optional*, defaults to 2051):
+ Vocabulary size of the Csm model. Defines the number of different audio tokens that can be represented by each codebook.
+ text_vocab_size (`int`, *optional*, defaults to 128256):
+ Vocabulary size of the text input for the Csm model. Defines the number of different text tokens that can be represented.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations of the backbone model.
+ intermediate_size (`int`, *optional*, defaults to 8192):
+ Dimension of the MLP representations of the backbone model.
+ num_hidden_layers (`int`, *optional*, defaults to 16):
+ Number of hidden layers in the backbone model Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the backbone model Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245).
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the backbone model Transformer decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 128002):
+ Padding token id.
+ codebook_pad_token_id (`int`, *optional*, defaults to 2050):
+ Padding token id for codebook tokens.
+ codebook_eos_token_id (`int`, *optional*, defaults to 0):
+ End of stream token id for codebook tokens.
+ bos_token_id (`int`, *optional*, defaults to 128000):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*):
+ End of stream token id.
+ audio_token_id (`int`, *optional*, defaults to 128002):
+ Audio token id in the text input.
+ audio_eos_token_id (`int`, *optional*, defaults to 128003):
+ End of stream token id for audio in the text input.
+ rope_theta (`float`, *optional*, defaults to 500000):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*, defaults to `{'factor': 32.0, 'high_freq_factor': 0.5, 'low_freq_factor': 0.125, 'original_max_position_embeddings': 1024, 'rope_type': 'llama3'}`):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+ head_dim (`int`, *optional*):
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
+ tie_codebooks_embeddings (`bool`, *optional*, defaults to `True`):
+ Whether to tie the codebook tokens embeddings of the backbone model to the codebook tokens embeddings of the depth decoder.
+ depth_decoder_config (`CsmDepthDecoderConfig`, *optional*):
+ Configuration for the depth decoder.
+ codec_config (`PretrainedConfig`, *optional*):
+ Configuration for the codec.
+
+ ```python
+ >>> from transformers import CsmForConditionalGeneration, CsmConfig
+
+ >>> # Initializing a CsmConfig
+ >>> configuration = CsmConfig()
+
+ >>> # Initializing a model
+ >>> model = CsmForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "csm"
+ base_config_key = "csm_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ sub_configs = {
+ "codec_config": AutoConfig,
+ "depth_decoder_config": CsmDepthDecoderConfig,
+ }
+
+ def __init__(
+ self,
+ num_codebooks=32,
+ vocab_size=2051,
+ text_vocab_size=128256,
+ hidden_size=2048,
+ intermediate_size=8192,
+ num_hidden_layers=16,
+ num_attention_heads=32,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=128002,
+ codebook_pad_token_id=2050,
+ codebook_eos_token_id=0,
+ bos_token_id=128000,
+ eos_token_id=None,
+ audio_token_id=128002,
+ audio_eos_token_id=128003,
+ rope_theta=500000,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ mlp_bias=False,
+ head_dim=None,
+ tie_codebooks_embeddings=True,
+ depth_decoder_config=None,
+ codec_config=None,
+ **kwargs,
+ ):
+ if kwargs.pop("tie_word_embeddings", False):
+ raise ValueError("`tie_word_embeddings=True` is not supported for CsmConfig")
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=False,
+ **kwargs,
+ )
+
+ if depth_decoder_config is None:
+ self.depth_decoder_config = CsmDepthDecoderConfig()
+ logger.info("depth_decoder_config is None, using default depth decoder config.")
+ elif isinstance(depth_decoder_config, dict):
+ self.depth_decoder_config = CsmDepthDecoderConfig(**depth_decoder_config)
+ elif isinstance(depth_decoder_config, CsmDepthDecoderConfig):
+ self.depth_decoder_config = depth_decoder_config
+
+ if codec_config is None:
+ self.codec_config = AutoConfig.for_model("mimi")
+ logger.info("codec_config is None, using default audio encoder config.")
+ elif isinstance(codec_config, dict):
+ self.codec_config = AutoConfig.for_model(**codec_config)
+ elif isinstance(codec_config, PretrainedConfig):
+ self.codec_config = codec_config
+
+ self.text_vocab_size = text_vocab_size
+ self.num_codebooks = num_codebooks
+ self.audio_token_id = audio_token_id
+ self.audio_eos_token_id = audio_eos_token_id
+ self.codebook_pad_token_id = codebook_pad_token_id
+ self.codebook_eos_token_id = codebook_eos_token_id
+ self.tie_codebooks_embeddings = tie_codebooks_embeddings
+
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.mlp_bias = mlp_bias
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+
+__all__ = [
+ "CsmDepthDecoderConfig",
+ "CsmConfig",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/csm/generation_csm.py b/venv/lib/python3.13/site-packages/transformers/models/csm/generation_csm.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf8bc141f5d1ddfb1c5eeea6ae9c4ce2590a8caa
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/csm/generation_csm.py
@@ -0,0 +1,491 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...generation import (
+ GenerateDecoderOnlyOutput,
+ GenerationConfig,
+ GenerationMixin,
+ GenerationMode,
+)
+from ...generation.logits_process import LogitsProcessorList
+from ...generation.stopping_criteria import MaxLengthCriteria, StoppingCriteriaList
+from ...generation.utils import GenerateNonBeamOutput
+from ...utils import logging
+
+
+if TYPE_CHECKING:
+ from ...generation.streamers import BaseStreamer
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class CsmGenerateOutput(GenerateDecoderOnlyOutput):
+ """
+ Outputs of CsmForConditionalGeneration.generate.
+
+ Args:
+ sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
+ if all batches finished early due to the `eos_token_id`.
+ scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
+ Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
+ at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
+ each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
+ logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
+ Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
+ at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
+ each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
+ attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
+ hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
+ Returns the model cache, used to speed up decoding. Different models have a different cache format, check
+ audio (`list(torch.FloatTensor)` of length `batch_size`):
+ The generated audio.
+ """
+
+ audio: Optional[list[torch.Tensor]] = None
+
+
+class CsmGenerationMixin(GenerationMixin):
+ def _get_stopping_criteria(
+ self,
+ *args,
+ **kwargs,
+ ) -> StoppingCriteriaList:
+ criteria = super()._get_stopping_criteria(*args, **kwargs)
+
+ kept_criteria = StoppingCriteriaList()
+ for criterion in criteria:
+ if not isinstance(criterion, MaxLengthCriteria):
+ logger.warning(
+ f"Csm does not support {criterion.__class__.__name__} stopping criteria, it will be ignored."
+ )
+ else:
+ kept_criteria.append(criterion)
+ return kept_criteria
+
+ def _prepare_generation_config(
+ self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Any
+ ) -> tuple[GenerationConfig, dict]:
+ """
+ This method overrides [~generation.utils.GenerationMixin._prepare_generation_config].
+ It ensures that the depth decoder generation config is initialized and that passed args as depth_decoder_* are properly handled.
+ """
+ # extract depth decoder kwargs and remove them from the main kwargs
+ depth_decoder_kwargs = {
+ k[len("depth_decoder_") :]: v for k, v in kwargs.items() if k.startswith("depth_decoder_")
+ }
+
+ # remove the depth decoder keys from the original kwargs
+ kwargs = {k: v for k, v in kwargs.items() if not k.startswith("depth_decoder_")}
+
+ # initialize the generation config
+ generation_config, model_kwargs = super()._prepare_generation_config(
+ generation_config, use_model_defaults, **kwargs
+ )
+ self.depth_decoder.generation_config.update(**depth_decoder_kwargs)
+
+ # ensure the depth decoder generation config is valid
+ depth_decoder_min_new_tokens = getattr(self.depth_decoder.generation_config, "min_new_tokens") or (
+ self.config.num_codebooks - 1
+ )
+ depth_decoder_max_new_tokens = getattr(self.depth_decoder.generation_config, "max_new_tokens") or (
+ self.config.num_codebooks - 1
+ )
+
+ if {depth_decoder_min_new_tokens, depth_decoder_max_new_tokens} != {self.config.num_codebooks - 1}:
+ raise ValueError(
+ f"depth_decoder_generation_config's min_new_tokens ({depth_decoder_min_new_tokens}) and max_new_tokens ({depth_decoder_max_new_tokens}) must be equal to self.config.num_codebooks - 1 ({self.config.num_codebooks - 1})"
+ )
+ elif self.depth_decoder.generation_config.return_dict_in_generate:
+ logger.warning(
+ "depth_decoder_generation_config.return_dict_in_generate is set to True, but this will be ignored as the depth decoder model does not return a dictionary in generate"
+ )
+ self.depth_decoder.generation_config.return_dict_in_generate = False
+
+ self.depth_decoder.generation_config.min_new_tokens = depth_decoder_min_new_tokens
+ self.depth_decoder.generation_config.max_new_tokens = depth_decoder_max_new_tokens
+
+ # Monkey patch the get_generation_mode method to support CSM model
+ original_get_generation_mode = generation_config.get_generation_mode
+
+ def patched_get_generation_mode(assistant_model=None):
+ generation_mode = original_get_generation_mode(assistant_model)
+ if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]:
+ raise ValueError(
+ f"Generation mode {generation_mode} is not supported for CSM model. Please set generation parameters to use greedy or sampling generation."
+ )
+
+ return generation_mode
+
+ generation_config.get_generation_mode = patched_get_generation_mode
+
+ return generation_config, model_kwargs
+
+ def _sample(
+ self,
+ input_ids: torch.LongTensor,
+ logits_processor: LogitsProcessorList,
+ stopping_criteria: StoppingCriteriaList,
+ generation_config: GenerationConfig,
+ synced_gpus: bool = False,
+ streamer: Optional["BaseStreamer"] = None,
+ **model_kwargs,
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
+ """
+ This method overrides [~generation.utils.GenerationMixin._sample].
+ To ease maintenance, modifications are marked with the comment "Csm specific".
+
+ Indeed, Csm model requires a custom generation sampling step:
+ 1. Infer the backbone model to sample the first codebook token
+ 2. Call generate on the depth decoder with the first codebook token as input_ids to sample the next codebook tokens
+ 3. Use these generated codebook tokens as input_ids to sample the next first codebook token using the backbone model
+ 4. Repeat until stopping criteria is met
+
+ Csm supports two stopping criteria:
+ - stop when the generated sequence is at max_length
+ - stop when all the generated codebook tokens are the codebook_eos_token_id
+ """
+ # init values
+ # *************** Csm specific ***************
+ pad_token_id = self.config.codebook_pad_token_id
+ has_eos_stopping_criteria = generation_config._eos_token_tensor is not None
+ # ============================================
+ output_attentions = generation_config.output_attentions
+ output_hidden_states = generation_config.output_hidden_states
+ output_scores = generation_config.output_scores
+ output_logits = generation_config.output_logits
+ return_dict_in_generate = generation_config.return_dict_in_generate
+ do_sample = generation_config.do_sample
+
+ # init attention / hidden states / scores tuples
+ scores = () if (return_dict_in_generate and output_scores) else None
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
+
+ # keep track of which sequences are already finished
+ batch_size, cur_len = input_ids.shape[:2]
+ this_peer_finished = False
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
+
+ # *************** Csm specific ***************
+ if input_ids.ndim == 2 and model_kwargs.get("inputs_embeds") is None:
+ # in the case where the passed input_ids correspond to text tokens, i.e. don't have a third dimension for codebook ids,
+ # we need to remove the input length to the MaxLengthCriteria stopping criteria has such input are not returned
+ for criterion in stopping_criteria:
+ if isinstance(criterion, MaxLengthCriteria):
+ criterion.max_length -= cur_len
+ # ============================================
+
+ model_forward = self.__call__
+ compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
+ if compile_forward:
+ os.environ["TOKENIZERS_PARALLELISM"] = "0"
+ model_forward = self.get_compiled_call(generation_config.compile_config)
+
+ is_prefill = True
+ while self._has_unfinished_sequences(
+ this_peer_finished,
+ synced_gpus,
+ device=input_ids.device,
+ ):
+ # prepare model inputs
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+
+ # prepare variable output controls (note: some models won't accept all output controls)
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
+ # *************** Csm specific ***************
+ model_inputs.update({"output_hidden_states": True})
+ # ============================================
+
+ if is_prefill:
+ outputs = self(**model_inputs, return_dict=True)
+ is_prefill = False
+ else:
+ outputs = model_forward(**model_inputs, return_dict=True)
+
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs,
+ model_kwargs,
+ )
+ if synced_gpus and this_peer_finished:
+ continue
+
+ # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
+ # (the clone itself is always small)
+ next_token_logits = outputs.logits[:, -1, :].clone().float()
+ next_token_logits = next_token_logits.to(input_ids.device)
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+
+ # Store scores, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_scores:
+ scores += (next_token_scores,)
+ if output_logits:
+ raw_logits += (next_token_logits,)
+ if output_attentions:
+ decoder_attentions += (outputs.attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (outputs.hidden_states,)
+
+ # token selection
+ if do_sample:
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
+ # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
+
+ # *************** Csm specific ***************
+ # infer the depth decoder
+ first_codebook_ids = next_tokens[:, None]
+ # adds place holder in position 0 that will be replaced by the backbone_last_hidden_state
+ depth_decoder_input_ids = nn.functional.pad(first_codebook_ids, (1, 0), value=0)
+ backbone_last_hidden_state = outputs.hidden_states[-1][:, -1, :]
+
+ depth_decoder_outputs = self.depth_decoder.generate(
+ input_ids=depth_decoder_input_ids, backbone_last_hidden_state=backbone_last_hidden_state.clone()
+ )
+ codebook_ids = (
+ depth_decoder_outputs
+ if isinstance(depth_decoder_outputs, torch.Tensor)
+ else depth_decoder_outputs.sequences
+ )
+ # remove the place holder in position 0
+ codebook_ids = codebook_ids[:, 1:]
+ next_tokens = codebook_ids
+
+ # finished sentences should have their next token be a padding token
+ if has_eos_stopping_criteria:
+ next_tokens = next_tokens * unfinished_sequences.unsqueeze(-1) + pad_token_id * (
+ 1 - unfinished_sequences.unsqueeze(-1)
+ )
+
+ # update generated ids, model inputs, and length for next step
+ if input_ids.ndim == 2:
+ input_ids = next_tokens[:, None, :]
+ else:
+ input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1)
+ # ============================================
+
+ if streamer is not None:
+ streamer.put(next_tokens.cpu())
+
+ # *************** Csm specific ***************
+ # for the eos stopping criteria, is it expected that the eos token is the same for each codebook !!!!
+ unfinished_sequences = unfinished_sequences & ~(
+ input_ids[:, -1, :-1] == self.config.codebook_eos_token_id
+ ).all(-1)
+ # ============================================
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
+ this_peer_finished = unfinished_sequences.max() == 0
+ cur_len += 1
+
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
+ del outputs
+
+ # *************** Csm specific ***************
+ del depth_decoder_outputs
+ # ============================================
+
+ if streamer is not None:
+ streamer.end()
+
+ if return_dict_in_generate:
+ return GenerateDecoderOnlyOutput(
+ sequences=input_ids,
+ scores=scores,
+ logits=raw_logits,
+ attentions=decoder_attentions,
+ hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return input_ids
+
+ def generate(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ input_values: Optional[torch.Tensor] = None,
+ input_values_cutoffs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ synced_gpus: Optional[bool] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ output_audio: Optional[bool] = False,
+ **kwargs,
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
+ r"""
+ This method overrides [`~generation.utils.GenerationMixin.generate`] to match the specifics of the Csm model.
+ Indeed, Csm model requires a custom generation sampling step:
+ 1. Infer the backbone model to sample the first codebook token
+ 2. Call generate on the depth decoder with the first codebook token as `input_ids` to sample the next codebook tokens
+ 3. Use these generated codebook tokens as `input_ids` to sample the next first codebook token using the backbone model
+ 4. Repeat until stopping criteria is met
+
+
+
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate(), e.g. `.generate(inputs, do_sample=True)`.
+
+
+ Parameters:
+ inputs_ids (`torch.Tensor` of shape (batch_size, seq_length), *optional*):
+ The sequence used as a prompt for the backbone model.
+ input_values (`torch.Tensor` of shape (batch_size, channels, max_concatenated_audio_length), *optional*):
+ The batched audio input values, where each batch entry contains the concatenation of all audio segments for that entry.
+ These values will be encoded into codebook tokens using the codec model and merged with the text input ids provided in `input_ids`.
+ input_values_cutoffs (`torch.Tensor` of shape (batch_size, max_num_audio), *optional*):
+ Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
+ If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
+ where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
+ the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
+ generation_config ([`~generation.GenerationConfig`], *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which has the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complements the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
+ sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
+ intended for advanced users.
+ synced_gpus (`bool`, *optional*):
+ Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
+ to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
+ deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
+ streamer (`BaseStreamer`, *optional*):
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ output_audio (`bool`, *optional*):
+ Whether to return the generated audio.
+ kwargs (`dict[str, Any]`, *optional*):
+ Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. Depth decoder specific kwargs should be prefixed with *depth_decoder_*.
+
+ Return:
+ [`CsmGenerateOutput`] or `torch.LongTensor` or `list[torch.FloatTensor]`: A [`CsmGenerateOutput`]
+ (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.LongTensor` when `output_audio=False`
+ or a `list[torch.FloatTensor]` otherwise.
+
+ Example:
+
+ ```python
+ >>> from transformers import CsmProcessor, CsmForConditionalGeneration
+ >>> from datasets import load_dataset, Audio
+
+ >>> model_id = "sesame/csm-1b"
+ >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+
+ >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+ >>> # ensure the audio is 24kHz
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
+
+ >>> conversation = []
+ >>> # prepare a conversation with text and corresponding audio
+ >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
+ ... conversation.append(
+ ... {
+ ... "role": f"{speaker_id}",
+ ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
+ ... }
+ ... )
+
+ >>> # text prompt
+ >>> conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
+
+ >>> inputs = processor.apply_chat_template(
+ ... conversation,
+ ... tokenize=True,
+ ... return_dict=True,
+ ... ).to(torch_device)
+
+ >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
+ >>> audio = model.generate(**inputs, output_audio=True)
+ >>> processor.save_audio(audio, "output.wav")
+ ```
+ """
+ generate_output = super().generate(
+ input_ids=input_ids,
+ input_values=input_values,
+ input_values_cutoffs=input_values_cutoffs,
+ generation_config=generation_config,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ synced_gpus=synced_gpus,
+ streamer=streamer,
+ **kwargs,
+ )
+
+ generate_returned_dict = not isinstance(generate_output, torch.Tensor)
+ audio = None
+ if output_audio:
+ generated_audio_codes = generate_output.sequences if generate_returned_dict else generate_output
+
+ # infer the codec model
+ audio = []
+ with torch.no_grad():
+ # =======================================
+ # TODO: @eustlb, this should be batched !!!
+ # but requires making sure batched inference of the codec model works as intended
+ for audio_codes_batch in generated_audio_codes:
+ eos_idxs = (audio_codes_batch == self.config.codebook_eos_token_id).all(dim=-1).nonzero()
+ if eos_idxs.numel() != 0:
+ cutoff_idx = eos_idxs.min()
+ else:
+ cutoff_idx = audio_codes_batch.shape[0]
+
+ audio_codes_batch = audio_codes_batch[:cutoff_idx]
+ codec_decode_output = self.codec_model.decode(audio_codes_batch.transpose(0, 1).unsqueeze(0))
+ audio.append(codec_decode_output.audio_values[0, 0])
+ # =======================================
+
+ if generate_returned_dict:
+ return CsmGenerateOutput(audio=audio, **generate_output)
+ elif output_audio:
+ return audio
+ else:
+ return generate_output
diff --git a/venv/lib/python3.13/site-packages/transformers/models/csm/modeling_csm.py b/venv/lib/python3.13/site-packages/transformers/models/csm/modeling_csm.py
new file mode 100644
index 0000000000000000000000000000000000000000..18b80b1ef12d1b28a3ba381815bf406b4a649def
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/csm/modeling_csm.py
@@ -0,0 +1,1088 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/csm/modular_csm.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_csm.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from transformers.utils.generic import check_model_inputs
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.deprecation import deprecate_kwarg
+from ..auto import AutoModel
+from .configuration_csm import CsmConfig, CsmDepthDecoderConfig
+from .generation_csm import CsmGenerationMixin
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for the model autoregressive outputs.
+ """
+)
+class CsmOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ depth_decoder_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction) of the depth decoder model.
+ depth_decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the depth decoder (scores for each vocabulary token before SoftMax).
+ depth_decoder_past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+ depth_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ depth_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+ backbone_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction) of the backbone model.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ depth_decoder_loss: Optional[torch.FloatTensor] = None
+ depth_decoder_logits: Optional[torch.FloatTensor] = None
+ depth_decoder_past_key_values: Optional[Cache] = None
+ depth_decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ depth_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ backbone_loss: Optional[torch.FloatTensor] = None
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class CsmRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ CsmRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class CsmRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: CsmConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class CsmMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class CsmAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: CsmConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class CsmDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: CsmConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = CsmAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = CsmMLP(config)
+ self.input_layernorm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring(
+ custom_intro="""
+ The bare Csm Model outputting raw hidden-states without any specific head on top.
+ """
+)
+@auto_docstring
+class CsmPreTrainedModel(PreTrainedModel):
+ config: CsmConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["CsmDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ # does not because of Mimi codec model
+ # _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": CsmDecoderLayer,
+ "attentions": CsmAttention,
+ }
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, CsmCodebooksHead):
+ num_codebooks = module.num_codebooks
+ for i in range(num_codebooks - 1):
+ module.weight.data[i].normal_(mean=0.0, std=self.config.initializer_range)
+
+
+@auto_docstring
+class CsmDepthDecoderModel(CsmPreTrainedModel):
+ config: CsmDepthDecoderConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.embed_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.backbone_hidden_size)
+ self.layers = nn.ModuleList(
+ [CsmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = CsmRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+ self.inputs_embeds_projector = nn.Linear(config.backbone_hidden_size, config.hidden_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ r"""
+ backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
+ The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
+ is provided in the `input_ids` argument.
+ """
+ if position_ids is not None and not torch.compiler.is_compiling():
+ logger.warning_once(
+ "Custom `position_ids` were provided but will be ignored. CSM depth decoder automatically determines position_ids "
+ "from `cache_position` and as it requires them to be identical across the batch, the provided position_ids will be ignored."
+ )
+ position_ids = None
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
+ device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device)
+
+ if inputs_embeds is None:
+ codebook_idxs = torch.clamp(cache_position - 1, min=0)
+ offset = codebook_idxs * self.vocab_size
+ inputs_embeds = self.embed_tokens(input_ids + offset)
+
+ input_ids_are_first_codebook = cache_position[0] == 0
+ if backbone_last_hidden_state is not None:
+ inputs_embeds[:, 0] = backbone_last_hidden_state
+ else:
+ if not torch.compiler.is_compiling() and input_ids_are_first_codebook:
+ logger.warning(
+ "When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference."
+ )
+
+ inputs_embeds = self.inputs_embeds_projector(inputs_embeds)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_ids = cache_position.unsqueeze(0)
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ )
+
+
+class CsmCodebooksHead(nn.Module):
+ def __init__(self, hidden_size, num_codebooks, vocab_size):
+ super().__init__()
+ self.num_codebooks = num_codebooks
+ self.weight = nn.Parameter(torch.empty(self.num_codebooks - 1, hidden_size, vocab_size))
+
+ def forward(self, hidden_states, cache_position=None):
+ if cache_position is None:
+ seq_length = hidden_states.shape[1]
+ codebook_weight = self.weight[torch.arange(seq_length)]
+ else:
+ codebook_idxs = cache_position - 1
+ codebook_weight = self.weight[codebook_idxs]
+
+ hidden_states = [
+ nn.functional.linear(hidden_states[:, codebook_idx, :], codebook_weight[codebook_idx].T)
+ for codebook_idx in range(codebook_weight.shape[0])
+ ]
+ hidden_states = torch.stack(hidden_states, dim=1)
+
+ return hidden_states
+
+
+@auto_docstring(
+ custom_intro="""
+ The CsmDepthDecoder Model transformer, with a [`CsmCodebooksHead`] on top,
+ which can be seen a position-specific language modeling head, allowing to use a different linear layer for each codebook
+ (e.g. position 0 is the first codebook and uses the first codebook head, etc.)
+ """
+)
+class CsmDepthDecoderForCausalLM(CsmPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = None
+ _tp_plan = None
+ _pp_plan = None
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = CsmDepthDecoderModel(config)
+ self.vocab_size = config.vocab_size
+ self.codebooks_head = CsmCodebooksHead(config.hidden_size, config.num_codebooks, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
+ The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
+ is provided in the `input_ids` argument.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ """
+ outputs = self.model(
+ input_ids=input_ids,
+ backbone_last_hidden_state=backbone_last_hidden_state,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ if isinstance(logits_to_keep, int):
+ if logits_to_keep == 0:
+ # skip idx 0 logits since it's for the concatenated backbone last hidden state
+ slice_indices = slice(1, None)
+ else:
+ slice_indices = slice(-logits_to_keep, None)
+ else:
+ slice_indices = logits_to_keep
+
+ logits = self.codebooks_head(
+ hidden_states[:, slice_indices, :], cache_position[slice_indices] if cache_position is not None else None
+ )
+ logits = logits.contiguous()
+
+ loss = None
+ if labels is not None:
+ shift_labels = labels[..., 1:].contiguous()
+ loss = self.loss_function(
+ logits=logits, labels=None, vocab_size=self.config.vocab_size, shift_labels=shift_labels, **kwargs
+ )
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
+ )
+
+ is_first_generation_step = model_inputs["cache_position"][0] == 0
+ if not is_first_generation_step:
+ model_inputs.pop("backbone_last_hidden_state")
+
+ # csm depth decoder does not use position_ids
+ model_inputs.pop("position_ids")
+
+ return model_inputs
+
+
+class CsmBackboneModelEmbeddings(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.embed_audio_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.hidden_size)
+ self.register_buffer(
+ "audio_tokens_offsets", torch.arange(config.num_codebooks) * config.vocab_size, persistent=False
+ )
+
+ def forward(self, input_ids):
+ input_embeds = self.embed_audio_tokens(input_ids + self.audio_tokens_offsets)
+ input_embeds = input_embeds.sum(dim=2)
+ return input_embeds
+
+
+@auto_docstring
+class CsmBackboneModel(CsmPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.embed_tokens = CsmBackboneModelEmbeddings(config)
+ self.layers = nn.ModuleList(
+ [CsmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = CsmRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
+ 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
+ requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
+
+ 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Csm model consists of two llama-like auto-regressive transformer models: a backbone model that predicts the first codebook token and a depth decoder that predicts the other codebook tokens.
+ """
+)
+class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
+ _tied_weights_keys = [
+ "backbone_model.embed_tokens.embed_audio_tokens.weight",
+ "depth_decoder.model.embed_tokens.weight",
+ ]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.embed_text_tokens = nn.Embedding(config.text_vocab_size, config.hidden_size)
+ self.backbone_model = CsmBackboneModel._from_config(config)
+ self.depth_decoder = CsmDepthDecoderForCausalLM._from_config(config.depth_decoder_config)
+ self.codec_model = AutoModel.from_config(config.codec_config)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.backbone_model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.backbone_model.embed_tokens = value
+
+ def _tie_weights(self):
+ if self.config.tie_codebooks_embeddings:
+ self._tie_or_clone_weights(
+ self.backbone_model.embed_tokens.embed_audio_tokens,
+ self.depth_decoder.model.embed_tokens,
+ )
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ if kwargs.get("output_loading_info", False):
+ model, loading_info = super().from_pretrained(*args, **kwargs)
+ else:
+ model = super().from_pretrained(*args, **kwargs)
+
+ # copy depth decoder generation conf attr to the depth decoder generation config
+ prefix = "depth_decoder_"
+ prefix_len = len(prefix)
+ depth_decoder_attrs = {
+ attr[prefix_len:]: value
+ for attr, value in vars(model.generation_config).items()
+ if attr.startswith(prefix)
+ }
+
+ vars(model.depth_decoder.generation_config).update({"_from_model_config": False, **depth_decoder_attrs})
+
+ # remove the depth decoder generation conf attr from the model generation config
+ for attr in depth_decoder_attrs:
+ delattr(model.generation_config, prefix + attr)
+
+ if "output_loading_info" in kwargs:
+ return model, loading_info
+ else:
+ return model
+
+ def save_pretrained(self, *args, **kwargs):
+ # copy the depth decoder generation config attributes to the model generation config
+ prefix = "depth_decoder_"
+ depth_decoder_attrs = self.depth_decoder.generation_config.to_diff_dict()
+ depth_decoder_attrs.pop("transformers_version", None)
+ for attr, value in depth_decoder_attrs.items():
+ setattr(self.generation_config, prefix + attr, value)
+
+ super().save_pretrained(*args, **kwargs)
+
+ def _merge_input_ids_with_input_values(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ input_values: Optional[torch.Tensor] = None,
+ input_values_cutoffs: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Optional[torch.Tensor]:
+ """
+ Merges the input_ids and input_values to produce a single inputs_embeds tensor:
+ 1 - Infers the codec model on the input_values to retrieve codebook token.
+ 2 - Embeds codebook tokens and places them at the correct positions in the inputs_embeds tensor.
+ 3 - If labels are provided, expands them to match codebook dimensions and position the target codebook tokens in the inputs_embeds tensor.
+
+ Args:
+ input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`):
+ The input ids to embed.
+ input_values (`torch.Tensor` of shape `(batch_size, channels, audio_sequence_length)`):
+ The audio input values to embed.
+ input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`):
+ The cutoffs of the audio input values relative to its batch index, padded with -1 when no audio.
+ """
+ inputs_embeds = self.embed_text_tokens(input_ids)
+
+ if input_values is not None:
+ # infer input_values_mask
+ input_values_cutoffs = nn.functional.pad(input_values_cutoffs, (1, 0))
+ audio_lengths = input_values_cutoffs[input_values_cutoffs >= 0].diff()
+ audio_lengths = audio_lengths[audio_lengths > 0]
+ input_values_mask = torch.arange(input_values_cutoffs.max(), device=input_values.device).expand(
+ len(audio_lengths), -1
+ )
+ input_values_mask = input_values_mask < audio_lengths.unsqueeze(1)
+
+ # =======================================
+ # TODO: @eustlb, this should be batched !!!
+ # but requires making sure batched inference of the codec model works as intended
+ with torch.no_grad():
+ audio_tokens_list = []
+ for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
+ batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
+ for i in range(batch_input_values_cutoffs.shape[0] - 1):
+ start_idx = batch_input_values_cutoffs[i]
+ end_idx = batch_input_values_cutoffs[i + 1]
+ audio_batch = batch_input_values[..., start_idx:end_idx]
+ codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
+ codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
+ audio_tokens_list.append(codebook_ids[0])
+
+ max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
+ batched_audio_token_ids = torch.stack(
+ [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
+ )
+ audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
+ # =======================================
+ audio_token_id = self.config.audio_token_id
+ audio_token_mask = input_ids == audio_token_id
+
+ audio_embeds = self.backbone_model.embed_tokens(batched_audio_token_ids)
+ inputs_embeds[audio_token_mask] = audio_embeds[audio_codes_mask]
+
+ # same for the audio eos token
+ audio_eos_frame_ids = (
+ torch.ones((1, 1, self.config.num_codebooks), device=input_ids.device, dtype=torch.long)
+ * self.config.codebook_eos_token_id
+ )
+ audio_eos_embeds = self.backbone_model.embed_tokens(audio_eos_frame_ids).squeeze(1)
+
+ audio_eos_token_mask = input_ids == self.config.audio_eos_token_id
+ inputs_embeds[audio_eos_token_mask] = audio_eos_embeds.repeat(audio_eos_token_mask.sum(), 1)
+
+ # if the labels are provided, we need to expand the labels to (batch_size, seq_length, num_codebooks)
+ if labels is not None:
+ labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
+ labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
+ labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
+ # mask depth decoder
+ depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
+ labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
+ labels = labels_expanded
+
+ return {"inputs_embeds": inputs_embeds, "labels": labels}
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ if input_ids is not None and input_ids.ndim == 2 and model_inputs.get("inputs_embeds") is None:
+ merged_inputs = self._merge_input_ids_with_input_values(
+ input_ids=input_ids,
+ input_values=kwargs.get("input_values"),
+ input_values_cutoffs=kwargs.get("input_values_cutoffs"),
+ labels=kwargs.get("labels"),
+ )
+ model_inputs.update(
+ {"inputs_embeds": merged_inputs["inputs_embeds"], "labels": merged_inputs["labels"], "input_ids": None}
+ )
+
+ return model_inputs
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ input_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ input_values_cutoffs: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, CsmOutputWithPast]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
+ 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
+ requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
+
+ 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`, *optional*):
+ Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
+ If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
+ where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
+ the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[config.audio_token_id, -100, -101]`.
+ Requires targeted `input_values` to be provided as audio tokens will be inferred from it using the `codec_model`.
+ - `config.audio_token_id` indicates an audio frames (considering sequence length elements as frames)
+ - `-100` will be ignored in the loss computation
+ - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
+
+ Such labels can be prepared using `output_labels=True` when calling [`CsmProcessor`].
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ Kept for compatibility. Does not support another value than:
+ 1. `0`, which is equivalent to keeping all logits, used in the training regime
+ 2. `1`, which is equivalent to keeping only the last logit, used in the generation regime
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import CsmForConditionalGeneration, AutoProcessor
+ >>> from datasets import load_dataset, Audio
+
+ >>> model_id = "sesame/csm-1b"
+ >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+
+ >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+ >>> # ensure the audio is 24kHz
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
+
+ >>> conversation = []
+ >>> # prepare a conversation with text and corresponding audio
+ >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
+ ... conversation.append(
+ ... {
+ ... "role": f"{speaker_id}",
+ ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
+ ... }
+ ... )
+
+ >>> inputs = processor.apply_chat_template(
+ ... conversation,
+ ... tokenize=True,
+ ... return_dict=True,
+ ... output_labels=True,
+ ... ).to(torch_device)
+
+ >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
+ >>> output = model(**inputs)
+ >>> output.loss.backward()
+ ```"""
+ if input_ids is not None and input_ids.ndim == 2:
+ merged_inputs = self._merge_input_ids_with_input_values(
+ input_ids, input_values, input_values_cutoffs, labels
+ )
+ inputs_embeds = merged_inputs["inputs_embeds"]
+ labels = merged_inputs["labels"]
+ input_ids = None
+
+ backbone_outputs = self.backbone_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ backbone_hidden_states = backbone_outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :])
+
+ loss = None
+ backbone_loss = None
+ depth_decoder_loss = None
+ depth_decoder_outputs = None
+ if labels is not None:
+ # select first codebook as labels for the backbone model
+ backbone_labels = labels[:, :, 0]
+ backbone_loss = self.loss_function(
+ logits=backbone_logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs
+ )
+
+ # for the depth decoder, we need to select the frames to train on
+ # those are frames where the label is not uniformly `ignore_index` along the codebook dimension
+ train_mask = ~(labels[:, :, 1:] == -100).all(dim=-1)
+ depth_decoder_input_ids = labels[train_mask][..., : self.config.num_codebooks - 1]
+ # add place holder in position 0 that will be replaced by the backbone_last_hidden_state
+ depth_decoder_input_ids = nn.functional.pad(depth_decoder_input_ids, (1, 0), value=0)
+
+ train_idxs = train_mask.nonzero(as_tuple=True)
+ backbone_last_hidden_states = backbone_hidden_states[train_idxs[0], train_idxs[1] - 1, :]
+ depth_decoder_labels = labels[train_mask]
+
+ depth_decoder_outputs = self.depth_decoder(
+ input_ids=depth_decoder_input_ids,
+ backbone_last_hidden_state=backbone_last_hidden_states,
+ use_cache=use_cache,
+ return_dict=True,
+ labels=depth_decoder_labels,
+ **kwargs,
+ )
+
+ depth_decoder_loss = depth_decoder_outputs.loss
+ loss = backbone_loss + depth_decoder_loss
+
+ return CsmOutputWithPast(
+ loss=loss,
+ backbone_loss=backbone_loss,
+ depth_decoder_loss=depth_decoder_loss,
+ logits=backbone_logits,
+ past_key_values=backbone_outputs.past_key_values,
+ hidden_states=backbone_outputs.hidden_states,
+ attentions=backbone_outputs.attentions,
+ depth_decoder_logits=depth_decoder_outputs.logits if depth_decoder_outputs is not None else None,
+ depth_decoder_past_key_values=depth_decoder_outputs.past_key_values
+ if depth_decoder_outputs is not None
+ else None,
+ depth_decoder_hidden_states=depth_decoder_outputs.hidden_states
+ if depth_decoder_outputs is not None
+ else None,
+ depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None,
+ )
+
+
+__all__ = [
+ "CsmPreTrainedModel",
+ "CsmBackboneModel",
+ "CsmDepthDecoderModel",
+ "CsmDepthDecoderForCausalLM",
+ "CsmForConditionalGeneration",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/csm/modular_csm.py b/venv/lib/python3.13/site-packages/transformers/models/csm/modular_csm.py
new file mode 100644
index 0000000000000000000000000000000000000000..89a6e52a063b4a491f1c2df2a06648d35ea39789
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/csm/modular_csm.py
@@ -0,0 +1,766 @@
+# coding=utf-8
+# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+
+from transformers.utils.generic import check_model_inputs
+
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
+from ..auto import AutoModel
+from ..llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaForCausalLM,
+ LlamaMLP,
+ LlamaModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ TransformersKwargs,
+)
+from .configuration_csm import CsmConfig, CsmDepthDecoderConfig
+from .generation_csm import CsmGenerationMixin
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for the model autoregressive outputs.
+ """
+)
+class CsmOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ depth_decoder_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction) of the depth decoder model.
+ depth_decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the depth decoder (scores for each vocabulary token before SoftMax).
+ depth_decoder_past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+ depth_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ depth_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+ backbone_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction) of the backbone model.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ depth_decoder_loss: Optional[torch.FloatTensor] = None
+ depth_decoder_logits: Optional[torch.FloatTensor] = None
+ depth_decoder_past_key_values: Optional[Cache] = None
+ depth_decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ depth_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ backbone_loss: Optional[torch.FloatTensor] = None
+
+
+# manually specify names for correct naming when converting from modular
+class CsmRMSNorm(LlamaRMSNorm):
+ pass
+
+
+class CsmRotaryEmbedding(LlamaRotaryEmbedding):
+ pass
+
+
+class CsmMLP(LlamaMLP):
+ pass
+
+
+class CsmAttention(LlamaAttention):
+ pass
+
+
+class CsmDecoderLayer(LlamaDecoderLayer):
+ pass
+
+
+@auto_docstring(
+ custom_intro="""
+ The bare Csm Model outputting raw hidden-states without any specific head on top.
+ """
+)
+@auto_docstring
+class CsmPreTrainedModel(PreTrainedModel):
+ config: CsmConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["CsmDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ # does not because of Mimi codec model
+ # _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": CsmDecoderLayer,
+ "attentions": CsmAttention,
+ }
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, CsmCodebooksHead):
+ num_codebooks = module.num_codebooks
+ for i in range(num_codebooks - 1):
+ module.weight.data[i].normal_(mean=0.0, std=self.config.initializer_range)
+
+
+@auto_docstring
+class CsmDepthDecoderModel(LlamaModel, CsmPreTrainedModel):
+ config: CsmDepthDecoderConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.embed_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.backbone_hidden_size)
+ self.inputs_embeds_projector = nn.Linear(config.backbone_hidden_size, config.hidden_size, bias=False)
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ r"""
+ backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
+ The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
+ is provided in the `input_ids` argument.
+ """
+ if position_ids is not None and not torch.compiler.is_compiling():
+ logger.warning_once(
+ "Custom `position_ids` were provided but will be ignored. CSM depth decoder automatically determines position_ids "
+ "from `cache_position` and as it requires them to be identical across the batch, the provided position_ids will be ignored."
+ )
+ position_ids = None
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
+ device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device)
+
+ if inputs_embeds is None:
+ codebook_idxs = torch.clamp(cache_position - 1, min=0)
+ offset = codebook_idxs * self.vocab_size
+ inputs_embeds = self.embed_tokens(input_ids + offset)
+
+ input_ids_are_first_codebook = cache_position[0] == 0
+ if backbone_last_hidden_state is not None:
+ inputs_embeds[:, 0] = backbone_last_hidden_state
+ else:
+ if not torch.compiler.is_compiling() and input_ids_are_first_codebook:
+ logger.warning(
+ "When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference."
+ )
+
+ inputs_embeds = self.inputs_embeds_projector(inputs_embeds)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_ids = cache_position.unsqueeze(0)
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ )
+
+
+class CsmCodebooksHead(nn.Module):
+ def __init__(self, hidden_size, num_codebooks, vocab_size):
+ super().__init__()
+ self.num_codebooks = num_codebooks
+ self.weight = nn.Parameter(torch.empty(self.num_codebooks - 1, hidden_size, vocab_size))
+
+ def forward(self, hidden_states, cache_position=None):
+ if cache_position is None:
+ seq_length = hidden_states.shape[1]
+ codebook_weight = self.weight[torch.arange(seq_length)]
+ else:
+ codebook_idxs = cache_position - 1
+ codebook_weight = self.weight[codebook_idxs]
+
+ hidden_states = [
+ nn.functional.linear(hidden_states[:, codebook_idx, :], codebook_weight[codebook_idx].T)
+ for codebook_idx in range(codebook_weight.shape[0])
+ ]
+ hidden_states = torch.stack(hidden_states, dim=1)
+
+ return hidden_states
+
+
+@auto_docstring(
+ custom_intro="""
+ The CsmDepthDecoder Model transformer, with a [`CsmCodebooksHead`] on top,
+ which can be seen a position-specific language modeling head, allowing to use a different linear layer for each codebook
+ (e.g. position 0 is the first codebook and uses the first codebook head, etc.)
+ """
+)
+class CsmDepthDecoderForCausalLM(LlamaForCausalLM, GenerationMixin):
+ _tied_weights_keys = None
+ _tp_plan = None
+ _pp_plan = None
+
+ def __init__(self, config):
+ super().__init__(config)
+ del self.lm_head
+ self.codebooks_head = CsmCodebooksHead(config.hidden_size, config.num_codebooks, config.vocab_size)
+ self.model = CsmDepthDecoderModel(config)
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
+ )
+
+ is_first_generation_step = model_inputs["cache_position"][0] == 0
+ if not is_first_generation_step:
+ model_inputs.pop("backbone_last_hidden_state")
+
+ # csm depth decoder does not use position_ids
+ model_inputs.pop("position_ids")
+
+ return model_inputs
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
+ The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
+ is provided in the `input_ids` argument.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ """
+ outputs = self.model(
+ input_ids=input_ids,
+ backbone_last_hidden_state=backbone_last_hidden_state,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ if isinstance(logits_to_keep, int):
+ if logits_to_keep == 0:
+ # skip idx 0 logits since it's for the concatenated backbone last hidden state
+ slice_indices = slice(1, None)
+ else:
+ slice_indices = slice(-logits_to_keep, None)
+ else:
+ slice_indices = logits_to_keep
+
+ logits = self.codebooks_head(
+ hidden_states[:, slice_indices, :], cache_position[slice_indices] if cache_position is not None else None
+ )
+ logits = logits.contiguous()
+
+ loss = None
+ if labels is not None:
+ shift_labels = labels[..., 1:].contiguous()
+ loss = self.loss_function(
+ logits=logits, labels=None, vocab_size=self.config.vocab_size, shift_labels=shift_labels, **kwargs
+ )
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class CsmBackboneModelEmbeddings(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.embed_audio_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.hidden_size)
+ self.register_buffer(
+ "audio_tokens_offsets", torch.arange(config.num_codebooks) * config.vocab_size, persistent=False
+ )
+
+ def forward(self, input_ids):
+ input_embeds = self.embed_audio_tokens(input_ids + self.audio_tokens_offsets)
+ input_embeds = input_embeds.sum(dim=2)
+ return input_embeds
+
+
+@auto_docstring
+class CsmBackboneModel(LlamaModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.embed_tokens = CsmBackboneModelEmbeddings(config)
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(self, **super_kwargs):
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
+ 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
+ requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
+
+ 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ return super().forward(**super_kwargs)
+
+
+@auto_docstring(
+ custom_intro="""
+ The Csm model consists of two llama-like auto-regressive transformer models: a backbone model that predicts the first codebook token and a depth decoder that predicts the other codebook tokens.
+ """
+)
+class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
+ _tied_weights_keys = [
+ "backbone_model.embed_tokens.embed_audio_tokens.weight",
+ "depth_decoder.model.embed_tokens.weight",
+ ]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.embed_text_tokens = nn.Embedding(config.text_vocab_size, config.hidden_size)
+ self.backbone_model = CsmBackboneModel._from_config(config)
+ self.depth_decoder = CsmDepthDecoderForCausalLM._from_config(config.depth_decoder_config)
+ self.codec_model = AutoModel.from_config(config.codec_config)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.backbone_model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.backbone_model.embed_tokens = value
+
+ def _tie_weights(self):
+ if self.config.tie_codebooks_embeddings:
+ self._tie_or_clone_weights(
+ self.backbone_model.embed_tokens.embed_audio_tokens,
+ self.depth_decoder.model.embed_tokens,
+ )
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ if kwargs.get("output_loading_info", False):
+ model, loading_info = super().from_pretrained(*args, **kwargs)
+ else:
+ model = super().from_pretrained(*args, **kwargs)
+
+ # copy depth decoder generation conf attr to the depth decoder generation config
+ prefix = "depth_decoder_"
+ prefix_len = len(prefix)
+ depth_decoder_attrs = {
+ attr[prefix_len:]: value
+ for attr, value in vars(model.generation_config).items()
+ if attr.startswith(prefix)
+ }
+
+ vars(model.depth_decoder.generation_config).update({"_from_model_config": False, **depth_decoder_attrs})
+
+ # remove the depth decoder generation conf attr from the model generation config
+ for attr in depth_decoder_attrs:
+ delattr(model.generation_config, prefix + attr)
+
+ if "output_loading_info" in kwargs:
+ return model, loading_info
+ else:
+ return model
+
+ def save_pretrained(self, *args, **kwargs):
+ # copy the depth decoder generation config attributes to the model generation config
+ prefix = "depth_decoder_"
+ depth_decoder_attrs = self.depth_decoder.generation_config.to_diff_dict()
+ depth_decoder_attrs.pop("transformers_version", None)
+ for attr, value in depth_decoder_attrs.items():
+ setattr(self.generation_config, prefix + attr, value)
+
+ super().save_pretrained(*args, **kwargs)
+
+ def _merge_input_ids_with_input_values(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ input_values: Optional[torch.Tensor] = None,
+ input_values_cutoffs: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Optional[torch.Tensor]:
+ """
+ Merges the input_ids and input_values to produce a single inputs_embeds tensor:
+ 1 - Infers the codec model on the input_values to retrieve codebook token.
+ 2 - Embeds codebook tokens and places them at the correct positions in the inputs_embeds tensor.
+ 3 - If labels are provided, expands them to match codebook dimensions and position the target codebook tokens in the inputs_embeds tensor.
+
+ Args:
+ input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`):
+ The input ids to embed.
+ input_values (`torch.Tensor` of shape `(batch_size, channels, audio_sequence_length)`):
+ The audio input values to embed.
+ input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`):
+ The cutoffs of the audio input values relative to its batch index, padded with -1 when no audio.
+ """
+ inputs_embeds = self.embed_text_tokens(input_ids)
+
+ if input_values is not None:
+ # infer input_values_mask
+ input_values_cutoffs = nn.functional.pad(input_values_cutoffs, (1, 0))
+ audio_lengths = input_values_cutoffs[input_values_cutoffs >= 0].diff()
+ audio_lengths = audio_lengths[audio_lengths > 0]
+ input_values_mask = torch.arange(input_values_cutoffs.max(), device=input_values.device).expand(
+ len(audio_lengths), -1
+ )
+ input_values_mask = input_values_mask < audio_lengths.unsqueeze(1)
+
+ # =======================================
+ # TODO: @eustlb, this should be batched !!!
+ # but requires making sure batched inference of the codec model works as intended
+ with torch.no_grad():
+ audio_tokens_list = []
+ for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
+ batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
+ for i in range(batch_input_values_cutoffs.shape[0] - 1):
+ start_idx = batch_input_values_cutoffs[i]
+ end_idx = batch_input_values_cutoffs[i + 1]
+ audio_batch = batch_input_values[..., start_idx:end_idx]
+ codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
+ codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
+ audio_tokens_list.append(codebook_ids[0])
+
+ max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
+ batched_audio_token_ids = torch.stack(
+ [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
+ )
+ audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
+ # =======================================
+ audio_token_id = self.config.audio_token_id
+ audio_token_mask = input_ids == audio_token_id
+
+ audio_embeds = self.backbone_model.embed_tokens(batched_audio_token_ids)
+ inputs_embeds[audio_token_mask] = audio_embeds[audio_codes_mask]
+
+ # same for the audio eos token
+ audio_eos_frame_ids = (
+ torch.ones((1, 1, self.config.num_codebooks), device=input_ids.device, dtype=torch.long)
+ * self.config.codebook_eos_token_id
+ )
+ audio_eos_embeds = self.backbone_model.embed_tokens(audio_eos_frame_ids).squeeze(1)
+
+ audio_eos_token_mask = input_ids == self.config.audio_eos_token_id
+ inputs_embeds[audio_eos_token_mask] = audio_eos_embeds.repeat(audio_eos_token_mask.sum(), 1)
+
+ # if the labels are provided, we need to expand the labels to (batch_size, seq_length, num_codebooks)
+ if labels is not None:
+ labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
+ labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
+ labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
+ # mask depth decoder
+ depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
+ labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
+ labels = labels_expanded
+
+ return {"inputs_embeds": inputs_embeds, "labels": labels}
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ if input_ids is not None and input_ids.ndim == 2 and model_inputs.get("inputs_embeds") is None:
+ merged_inputs = self._merge_input_ids_with_input_values(
+ input_ids=input_ids,
+ input_values=kwargs.get("input_values"),
+ input_values_cutoffs=kwargs.get("input_values_cutoffs"),
+ labels=kwargs.get("labels"),
+ )
+ model_inputs.update(
+ {"inputs_embeds": merged_inputs["inputs_embeds"], "labels": merged_inputs["labels"], "input_ids": None}
+ )
+
+ return model_inputs
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ input_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ input_values_cutoffs: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, CsmOutputWithPast]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
+ 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
+ requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
+
+ 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`, *optional*):
+ Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
+ If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
+ where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
+ the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[config.audio_token_id, -100, -101]`.
+ Requires targeted `input_values` to be provided as audio tokens will be inferred from it using the `codec_model`.
+ - `config.audio_token_id` indicates an audio frames (considering sequence length elements as frames)
+ - `-100` will be ignored in the loss computation
+ - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
+
+ Such labels can be prepared using `output_labels=True` when calling [`CsmProcessor`].
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ Kept for compatibility. Does not support another value than:
+ 1. `0`, which is equivalent to keeping all logits, used in the training regime
+ 2. `1`, which is equivalent to keeping only the last logit, used in the generation regime
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import CsmForConditionalGeneration, AutoProcessor
+ >>> from datasets import load_dataset, Audio
+
+ >>> model_id = "sesame/csm-1b"
+ >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+
+ >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+ >>> # ensure the audio is 24kHz
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
+
+ >>> conversation = []
+ >>> # prepare a conversation with text and corresponding audio
+ >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
+ ... conversation.append(
+ ... {
+ ... "role": f"{speaker_id}",
+ ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
+ ... }
+ ... )
+
+ >>> inputs = processor.apply_chat_template(
+ ... conversation,
+ ... tokenize=True,
+ ... return_dict=True,
+ ... output_labels=True,
+ ... ).to(torch_device)
+
+ >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
+ >>> output = model(**inputs)
+ >>> output.loss.backward()
+ ```"""
+ if input_ids is not None and input_ids.ndim == 2:
+ merged_inputs = self._merge_input_ids_with_input_values(
+ input_ids, input_values, input_values_cutoffs, labels
+ )
+ inputs_embeds = merged_inputs["inputs_embeds"]
+ labels = merged_inputs["labels"]
+ input_ids = None
+
+ backbone_outputs = self.backbone_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ backbone_hidden_states = backbone_outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :])
+
+ loss = None
+ backbone_loss = None
+ depth_decoder_loss = None
+ depth_decoder_outputs = None
+ if labels is not None:
+ # select first codebook as labels for the backbone model
+ backbone_labels = labels[:, :, 0]
+ backbone_loss = self.loss_function(
+ logits=backbone_logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs
+ )
+
+ # for the depth decoder, we need to select the frames to train on
+ # those are frames where the label is not uniformly `ignore_index` along the codebook dimension
+ train_mask = ~(labels[:, :, 1:] == -100).all(dim=-1)
+ depth_decoder_input_ids = labels[train_mask][..., : self.config.num_codebooks - 1]
+ # add place holder in position 0 that will be replaced by the backbone_last_hidden_state
+ depth_decoder_input_ids = nn.functional.pad(depth_decoder_input_ids, (1, 0), value=0)
+
+ train_idxs = train_mask.nonzero(as_tuple=True)
+ backbone_last_hidden_states = backbone_hidden_states[train_idxs[0], train_idxs[1] - 1, :]
+ depth_decoder_labels = labels[train_mask]
+
+ depth_decoder_outputs = self.depth_decoder(
+ input_ids=depth_decoder_input_ids,
+ backbone_last_hidden_state=backbone_last_hidden_states,
+ use_cache=use_cache,
+ return_dict=True,
+ labels=depth_decoder_labels,
+ **kwargs,
+ )
+
+ depth_decoder_loss = depth_decoder_outputs.loss
+ loss = backbone_loss + depth_decoder_loss
+
+ return CsmOutputWithPast(
+ loss=loss,
+ backbone_loss=backbone_loss,
+ depth_decoder_loss=depth_decoder_loss,
+ logits=backbone_logits,
+ past_key_values=backbone_outputs.past_key_values,
+ hidden_states=backbone_outputs.hidden_states,
+ attentions=backbone_outputs.attentions,
+ depth_decoder_logits=depth_decoder_outputs.logits if depth_decoder_outputs is not None else None,
+ depth_decoder_past_key_values=depth_decoder_outputs.past_key_values
+ if depth_decoder_outputs is not None
+ else None,
+ depth_decoder_hidden_states=depth_decoder_outputs.hidden_states
+ if depth_decoder_outputs is not None
+ else None,
+ depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None,
+ )
+
+
+__all__ = [
+ "CsmPreTrainedModel",
+ "CsmBackboneModel",
+ "CsmDepthDecoderModel",
+ "CsmDepthDecoderForCausalLM",
+ "CsmForConditionalGeneration",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/csm/processing_csm.py b/venv/lib/python3.13/site-packages/transformers/models/csm/processing_csm.py
new file mode 100644
index 0000000000000000000000000000000000000000..95596f4a3a9e1d4b06f3f7e1fcfb94377b66c035
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/csm/processing_csm.py
@@ -0,0 +1,373 @@
+# coding=utf-8
+# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from pathlib import Path
+from typing import Any, Optional, Union
+
+import numpy as np
+
+from ...utils import is_soundfile_available, is_torch_available
+
+
+if is_torch_available():
+ import torch
+
+if is_soundfile_available():
+ import soundfile as sf
+
+from ...audio_utils import AudioInput, make_list_of_audio
+from ...feature_extraction_utils import BatchFeature
+from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+
+
+class CsmAudioKwargs(AudioKwargs, total=False):
+ encoded_length_kwargs: Optional[dict[str, Any]]
+
+
+class CsmProcessorKwargs(ProcessingKwargs, total=False):
+ audio_kwargs: CsmAudioKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": True,
+ "padding_side": "left",
+ "add_special_tokens": False,
+ },
+ "audio_kwargs": {
+ "encoded_length_kwargs": {
+ "kernel_sizes": [7, 3, 1, 8, 3, 1, 10, 3, 1, 12, 3, 1, 16, 3, 4],
+ "strides": [1, 1, 1, 4, 1, 1, 5, 1, 1, 6, 1, 1, 8, 1, 2],
+ "dilations": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
+ "use_causal_conv": True,
+ },
+ "sampling_rate": 24000,
+ },
+ "common_kwargs": {"return_tensors": "pt"},
+ }
+
+
+class CsmProcessor(ProcessorMixin):
+ r"""
+ Constructs a Csm processor which wraps [`EncodecFeatureExtractor`] and
+ [`PretrainedTokenizerFast`] into a single processor that inherits both the audio feature extraction and
+ tokenizer functionalities. See the [`~CsmProcessor.__call__`] for more
+ information.
+ The preferred way of passing kwargs is as a dictionary per modality, see usage example below.
+ ```python
+ from transformers import CsmProcessor
+ from datasets import load_dataset
+
+ ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+ audio = ds[0]["audio"]["array"]
+
+ processor = CsmProcessor.from_pretrained("sesame/csm-1b")
+
+ processor(
+ text=["<|begin_of_text|>[0]What are you working on?<|end_of_text|><|AUDIO|><|audio_eos|><|begin_of_text|>[1]I'm figuring out my budget.<|end_of_text|>"],
+ audio=audio,
+ text_kwargs = {"padding": False},
+ audio_kwargs = {"sampling_rate": 16000},
+ common_kwargs = {"return_tensors": "pt"},
+ )
+ # this should error out because EncodecFeatureExtractor expects a 24kHz audio :)
+ ```
+
+ Args:
+ feature_extractor ([`EncodecFeatureExtractor`]):
+ The feature extractor is a required input.
+ tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+
+ """
+
+ attributes = ["feature_extractor", "tokenizer"]
+ feature_extractor_class = "EncodecFeatureExtractor"
+ tokenizer_class = "PreTrainedTokenizerFast"
+
+ def __init__(
+ self,
+ feature_extractor,
+ tokenizer,
+ chat_template=None,
+ ):
+ if not hasattr(tokenizer, "audio_token"):
+ self.audio_token = "<|AUDIO|>"
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
+ else:
+ self.audio_token = tokenizer.audio_token
+ self.audio_token_id = tokenizer.audio_token_id
+
+ if not hasattr(tokenizer, "audio_eos_token"):
+ self.audio_eos_token = "<|audio_eos|>"
+ self.audio_eos_token_id = tokenizer.convert_tokens_to_ids(self.audio_eos_token)
+ else:
+ self.audio_eos_token = tokenizer.audio_eos_token
+ self.audio_eos_token_id = tokenizer.audio_eos_token_id
+
+ super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
+
+ @staticmethod
+ def _get_encoded_length(audio_length, kernel_sizes=None, strides=None, dilations=None, use_causal_conv=None):
+ """
+ Compute the length of the encoded audio sequence.
+
+ Args:
+ audio_length (int): The length of the audio sequence.
+ kernel_sizes (list[int]): The kernel sizes for the convolutional layers.
+ strides (list[int]): The strides for the convolutional layers.
+ use_causal_conv (bool): Whether to use causal convolutions.
+ """
+ cur_length = audio_length
+
+ if kernel_sizes is None or strides is None or dilations is None or use_causal_conv is None:
+ return cur_length
+
+ for kernel_size, stride, dilation in zip(kernel_sizes, strides, dilations):
+ effective_kernel_size = (kernel_size - 1) * dilation + 1
+ padding_total = kernel_size - stride
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+
+ n_frames = (cur_length - effective_kernel_size + padding_total) / stride + 1
+ n_frames = math.ceil(n_frames) - 1
+ ideal_length = n_frames * stride + kernel_size - padding_total
+ extra_padding = ideal_length - cur_length
+
+ if use_causal_conv:
+ padding_left = padding_total
+ padding_right = extra_padding
+ else:
+ padding_right = padding_right + extra_padding
+
+ cur_length = cur_length + padding_left + padding_right
+ cur_length = (cur_length - dilation * (kernel_size - 1) - 1) // stride + 1
+
+ return cur_length
+
+ def save_audio(
+ self,
+ audio: AudioInput,
+ saving_path: Union[str, Path, list[Union[str, Path]]],
+ **kwargs: Unpack[CsmProcessorKwargs],
+ ):
+ # TODO: @eustlb, this should be in AudioProcessor
+ if not is_soundfile_available():
+ raise ImportError("Please install `soundfile` to save audio files.")
+
+ # ensure correct audio input
+ audio = make_list_of_audio(audio)
+
+ # ensure correct saving path
+ if isinstance(saving_path, (str, Path)):
+ saving_path = [saving_path]
+ elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
+ raise ValueError("Invalid input path. Please provide a string, or a list of strings")
+
+ if len(audio) != len(saving_path):
+ raise ValueError("The number of audio and saving paths must be the same")
+
+ output_kwargs = self._merge_kwargs(
+ CsmProcessorKwargs,
+ **kwargs,
+ )
+ audio_kwargs = output_kwargs["audio_kwargs"]
+ sampling_rate = audio_kwargs["sampling_rate"]
+
+ for audio_value, p in zip(audio, saving_path):
+ if isinstance(audio_value, torch.Tensor):
+ audio_value = audio_value.cpu().float().numpy()
+ sf.write(p, audio_value, sampling_rate)
+
+ def __call__(
+ self,
+ text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]],
+ audio: Optional[AudioInput] = None,
+ output_labels: Optional[bool] = False,
+ depth_decoder_labels_ratio: Optional[float] = 1.0,
+ **kwargs: Unpack[CsmProcessorKwargs],
+ ):
+ r"""
+ Main method to prepare text(s) and audio to be fed as input to the model. This method forwards the `text`
+ arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode
+ the text. To prepare the audio, this method forwards the `audio` arguments to
+ EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`]. Please refer
+ to the docstring of the above two methods for more information.
+
+ Args:
+ audio (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The audio or batch of audio to be prepared. Each audio can be a NumPy array or PyTorch
+ tensor.
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ output_labels (bool, *optional*, default=False):
+ Whether to return labels for training. Indices will be in `[config.audio_token_id, -100, -101]`.
+ - `config.audio_token_id` indicates an audio frame (considering sequence length elements as frames)
+ - `-100` will be ignored in the loss computation
+ - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
+ depth_decoder_labels_ratio (float, *optional*, default=1.0):
+ The ratio of audio frames to keep for the depth decoder labels.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **input_values** -- List of audio values to be fed to a model. Returned when `audio` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **labels** -- List of labels for the audio frames. Returned when `output_labels=True`.
+ """
+
+ output_kwargs = self._merge_kwargs(
+ CsmProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ text_kwargs = output_kwargs["text_kwargs"]
+ audio_kwargs = output_kwargs["audio_kwargs"]
+ common_kwargs = output_kwargs["common_kwargs"]
+
+ return_tensors = common_kwargs.pop("return_tensors", None)
+ if return_tensors != "pt":
+ raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
+
+ if isinstance(text, str):
+ text = [text]
+ elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+ n_audio_in_text = [t.count(self.audio_token) for t in text]
+
+ n_audio = 0
+ if audio is not None:
+ audio = make_list_of_audio(audio)
+ n_audio = len(audio)
+
+ if sum(n_audio_in_text) > 0 and n_audio != sum(n_audio_in_text):
+ if audio is None:
+ raise ValueError("No audio were provided, but there are audio tokens in the prompt")
+ else:
+ raise ValueError(
+ f"The number of audio tokens in each text ({n_audio_in_text}) should be the same as the "
+ f"number of provided audios ({n_audio})."
+ )
+
+ if audio is not None:
+ encoded_length_kwargs = audio_kwargs.pop("encoded_length_kwargs", {})
+ num_audio_tokens_list = [
+ self._get_encoded_length(audio_array.shape[-1], **encoded_length_kwargs) for audio_array in audio
+ ]
+ num_audio_tokens_list_copy = num_audio_tokens_list.copy()
+
+ # expand the text to repeat the audio token for the corresponding number of frames
+ expanded_text = []
+ for sample in text:
+ replace_str = []
+ while self.audio_token in sample:
+ num_audio_tokens = num_audio_tokens_list_copy.pop(0)
+ expanded_audio_token = self.audio_token * num_audio_tokens
+
+ replace_str.append(expanded_audio_token)
+ sample = sample.replace(self.audio_token, "", 1)
+
+ while "" in sample:
+ sample = sample.replace("", replace_str.pop(0), 1)
+ expanded_text.append(sample)
+
+ text = expanded_text
+
+ encoding = self.tokenizer(text, **text_kwargs)
+ data = {}
+ data.update(encoding)
+
+ if audio is not None:
+ audio_kwargs.pop("return_attention_mask", None) # not supported by the feature extractor
+
+ concatenated_audio, input_values_cutoffs = [], []
+ offset = 0
+ for n_audio in n_audio_in_text:
+ if n_audio == 0:
+ concatenated_audio.append(np.zeros(0))
+ input_values_cutoffs.append(torch.tensor([-1]))
+ else:
+ concatenated_audio.append(
+ np.concatenate(
+ [
+ el.cpu().numpy() if isinstance(el, torch.Tensor) else el
+ for el in audio[offset : offset + n_audio]
+ ],
+ axis=-1,
+ )
+ )
+ input_values_cutoffs.append(
+ torch.tensor([el.shape[-1] for el in audio[offset : offset + n_audio]]).cumsum(dim=-1)
+ )
+ offset += n_audio
+
+ audio_inputs = self.feature_extractor(concatenated_audio, **audio_kwargs)
+ audio_inputs.pop("padding_mask", None) # not applicable here
+ data.update(audio_inputs)
+
+ # pad and stack the audio cut idxs
+ max_len = max(cut_idxs.shape[-1] for cut_idxs in input_values_cutoffs)
+ input_values_cutoffs = [
+ torch.nn.functional.pad(cut_idxs, (0, max_len - cut_idxs.shape[-1]), value=-1)
+ for cut_idxs in input_values_cutoffs
+ ]
+ data["input_values_cutoffs"] = torch.stack(input_values_cutoffs, dim=0)
+
+ if output_labels:
+ audio_frame_idxs = (data["input_ids"] == self.audio_token_id).nonzero()
+ n_audio_frames = audio_frame_idxs.shape[0]
+
+ if depth_decoder_labels_ratio <= 1.0:
+ rand_idxs = torch.randperm(n_audio_frames)[: int(n_audio_frames * (1 - depth_decoder_labels_ratio))]
+ skip_frames_idxs = audio_frame_idxs[rand_idxs]
+ else:
+ skip_frames_idxs = audio_frame_idxs
+
+ labels = torch.where(
+ (data["input_ids"] == self.audio_token_id) | (data["input_ids"] == self.audio_eos_token_id),
+ data["input_ids"],
+ -100,
+ )
+ labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
+
+ data["labels"] = labels
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ feature_extractor_input_names = self.feature_extractor.model_input_names
+
+ # Remove `padding_mask`, it is popped and not used when processing. Make a copy of list when removing
+ # otherwise `self.feature_extractor.model_input_names` is also modified
+ feature_extractor_input_names = [name for name in feature_extractor_input_names if name != "padding_mask"]
+ return list(tokenizer_input_names + feature_extractor_input_names + ["input_values_cutoffs"])
+
+
+__all__ = ["CsmProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/d_fine/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/d_fine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..879b53709bc673bcf28553a51175f06fa1e362c0
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/d_fine/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_d_fine import *
+ from .modeling_d_fine import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/d_fine/configuration_d_fine.py b/venv/lib/python3.13/site-packages/transformers/models/d_fine/configuration_d_fine.py
new file mode 100644
index 0000000000000000000000000000000000000000..7484d9a347e534f3ebfff5a0776a0413a5a416dc
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/d_fine/configuration_d_fine.py
@@ -0,0 +1,433 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/d_fine/modular_d_fine.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_d_fine.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import verify_backbone_config_arguments
+from ..auto import CONFIG_MAPPING
+
+
+logger = logging.get_logger(__name__)
+
+
+# TODO: Attribute map assignment logic should be fixed in modular
+# as well as super() call parsing because otherwise we cannot re-write args after initialization
+class DFineConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`DFineModel`]. It is used to instantiate a D-FINE
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of D-FINE-X-COCO "[ustc-community/dfine-xlarge-coco"](https://huggingface.co/ustc-community/dfine-xlarge-coco").
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ initializer_range (`float`, *optional*, defaults to 0.01):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_bias_prior_prob (`float`, *optional*):
+ The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
+ If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ batch_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the batch normalization layers.
+ backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
+ The configuration of the backbone model.
+ backbone (`str`, *optional*):
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
+ Whether to use pretrained weights for the backbone.
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
+ library.
+ freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
+ Whether to freeze the batch normalization layers in the backbone.
+ backbone_kwargs (`dict`, *optional*):
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
+ encoder_hidden_dim (`int`, *optional*, defaults to 256):
+ Dimension of the layers in hybrid encoder.
+ encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
+ Multi level features input for encoder.
+ feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`):
+ Strides used in each feature map.
+ encoder_layers (`int`, *optional*, defaults to 1):
+ Total of layers to be used by the encoder.
+ encoder_ffn_dim (`int`, *optional*, defaults to 1024):
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The ratio for all dropout layers.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`):
+ Indexes of the projected layers to be used in the encoder.
+ positional_encoding_temperature (`int`, *optional*, defaults to 10000):
+ The temperature parameter used to create the positional encodings.
+ encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ activation_function (`str`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ eval_size (`tuple[int, int]`, *optional*):
+ Height and width used to computes the effective height and width of the position embeddings after taking
+ into account the stride.
+ normalize_before (`bool`, *optional*, defaults to `False`):
+ Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
+ feed-forward modules.
+ hidden_expansion (`float`, *optional*, defaults to 1.0):
+ Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
+ d_model (`int`, *optional*, defaults to 256):
+ Dimension of the layers exclude hybrid encoder.
+ num_queries (`int`, *optional*, defaults to 300):
+ Number of object queries.
+ decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
+ Multi level features dimension for decoder
+ decoder_ffn_dim (`int`, *optional*, defaults to 1024):
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
+ num_feature_levels (`int`, *optional*, defaults to 3):
+ The number of input feature levels.
+ decoder_n_points (`int`, *optional*, defaults to 4):
+ The number of sampled keys in each feature level for each attention head in the decoder.
+ decoder_layers (`int`, *optional*, defaults to 6):
+ Number of decoder layers.
+ decoder_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
+ The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ num_denoising (`int`, *optional*, defaults to 100):
+ The total number of denoising tasks or queries to be used for contrastive denoising.
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
+ The fraction of denoising labels to which random noise should be added.
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
+ Scale or magnitude of noise to be added to the bounding boxes.
+ learn_initial_query (`bool`, *optional*, defaults to `False`):
+ Indicates whether the initial query embeddings for the decoder should be learned during training
+ anchor_image_size (`tuple[int, int]`, *optional*):
+ Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
+ with_box_refine (`bool`, *optional*, defaults to `True`):
+ Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
+ based on the predictions from the previous layer.
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+ Whether the architecture has an encoder decoder structure.
+ matcher_alpha (`float`, *optional*, defaults to 0.25):
+ Parameter alpha used by the Hungarian Matcher.
+ matcher_gamma (`float`, *optional*, defaults to 2.0):
+ Parameter gamma used by the Hungarian Matcher.
+ matcher_class_cost (`float`, *optional*, defaults to 2.0):
+ The relative weight of the class loss used by the Hungarian Matcher.
+ matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
+ The relative weight of the bounding box loss used by the Hungarian Matcher.
+ matcher_giou_cost (`float`, *optional*, defaults to 2.0):
+ The relative weight of the giou loss of used by the Hungarian Matcher.
+ use_focal_loss (`bool`, *optional*, defaults to `True`):
+ Parameter informing if focal focal should be used.
+ auxiliary_loss (`bool`, *optional*, defaults to `True`):
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
+ focal_loss_alpha (`float`, *optional*, defaults to 0.75):
+ Parameter alpha used to compute the focal loss.
+ focal_loss_gamma (`float`, *optional*, defaults to 2.0):
+ Parameter gamma used to compute the focal loss.
+ weight_loss_vfl (`float`, *optional*, defaults to 1.0):
+ Relative weight of the varifocal loss in the object detection loss.
+ weight_loss_bbox (`float`, *optional*, defaults to 5.0):
+ Relative weight of the L1 bounding box loss in the object detection loss.
+ weight_loss_giou (`float`, *optional*, defaults to 2.0):
+ Relative weight of the generalized IoU loss in the object detection loss.
+ weight_loss_fgl (`float`, *optional*, defaults to 0.15):
+ Relative weight of the fine-grained localization loss in the object detection loss.
+ weight_loss_ddf (`float`, *optional*, defaults to 1.5):
+ Relative weight of the decoupled distillation focal loss in the object detection loss.
+ eos_coefficient (`float`, *optional*, defaults to 0.0001):
+ Relative classification weight of the 'no-object' class in the object detection loss.
+ eval_idx (`int`, *optional*, defaults to -1):
+ Index of the decoder layer to use for evaluation. If negative, counts from the end
+ (e.g., -1 means use the last layer). This allows for early prediction in the decoder
+ stack while still training later layers.
+ layer_scale (`float`, *optional*, defaults to `1.0`):
+ Scaling factor for the hidden dimension in later decoder layers. Used to adjust the
+ model capacity after the evaluation layer.
+ max_num_bins (`int`, *optional*, defaults to 32):
+ Maximum number of bins for the distribution-guided bounding box refinement.
+ Higher values allow for more fine-grained localization but increase computation.
+ reg_scale (`float`, *optional*, defaults to 4.0):
+ Scale factor for the regression distribution. Controls the range and granularity
+ of the bounding box refinement process.
+ depth_mult (`float`, *optional*, defaults to 1.0):
+ Multiplier for the number of blocks in RepNCSPELAN4 layers. Used to scale the model's
+ depth while maintaining its architecture.
+ top_prob_values (`int`, *optional*, defaults to 4):
+ Number of top probability values to consider from each corner's distribution.
+ lqe_hidden_dim (`int`, *optional*, defaults to 64):
+ Hidden dimension size for the Location Quality Estimator (LQE) network.
+ lqe_layers (`int`, *optional*, defaults to 2):
+ Number of layers in the Location Quality Estimator MLP.
+ decoder_offset_scale (`float`, *optional*, defaults to 0.5):
+ Offset scale used in deformable attention.
+ decoder_method (`str`, *optional*, defaults to `"default"`):
+ The method to use for the decoder: `"default"` or `"discrete"`.
+ up (`float`, *optional*, defaults to 0.5):
+ Controls the upper bounds of the Weighting Function.
+ """
+
+ model_type = "d_fine"
+ layer_types = ["basic", "bottleneck"]
+ attribute_map = {
+ "hidden_size": "d_model",
+ "num_attention_heads": "encoder_attention_heads",
+ }
+
+ def __init__(
+ self,
+ initializer_range=0.01,
+ initializer_bias_prior_prob=None,
+ layer_norm_eps=1e-5,
+ batch_norm_eps=1e-5,
+ # backbone
+ backbone_config=None,
+ backbone=None,
+ use_pretrained_backbone=False,
+ use_timm_backbone=False,
+ freeze_backbone_batch_norms=True,
+ backbone_kwargs=None,
+ # encoder HybridEncoder
+ encoder_hidden_dim=256,
+ encoder_in_channels=[512, 1024, 2048],
+ feat_strides=[8, 16, 32],
+ encoder_layers=1,
+ encoder_ffn_dim=1024,
+ encoder_attention_heads=8,
+ dropout=0.0,
+ activation_dropout=0.0,
+ encode_proj_layers=[2],
+ positional_encoding_temperature=10000,
+ encoder_activation_function="gelu",
+ activation_function="silu",
+ eval_size=None,
+ normalize_before=False,
+ hidden_expansion=1.0,
+ # decoder DFineTransformer
+ d_model=256,
+ num_queries=300,
+ decoder_in_channels=[256, 256, 256],
+ decoder_ffn_dim=1024,
+ num_feature_levels=3,
+ decoder_n_points=4,
+ decoder_layers=6,
+ decoder_attention_heads=8,
+ decoder_activation_function="relu",
+ attention_dropout=0.0,
+ num_denoising=100,
+ label_noise_ratio=0.5,
+ box_noise_scale=1.0,
+ learn_initial_query=False,
+ anchor_image_size=None,
+ with_box_refine=True,
+ is_encoder_decoder=True,
+ # Loss
+ matcher_alpha=0.25,
+ matcher_gamma=2.0,
+ matcher_class_cost=2.0,
+ matcher_bbox_cost=5.0,
+ matcher_giou_cost=2.0,
+ use_focal_loss=True,
+ auxiliary_loss=True,
+ focal_loss_alpha=0.75,
+ focal_loss_gamma=2.0,
+ weight_loss_vfl=1.0,
+ weight_loss_bbox=5.0,
+ weight_loss_giou=2.0,
+ weight_loss_fgl=0.15,
+ weight_loss_ddf=1.5,
+ eos_coefficient=1e-4,
+ eval_idx=-1,
+ layer_scale=1,
+ max_num_bins=32,
+ reg_scale=4.0,
+ depth_mult=1.0,
+ top_prob_values=4,
+ lqe_hidden_dim=64,
+ lqe_layers=2,
+ decoder_offset_scale=0.5,
+ decoder_method="default",
+ up=0.5,
+ **kwargs,
+ ):
+ self.initializer_range = initializer_range
+ self.initializer_bias_prior_prob = initializer_bias_prior_prob
+ self.layer_norm_eps = layer_norm_eps
+ self.batch_norm_eps = batch_norm_eps
+ # backbone
+ if backbone_config is None and backbone is None:
+ logger.info(
+ "`backbone_config` and `backbone` are `None`. Initializing the config with the default `HGNet-V2` backbone."
+ )
+ backbone_model_type = "hgnet_v2"
+ config_class = CONFIG_MAPPING[backbone_model_type]
+ # this will map it to RTDetrResNetConfig
+ # note: we can instead create HGNetV2Config
+ # and we would need to create HGNetV2Backbone
+ backbone_config = config_class(
+ num_channels=3,
+ embedding_size=64,
+ hidden_sizes=[256, 512, 1024, 2048],
+ depths=[3, 4, 6, 3],
+ layer_type="bottleneck",
+ hidden_act="relu",
+ downsample_in_first_stage=False,
+ downsample_in_bottleneck=False,
+ out_features=None,
+ out_indices=[2, 3, 4],
+ )
+ elif isinstance(backbone_config, dict):
+ backbone_model_type = backbone_config.pop("model_type")
+ config_class = CONFIG_MAPPING[backbone_model_type]
+ backbone_config = config_class.from_dict(backbone_config)
+
+ verify_backbone_config_arguments(
+ use_timm_backbone=use_timm_backbone,
+ use_pretrained_backbone=use_pretrained_backbone,
+ backbone=backbone,
+ backbone_config=backbone_config,
+ backbone_kwargs=backbone_kwargs,
+ )
+
+ self.backbone_config = backbone_config
+ self.backbone = backbone
+ self.use_pretrained_backbone = use_pretrained_backbone
+ self.use_timm_backbone = use_timm_backbone
+ self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
+ self.backbone_kwargs = backbone_kwargs
+ # encoder
+ self.encoder_hidden_dim = encoder_hidden_dim
+ self.encoder_in_channels = encoder_in_channels
+ self.feat_strides = feat_strides
+ self.encoder_attention_heads = encoder_attention_heads
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.dropout = dropout
+ self.activation_dropout = activation_dropout
+ self.encode_proj_layers = encode_proj_layers
+ self.encoder_layers = encoder_layers
+ self.positional_encoding_temperature = positional_encoding_temperature
+ self.eval_size = eval_size
+ self.normalize_before = normalize_before
+ self.encoder_activation_function = encoder_activation_function
+ self.activation_function = activation_function
+ self.hidden_expansion = hidden_expansion
+ # decoder
+ self.d_model = d_model
+ self.num_queries = num_queries
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.decoder_in_channels = decoder_in_channels
+ self.num_feature_levels = num_feature_levels
+ self.decoder_n_points = decoder_n_points
+ self.decoder_layers = decoder_layers
+ self.decoder_attention_heads = decoder_attention_heads
+ self.decoder_activation_function = decoder_activation_function
+ self.attention_dropout = attention_dropout
+ self.num_denoising = num_denoising
+ self.label_noise_ratio = label_noise_ratio
+ self.box_noise_scale = box_noise_scale
+ self.learn_initial_query = learn_initial_query
+ self.anchor_image_size = anchor_image_size
+ self.auxiliary_loss = auxiliary_loss
+ self.with_box_refine = with_box_refine
+ # Loss
+ self.matcher_alpha = matcher_alpha
+ self.matcher_gamma = matcher_gamma
+ self.matcher_class_cost = matcher_class_cost
+ self.matcher_bbox_cost = matcher_bbox_cost
+ self.matcher_giou_cost = matcher_giou_cost
+ self.use_focal_loss = use_focal_loss
+ self.focal_loss_alpha = focal_loss_alpha
+ self.focal_loss_gamma = focal_loss_gamma
+ self.weight_loss_vfl = weight_loss_vfl
+ self.weight_loss_bbox = weight_loss_bbox
+ self.weight_loss_giou = weight_loss_giou
+ self.weight_loss_fgl = weight_loss_fgl
+ self.weight_loss_ddf = weight_loss_ddf
+ self.eos_coefficient = eos_coefficient
+ # add the new attributes with the given values or defaults
+ self.eval_idx = eval_idx
+ self.layer_scale = layer_scale
+ self.max_num_bins = max_num_bins
+ self.reg_scale = reg_scale
+ self.depth_mult = depth_mult
+ self.decoder_offset_scale = decoder_offset_scale
+ self.decoder_method = decoder_method
+ self.top_prob_values = top_prob_values
+ self.lqe_hidden_dim = lqe_hidden_dim
+ self.lqe_layers = lqe_layers
+ self.up = up
+
+ if isinstance(self.decoder_n_points, list):
+ if len(self.decoder_n_points) != self.num_feature_levels:
+ raise ValueError(
+ f"Length of decoder_n_points list ({len(self.decoder_n_points)}) must match num_feature_levels ({self.num_feature_levels})."
+ )
+
+ head_dim = self.d_model // self.decoder_attention_heads
+ if head_dim * self.decoder_attention_heads != self.d_model:
+ raise ValueError(
+ f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}"
+ )
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
+
+ @property
+ def num_attention_heads(self) -> int:
+ return self.encoder_attention_heads
+
+ @property
+ def hidden_size(self) -> int:
+ return self.d_model
+
+ @property
+ def sub_configs(self):
+ return (
+ {"backbone_config": type(self.backbone_config)}
+ if getattr(self, "backbone_config", None) is not None
+ else {}
+ )
+
+ @classmethod
+ def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
+ """Instantiate a [`DFineConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
+ configuration.
+
+ Args:
+ backbone_config ([`PretrainedConfig`]):
+ The backbone configuration.
+
+ Returns:
+ [`DFineConfig`]: An instance of a configuration object
+ """
+ return cls(
+ backbone_config=backbone_config,
+ **kwargs,
+ )
+
+
+__all__ = ["DFineConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/d_fine/modeling_d_fine.py b/venv/lib/python3.13/site-packages/transformers/models/d_fine/modeling_d_fine.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdc008e3c7bbfc6f9f6c14be256c782c11846feb
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/d_fine/modeling_d_fine.py
@@ -0,0 +1,2196 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/d_fine/modular_d_fine.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_d_fine.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from dataclasses import dataclass
+from typing import Any, Optional, Union
+
+import torch
+import torch.nn.functional as F
+import torch.nn.init as init
+from torch import Tensor, nn
+
+from ...activations import ACT2CLS, ACT2FN
+from ...image_transforms import center_to_corners_format, corners_to_center_format
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import compile_compatible_method_lru_cache
+from ...utils import ModelOutput, auto_docstring, is_torchdynamo_compiling, torch_int
+from ...utils.backbone_utils import load_backbone
+from .configuration_d_fine import DFineConfig
+
+
+def multi_scale_deformable_attention_v2(
+ value: Tensor,
+ value_spatial_shapes: Tensor,
+ sampling_locations: Tensor,
+ attention_weights: Tensor,
+ num_points_list: list[int],
+ method="default",
+) -> Tensor:
+ batch_size, _, num_heads, hidden_dim = value.shape
+ _, num_queries, num_heads, num_levels, num_points = sampling_locations.shape
+ value_list = (
+ value.permute(0, 2, 3, 1)
+ .flatten(0, 1)
+ .split([height * width for height, width in value_spatial_shapes], dim=-1)
+ )
+ # sampling_offsets [8, 480, 8, 12, 2]
+ if method == "default":
+ sampling_grids = 2 * sampling_locations - 1
+ elif method == "discrete":
+ sampling_grids = sampling_locations
+ sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ sampling_grids = sampling_grids.split(num_points_list, dim=-2)
+ sampling_value_list = []
+ for level_id, (height, width) in enumerate(value_spatial_shapes):
+ # batch_size, height*width, num_heads, hidden_dim
+ # -> batch_size, height*width, num_heads*hidden_dim
+ # -> batch_size, num_heads*hidden_dim, height*width
+ # -> batch_size*num_heads, hidden_dim, height, width
+ value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width)
+ # batch_size, num_queries, num_heads, num_points, 2
+ # -> batch_size, num_heads, num_queries, num_points, 2
+ # -> batch_size*num_heads, num_queries, num_points, 2
+ sampling_grid_l_ = sampling_grids[level_id]
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
+ if method == "default":
+ sampling_value_l_ = nn.functional.grid_sample(
+ value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
+ )
+ elif method == "discrete":
+ sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to(
+ torch.int64
+ )
+
+ # Separate clamping for x and y coordinates
+ sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1)
+ sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1)
+
+ # Combine the clamped coordinates
+ sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1)
+ sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2)
+ sampling_idx = (
+ torch.arange(sampling_coord.shape[0], device=value.device)
+ .unsqueeze(-1)
+ .repeat(1, sampling_coord.shape[1])
+ )
+ sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]]
+ sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape(
+ batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id]
+ )
+ sampling_value_list.append(sampling_value_l_)
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
+ attention_weights = attention_weights.permute(0, 2, 1, 3).reshape(
+ batch_size * num_heads, 1, num_queries, sum(num_points_list)
+ )
+ output = (
+ (torch.concat(sampling_value_list, dim=-1) * attention_weights)
+ .sum(-1)
+ .view(batch_size, num_heads * hidden_dim, num_queries)
+ )
+ return output.transpose(1, 2).contiguous()
+
+
+class DFineMultiscaleDeformableAttention(nn.Module):
+ def __init__(self, config: DFineConfig):
+ """
+ D-Fine version of multiscale deformable attention
+ """
+ super().__init__()
+ self.d_model = config.d_model
+ self.n_heads = config.decoder_attention_heads
+ self.n_levels = config.num_feature_levels
+ self.offset_scale = config.decoder_offset_scale
+ self.decoder_method = config.decoder_method
+ self.n_points = config.decoder_n_points
+
+ if isinstance(self.n_points, list):
+ num_points_list = self.n_points
+ else:
+ num_points_list = [self.n_points for _ in range(self.n_levels)]
+
+ self.num_points_list = num_points_list
+ num_points_scale = [1 / n for n in self.num_points_list for _ in range(n)]
+ self.register_buffer("num_points_scale", torch.tensor(num_points_scale, dtype=torch.float32))
+
+ self.total_points = self.n_heads * sum(self.num_points_list)
+
+ self.sampling_offsets = nn.Linear(self.d_model, self.total_points * 2)
+ self.attention_weights = nn.Linear(self.d_model, self.total_points)
+
+ self.ms_deformable_attn_core = multi_scale_deformable_attention_v2
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ reference_points=None,
+ encoder_hidden_states=None,
+ spatial_shapes=None,
+ spatial_shapes_list=None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size, num_queries, _ = hidden_states.shape
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
+
+ if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
+ raise ValueError(
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
+ )
+
+ # Reshape for multi-head attention
+ value = encoder_hidden_states.reshape(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
+ if attention_mask is not None:
+ value = value.masked_fill(~attention_mask[..., None], float(0))
+
+ sampling_offsets: torch.Tensor = self.sampling_offsets(hidden_states)
+ sampling_offsets = sampling_offsets.reshape(
+ batch_size, num_queries, self.n_heads, sum(self.num_points_list), 2
+ )
+
+ attention_weights = self.attention_weights(hidden_states).reshape(
+ batch_size, num_queries, self.n_heads, sum(self.num_points_list)
+ )
+ attention_weights = F.softmax(attention_weights, dim=-1)
+
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.tensor(spatial_shapes)
+ offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.n_levels, 1, 2)
+ sampling_locations = (
+ reference_points.reshape(batch_size, sequence_length, 1, self.n_levels, 1, 2)
+ + sampling_offsets / offset_normalizer
+ )
+ elif reference_points.shape[-1] == 4:
+ # reference_points [8, 480, None, 1, 4]
+ # sampling_offsets [8, 480, 8, 12, 2]
+ num_points_scale = self.num_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1)
+ offset = sampling_offsets * num_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale
+ sampling_locations = reference_points[:, :, None, :, :2] + offset
+ else:
+ raise ValueError(
+ f"Last dim of reference_points must be 2 or 4, but get {reference_points.shape[-1]} instead."
+ )
+
+ output = self.ms_deformable_attn_core(
+ value,
+ spatial_shapes_list,
+ sampling_locations,
+ attention_weights,
+ self.num_points_list,
+ self.decoder_method,
+ )
+
+ return output, attention_weights
+
+
+class DFineGate(nn.Module):
+ def __init__(self, d_model: int):
+ super().__init__()
+ self.gate = nn.Linear(2 * d_model, 2 * d_model)
+ self.norm = nn.LayerNorm(d_model)
+
+ def forward(self, second_residual: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
+ gate_input = torch.cat([second_residual, hidden_states], dim=-1)
+ gates = torch.sigmoid(self.gate(gate_input))
+ gate1, gate2 = gates.chunk(2, dim=-1)
+ hidden_states = self.norm(gate1 * second_residual + gate2 * hidden_states)
+ return hidden_states
+
+
+class DFineMultiheadAttention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper.
+
+ Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ if self.head_dim * num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
+ return tensor if position_embeddings is None else tensor + position_embeddings
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, target_len, embed_dim = hidden_states.size()
+ # add position embeddings to the hidden states before projecting to queries and keys
+ if position_embeddings is not None:
+ hidden_states_original = hidden_states
+ hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
+
+ # get queries, keys and values
+ query_states = self.q_proj(hidden_states) * self.scaling
+ key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size)
+ value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size)
+
+ proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
+ query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ source_len = key_states.size(1)
+
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
+ raise ValueError(
+ f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
+ attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size())
+
+ if attention_mask is not None:
+ if attention_mask.size() != (batch_size, 1, target_len, source_len):
+ raise ValueError(
+ f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
+ f" {attention_mask.size()}"
+ )
+ if attention_mask.dtype == torch.bool:
+ attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
+ attention_mask, -torch.inf
+ )
+ attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
+ attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
+ attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+class DFineDecoderLayer(nn.Module):
+ def __init__(self, config: DFineConfig):
+ super().__init__()
+ # self-attention
+ self.self_attn = DFineMultiheadAttention(
+ embed_dim=config.d_model,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.decoder_activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
+
+ # override the encoder attention module with d-fine version
+ self.encoder_attn = DFineMultiscaleDeformableAttention(config=config)
+ # feedforward neural networks
+ self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)
+ self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
+ # gate
+ self.gateway = DFineGate(config.d_model)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[torch.Tensor] = None,
+ reference_points=None,
+ spatial_shapes=None,
+ spatial_shapes_list=None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.Tensor, Any, Any]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(seq_len, batch, embed_dim)`.
+ position_embeddings (`torch.FloatTensor`, *optional*):
+ Position embeddings that are added to the queries and keys in the self-attention layer.
+ reference_points (`torch.FloatTensor`, *optional*):
+ Reference points.
+ spatial_shapes (`torch.LongTensor`, *optional*):
+ Spatial shapes.
+ level_start_index (`torch.LongTensor`, *optional*):
+ Level start index.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
+ values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ # Self Attention
+ hidden_states_2, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=encoder_attention_mask,
+ position_embeddings=position_embeddings,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
+ hidden_states = hidden_states + hidden_states_2
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ residual = hidden_states
+
+ # Cross-Attention
+ cross_attn_weights = None
+ hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings
+ hidden_states_2, cross_attn_weights = self.encoder_attn(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ )
+
+ hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
+ hidden_states = self.gateway(residual, hidden_states_2)
+
+ # Fully Connected
+ hidden_states_2 = self.activation_fn(self.fc1(hidden_states))
+ hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.activation_dropout, training=self.training)
+ hidden_states_2 = self.fc2(hidden_states_2)
+ hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
+ hidden_states = hidden_states + hidden_states_2
+ hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504))
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ return outputs
+
+
+@auto_docstring
+class DFinePreTrainedModel(PreTrainedModel):
+ config: DFineConfig
+ base_model_prefix = "d_fine"
+ main_input_name = "pixel_values"
+ _no_split_modules = [r"DFineHybridEncoder", r"DFineDecoderLayer"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ # initialize linear layer bias value according to a given probability value.
+ if isinstance(module, (DFineForObjectDetection, DFineDecoder)):
+ if module.class_embed is not None:
+ for layer in module.class_embed:
+ prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
+ bias = float(-math.log((1 - prior_prob) / prior_prob))
+ nn.init.xavier_uniform_(layer.weight)
+ nn.init.constant_(layer.bias, bias)
+
+ if module.bbox_embed is not None:
+ for layer in module.bbox_embed:
+ nn.init.constant_(layer.layers[-1].weight, 0)
+ nn.init.constant_(layer.layers[-1].bias, 0)
+
+ if hasattr(module, "reg_scale"):
+ module.reg_scale.fill_(self.config.reg_scale)
+
+ if hasattr(module, "up"):
+ module.up.fill_(self.config.up)
+
+ if isinstance(module, DFineMultiscaleDeformableAttention):
+ nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
+ default_dtype = torch.get_default_dtype()
+ thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
+ 2.0 * math.pi / module.n_heads
+ )
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values
+ grid_init = grid_init.reshape(module.n_heads, 1, 2).tile([1, sum(module.num_points_list), 1])
+ scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1)
+ grid_init *= scaling
+ with torch.no_grad():
+ module.sampling_offsets.bias.data[...] = grid_init.flatten()
+
+ nn.init.constant_(module.attention_weights.weight.data, 0.0)
+ nn.init.constant_(module.attention_weights.bias.data, 0.0)
+
+ if isinstance(module, DFineModel):
+ prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
+ bias = float(-math.log((1 - prior_prob) / prior_prob))
+ nn.init.xavier_uniform_(module.enc_score_head.weight)
+ nn.init.constant_(module.enc_score_head.bias, bias)
+
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ if isinstance(module, DFineGate):
+ bias = float(-math.log((1 - 0.5) / 0.5))
+ init.constant_(module.gate.bias, bias)
+ init.constant_(module.gate.weight, 0)
+
+ if isinstance(module, DFineLQE):
+ init.constant_(module.reg_conf.layers[-1].bias, 0)
+ init.constant_(module.reg_conf.layers[-1].weight, 0)
+
+ if isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+
+ if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
+ nn.init.xavier_uniform_(module.weight_embedding.weight)
+ if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:
+ nn.init.xavier_uniform_(module.denoising_class_embed.weight)
+
+
+class DFineIntegral(nn.Module):
+ """
+ A static layer that calculates integral results from a distribution.
+
+ This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`,
+ where Pr(n) is the softmax probability vector representing the discrete
+ distribution, and W(n) is the non-uniform Weighting Function.
+
+ Args:
+ max_num_bins (int): Max number of the discrete bins. Default is 32.
+ It can be adjusted based on the dataset or task requirements.
+ """
+
+ def __init__(self, config: DFineConfig):
+ super().__init__()
+ self.max_num_bins = config.max_num_bins
+
+ def forward(self, pred_corners: torch.Tensor, project: torch.Tensor) -> torch.Tensor:
+ batch_size, num_queries, _ = pred_corners.shape
+ pred_corners = F.softmax(pred_corners.reshape(-1, self.max_num_bins + 1), dim=1)
+ pred_corners = F.linear(pred_corners, project.to(pred_corners.device)).reshape(-1, 4)
+ pred_corners = pred_corners.reshape(batch_size, num_queries, -1)
+ return pred_corners
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of the DFineDecoder. This class adds two attributes to
+ BaseModelOutputWithCrossAttentions, namely:
+ - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
+ - a stacked tensor of intermediate reference points.
+ """
+)
+class DFineDecoderOutput(ModelOutput):
+ r"""
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
+ Stacked intermediate logits (logits of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked initial reference points (initial reference points of each layer of the decoder).
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
+ used to compute the weighted average in the cross-attention heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
+ intermediate_logits: Optional[torch.FloatTensor] = None
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
+ intermediate_predicted_corners: Optional[torch.FloatTensor] = None
+ initial_reference_points: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+def inverse_sigmoid(x, eps=1e-5):
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+def weighting_function(max_num_bins: int, up: torch.Tensor, reg_scale: int) -> torch.Tensor:
+ """
+ Generates the non-uniform Weighting Function W(n) for bounding box regression.
+
+ Args:
+ max_num_bins (int): Max number of the discrete bins.
+ up (Tensor): Controls upper bounds of the sequence,
+ where maximum offset is ±up * H / W.
+ reg_scale (float): Controls the curvature of the Weighting Function.
+ Larger values result in flatter weights near the central axis W(max_num_bins/2)=0
+ and steeper weights at both ends.
+ Returns:
+ Tensor: Sequence of Weighting Function.
+ """
+ upper_bound1 = abs(up[0]) * abs(reg_scale)
+ upper_bound2 = abs(up[0]) * abs(reg_scale) * 2
+ step = (upper_bound1 + 1) ** (2 / (max_num_bins - 2))
+ left_values = [-((step) ** i) + 1 for i in range(max_num_bins // 2 - 1, 0, -1)]
+ right_values = [(step) ** i - 1 for i in range(1, max_num_bins // 2)]
+ values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2]
+ values = torch.cat(values, 0)
+ return values
+
+
+def distance2bbox(points, distance: torch.Tensor, reg_scale: float) -> torch.Tensor:
+ """
+ Decodes edge-distances into bounding box coordinates.
+
+ Args:
+ points (`torch.Tensor`):
+ (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height]
+ distance (`torch.Tensor`):
+ (batch_size, num_boxes, 4) or (num_boxes, 4), representing distances from the point to the left, top, right, and bottom boundaries.
+ reg_scale (`float`):
+ Controls the curvature of the Weighting Function.
+ Returns:
+ `torch.Tensor`: Bounding boxes in (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height]
+ """
+ reg_scale = abs(reg_scale)
+ top_left_x = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale)
+ top_left_y = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale)
+ bottom_right_x = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale)
+ bottom_right_y = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale)
+
+ bboxes = torch.stack([top_left_x, top_left_y, bottom_right_x, bottom_right_y], -1)
+
+ return corners_to_center_format(bboxes)
+
+
+class DFineDecoder(DFinePreTrainedModel):
+ """
+ D-FINE Decoder implementing Fine-grained Distribution Refinement (FDR).
+
+ This decoder refines object detection predictions through iterative updates across multiple layers,
+ utilizing attention mechanisms, location quality estimators, and distribution refinement techniques
+ to improve bounding box accuracy and robustness.
+ """
+
+ def __init__(self, config: DFineConfig):
+ super().__init__(config)
+ self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx
+
+ self.dropout = config.dropout
+ self.layers = nn.ModuleList(
+ [DFineDecoderLayer(config) for _ in range(config.decoder_layers)]
+ + [DFineDecoderLayer(config) for _ in range(config.decoder_layers - self.eval_idx - 1)]
+ )
+ self.query_pos_head = DFineMLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2)
+
+ # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
+ self.bbox_embed = None
+ self.class_embed = None
+ self.reg_scale = nn.Parameter(torch.tensor([config.reg_scale]), requires_grad=False)
+ self.max_num_bins = config.max_num_bins
+ self.d_model = config.d_model
+ self.layer_scale = config.layer_scale
+ self.pre_bbox_head = DFineMLP(config.hidden_size, config.hidden_size, 4, 3)
+ self.integral = DFineIntegral(config)
+ self.num_head = config.decoder_attention_heads
+ self.up = nn.Parameter(torch.tensor([config.up]), requires_grad=False)
+ self.lqe_layers = nn.ModuleList([DFineLQE(config) for _ in range(config.decoder_layers)])
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ encoder_hidden_states: torch.Tensor,
+ reference_points: torch.Tensor,
+ inputs_embeds: torch.Tensor,
+ spatial_shapes,
+ level_start_index=None,
+ spatial_shapes_list=None,
+ output_hidden_states=None,
+ encoder_attention_mask=None,
+ memory_mask=None,
+ output_attentions=None,
+ return_dict=None,
+ ) -> DFineDecoderOutput:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
+ The query embeddings that are passed into the decoder.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
+ in `[0, 1]`:
+ - 1 for pixels that are real (i.e. **not masked**),
+ - 0 for pixels that are padding (i.e. **masked**).
+ position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
+ Position embeddings that are added to the queries and keys in each self-attention layer.
+ reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
+ Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
+ spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
+ Spatial shapes of the feature maps.
+ level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
+ Indexes for the start of each feature level. In range `[0, sequence_length]`.
+ valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
+ Ratio of valid area in each feature level.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ intermediate = ()
+ intermediate_reference_points = ()
+ intermediate_logits = ()
+ intermediate_predicted_corners = ()
+ initial_reference_points = ()
+
+ output_detach = pred_corners_undetach = 0
+
+ project = weighting_function(self.max_num_bins, self.up, self.reg_scale)
+ ref_points_detach = F.sigmoid(reference_points)
+
+ for i, decoder_layer in enumerate(self.layers):
+ ref_points_input = ref_points_detach.unsqueeze(2)
+ query_pos_embed = self.query_pos_head(ref_points_detach).clamp(min=-10, max=10)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ output = decoder_layer(
+ hidden_states=hidden_states,
+ position_embeddings=query_pos_embed,
+ reference_points=ref_points_input,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = output[0]
+
+ if i == 0:
+ # Initial bounding box predictions with inverse sigmoid refinement
+ new_reference_points = F.sigmoid(self.pre_bbox_head(output[0]) + inverse_sigmoid(ref_points_detach))
+ ref_points_initial = new_reference_points.detach()
+
+ # Refine bounding box corners using FDR, integrating previous layer's corrections
+ if self.bbox_embed is not None:
+ pred_corners = self.bbox_embed[i](hidden_states + output_detach) + pred_corners_undetach
+ inter_ref_bbox = distance2bbox(
+ ref_points_initial, self.integral(pred_corners, project), self.reg_scale
+ )
+ pred_corners_undetach = pred_corners
+ ref_points_detach = inter_ref_bbox.detach()
+
+ output_detach = hidden_states.detach()
+
+ intermediate += (hidden_states,)
+
+ if self.class_embed is not None and (self.training or i == self.eval_idx):
+ scores = self.class_embed[i](hidden_states)
+ # Add initial logits and reference points with pre-bbox head
+ if i == 0:
+ intermediate_logits += (scores,)
+ intermediate_reference_points += (new_reference_points,)
+ # Lqe does not affect the performance here.
+ scores = self.lqe_layers[i](scores, pred_corners)
+ intermediate_logits += (scores,)
+ intermediate_reference_points += (inter_ref_bbox,)
+ initial_reference_points += (ref_points_initial,)
+ intermediate_predicted_corners += (pred_corners,)
+
+ if output_attentions:
+ all_self_attns += (output[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (output[2],)
+
+ # Keep batch_size as first dimension
+ intermediate = torch.stack(intermediate)
+ if self.class_embed is not None and self.bbox_embed is not None:
+ intermediate_logits = torch.stack(intermediate_logits, dim=1)
+ intermediate_predicted_corners = torch.stack(intermediate_predicted_corners, dim=1)
+ initial_reference_points = torch.stack(initial_reference_points, dim=1)
+ intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ intermediate,
+ intermediate_logits,
+ intermediate_reference_points,
+ intermediate_predicted_corners,
+ initial_reference_points,
+ all_hidden_states,
+ all_self_attns,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+
+ return DFineDecoderOutput(
+ last_hidden_state=hidden_states,
+ intermediate_hidden_states=intermediate,
+ intermediate_logits=intermediate_logits,
+ intermediate_reference_points=intermediate_reference_points,
+ intermediate_predicted_corners=intermediate_predicted_corners,
+ initial_reference_points=initial_reference_points,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of the RT-DETR encoder-decoder model.
+ """
+)
+class DFineModelOutput(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
+ Stacked intermediate logits (logits of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Initial reference points used for the first decoder layer.
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Initial reference points sent through the Transformer decoder.
+ enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
+ picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e.
+ foreground and background).
+ enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`):
+ Logits of predicted bounding boxes coordinates in the encoder stage.
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
+ foreground and background).
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Logits of predicted bounding boxes coordinates in the first stage.
+ denoising_meta_values (`dict`):
+ Extra dictionary for the denoising related values.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
+ intermediate_logits: Optional[torch.FloatTensor] = None
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
+ intermediate_predicted_corners: Optional[torch.FloatTensor] = None
+ initial_reference_points: Optional[torch.FloatTensor] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ init_reference_points: Optional[torch.FloatTensor] = None
+ enc_topk_logits: Optional[torch.FloatTensor] = None
+ enc_topk_bboxes: Optional[torch.FloatTensor] = None
+ enc_outputs_class: Optional[torch.FloatTensor] = None
+ enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
+ denoising_meta_values: Optional[dict] = None
+
+
+class DFineFrozenBatchNorm2d(nn.Module):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
+ torchvision.models.resnet[18,34,50,101] produce nans.
+ """
+
+ def __init__(self, n):
+ super().__init__()
+ self.register_buffer("weight", torch.ones(n))
+ self.register_buffer("bias", torch.zeros(n))
+ self.register_buffer("running_mean", torch.zeros(n))
+ self.register_buffer("running_var", torch.ones(n))
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ num_batches_tracked_key = prefix + "num_batches_tracked"
+ if num_batches_tracked_key in state_dict:
+ del state_dict[num_batches_tracked_key]
+
+ super()._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ )
+
+ def forward(self, x):
+ # move reshapes to the beginning
+ # to make it user-friendly
+ weight = self.weight.reshape(1, -1, 1, 1)
+ bias = self.bias.reshape(1, -1, 1, 1)
+ running_var = self.running_var.reshape(1, -1, 1, 1)
+ running_mean = self.running_mean.reshape(1, -1, 1, 1)
+ epsilon = 1e-5
+ scale = weight * (running_var + epsilon).rsqrt()
+ bias = bias - running_mean * scale
+ return x * scale + bias
+
+
+def replace_batch_norm(model):
+ r"""
+ Recursively replace all `torch.nn.BatchNorm2d` with `DFineFrozenBatchNorm2d`.
+
+ Args:
+ model (torch.nn.Module):
+ input model
+ """
+ for name, module in model.named_children():
+ if isinstance(module, nn.BatchNorm2d):
+ new_module = DFineFrozenBatchNorm2d(module.num_features)
+
+ if module.weight.device != torch.device("meta"):
+ new_module.weight.data.copy_(module.weight)
+ new_module.bias.data.copy_(module.bias)
+ new_module.running_mean.data.copy_(module.running_mean)
+ new_module.running_var.data.copy_(module.running_var)
+
+ model._modules[name] = new_module
+
+ if len(list(module.children())) > 0:
+ replace_batch_norm(module)
+
+
+class DFineConvEncoder(nn.Module):
+ """
+ Convolutional backbone using the modeling_d_fine_resnet.py.
+
+ nn.BatchNorm2d layers are replaced by DFineFrozenBatchNorm2d as defined above.
+ https://github.com/lyuwenyu/RT-DETR/blob/main/DFine_pytorch/src/nn/backbone/presnet.py#L142
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ backbone = load_backbone(config)
+
+ if config.freeze_backbone_batch_norms:
+ # replace batch norm by frozen batch norm
+ with torch.no_grad():
+ replace_batch_norm(backbone)
+ self.model = backbone
+ self.intermediate_channel_sizes = self.model.channels
+
+ def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
+ # send pixel_values through the model to get list of feature maps
+ features = self.model(pixel_values).feature_maps
+
+ out = []
+ for feature_map in features:
+ # downsample pixel_mask to match shape of corresponding feature_map
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
+ out.append((feature_map, mask))
+ return out
+
+
+def get_contrastive_denoising_training_group(
+ targets,
+ num_classes,
+ num_queries,
+ class_embed,
+ num_denoising_queries=100,
+ label_noise_ratio=0.5,
+ box_noise_scale=1.0,
+):
+ """
+ Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes.
+
+ Args:
+ targets (`list[dict]`):
+ The target objects, each containing 'class_labels' and 'boxes' for objects in an image.
+ num_classes (`int`):
+ Total number of classes in the dataset.
+ num_queries (`int`):
+ Number of query slots in the transformer.
+ class_embed (`callable`):
+ A function or a model layer to embed class labels.
+ num_denoising_queries (`int`, *optional*, defaults to 100):
+ Number of denoising queries.
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
+ Ratio of noise applied to labels.
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
+ Scale of noise applied to bounding boxes.
+ Returns:
+ `tuple` comprising various elements:
+ - **input_query_class** (`torch.FloatTensor`) --
+ Class queries with applied label noise.
+ - **input_query_bbox** (`torch.FloatTensor`) --
+ Bounding box queries with applied box noise.
+ - **attn_mask** (`torch.FloatTensor`) --
+ Attention mask for separating denoising and reconstruction queries.
+ - **denoising_meta_values** (`dict`) --
+ Metadata including denoising positive indices, number of groups, and split sizes.
+ """
+
+ if num_denoising_queries <= 0:
+ return None, None, None, None
+
+ num_ground_truths = [len(t["class_labels"]) for t in targets]
+ device = targets[0]["class_labels"].device
+
+ max_gt_num = max(num_ground_truths)
+ if max_gt_num == 0:
+ return None, None, None, None
+
+ num_groups_denoising_queries = num_denoising_queries // max_gt_num
+ num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries
+ # pad gt to max_num of a batch
+ batch_size = len(num_ground_truths)
+
+ input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device)
+ input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device)
+ pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device)
+
+ for i in range(batch_size):
+ num_gt = num_ground_truths[i]
+ if num_gt > 0:
+ input_query_class[i, :num_gt] = targets[i]["class_labels"]
+ input_query_bbox[i, :num_gt] = targets[i]["boxes"]
+ pad_gt_mask[i, :num_gt] = 1
+ # each group has positive and negative queries.
+ input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries])
+ input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1])
+ pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries])
+ # positive and negative mask
+ negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device)
+ negative_gt_mask[:, max_gt_num:] = 1
+ negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1])
+ positive_gt_mask = 1 - negative_gt_mask
+ # contrastive denoising training positive index
+ positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
+ denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
+ denoise_positive_idx = torch.split(
+ denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths]
+ )
+ # total denoising queries
+ num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries)
+
+ if label_noise_ratio > 0:
+ mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
+ # randomly put a new one here
+ new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
+ input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
+
+ if box_noise_scale > 0:
+ known_bbox = center_to_corners_format(input_query_bbox)
+ diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
+ rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
+ rand_part = torch.rand_like(input_query_bbox)
+ rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
+ rand_part *= rand_sign
+ known_bbox += rand_part * diff
+ known_bbox.clip_(min=0.0, max=1.0)
+ input_query_bbox = corners_to_center_format(known_bbox)
+ input_query_bbox = inverse_sigmoid(input_query_bbox)
+
+ input_query_class = class_embed(input_query_class)
+
+ target_size = num_denoising_queries + num_queries
+ attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
+ # match query cannot see the reconstruction
+ attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
+
+ # reconstructions cannot see each other
+ for i in range(num_groups_denoising_queries):
+ idx_block_start = max_gt_num * 2 * i
+ idx_block_end = max_gt_num * 2 * (i + 1)
+ attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
+ attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
+
+ denoising_meta_values = {
+ "dn_positive_idx": denoise_positive_idx,
+ "dn_num_group": num_groups_denoising_queries,
+ "dn_num_split": [num_denoising_queries, num_queries],
+ }
+
+ return input_query_class, input_query_bbox, attn_mask, denoising_meta_values
+
+
+@auto_docstring(
+ custom_intro="""
+ RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top.
+ """
+)
+class DFineModel(DFinePreTrainedModel):
+ def __init__(self, config: DFineConfig):
+ super().__init__(config)
+
+ # Create backbone
+ self.backbone = DFineConvEncoder(config)
+ intermediate_channel_sizes = self.backbone.intermediate_channel_sizes
+ num_backbone_outs = len(config.decoder_in_channels)
+ encoder_input_proj_list = []
+ for _ in range(num_backbone_outs):
+ in_channels = intermediate_channel_sizes[_]
+ encoder_input_proj_list.append(
+ nn.Sequential(
+ nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False),
+ nn.BatchNorm2d(config.encoder_hidden_dim),
+ )
+ )
+ self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list)
+ self.encoder = DFineHybridEncoder(config=config)
+
+ # denoising part
+ if config.num_denoising > 0:
+ self.denoising_class_embed = nn.Embedding(
+ config.num_labels + 1, config.d_model, padding_idx=config.num_labels
+ )
+
+ # decoder embedding
+ if config.learn_initial_query:
+ self.weight_embedding = nn.Embedding(config.num_queries, config.d_model)
+
+ # encoder head
+ self.enc_output = nn.Sequential(
+ nn.Linear(config.d_model, config.d_model),
+ nn.LayerNorm(config.d_model, eps=config.layer_norm_eps),
+ )
+ self.enc_score_head = nn.Linear(config.d_model, config.num_labels)
+ self.enc_bbox_head = DFineMLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3)
+
+ # init encoder output anchors and valid_mask
+ if config.anchor_image_size:
+ self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype)
+ num_backbone_outs = len(config.decoder_in_channels)
+ decoder_input_proj_list = []
+ for _ in range(num_backbone_outs):
+ in_channels = config.decoder_in_channels[_]
+ decoder_input_proj_list.append(
+ nn.Sequential(
+ nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False),
+ nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
+ )
+ )
+ for _ in range(config.num_feature_levels - num_backbone_outs):
+ decoder_input_proj_list.append(
+ nn.Sequential(
+ nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
+ )
+ )
+ in_channels = config.d_model
+ self.decoder = DFineDecoder(config)
+ decoder_input_proj = []
+ in_channels = config.decoder_in_channels[-1]
+ for _ in range(num_backbone_outs):
+ if config.hidden_size == config.decoder_in_channels[-1]:
+ decoder_input_proj.append(nn.Identity())
+ else:
+ conv = nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False)
+ batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps)
+ decoder_input_proj.append(nn.Sequential(conv, batchnorm))
+ for _ in range(config.num_feature_levels - num_backbone_outs):
+ if config.hidden_size == config.decoder_in_channels[-1]:
+ decoder_input_proj.append(nn.Identity())
+ else:
+ conv = nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False)
+ batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps)
+ decoder_input_proj.append(nn.Sequential(conv, batchnorm))
+ self.decoder_input_proj = nn.ModuleList(decoder_input_proj)
+
+ self.post_init()
+
+ def get_encoder(self):
+ return self.encoder
+
+ def freeze_backbone(self):
+ for param in self.backbone.parameters():
+ param.requires_grad_(False)
+
+ def unfreeze_backbone(self):
+ for param in self.backbone.parameters():
+ param.requires_grad_(True)
+
+ @compile_compatible_method_lru_cache(maxsize=32)
+ def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
+ if spatial_shapes is None:
+ spatial_shapes = [
+ [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)]
+ for s in self.config.feat_strides
+ ]
+ anchors = []
+ for level, (height, width) in enumerate(spatial_shapes):
+ grid_y, grid_x = torch.meshgrid(
+ torch.arange(end=height, device=device).to(dtype),
+ torch.arange(end=width, device=device).to(dtype),
+ indexing="ij",
+ )
+ grid_xy = torch.stack([grid_x, grid_y], -1)
+ grid_xy = grid_xy.unsqueeze(0) + 0.5
+ grid_xy[..., 0] /= width
+ grid_xy[..., 1] /= height
+ wh = torch.ones_like(grid_xy) * grid_size * (2.0**level)
+ anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4))
+ # define the valid range for anchor coordinates
+ eps = 1e-2
+ anchors = torch.concat(anchors, 1)
+ valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
+ anchors = torch.log(anchors / (1 - anchors))
+ anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device))
+
+ return anchors, valid_mask
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[list[dict]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.FloatTensor], DFineModelOutput]:
+ r"""
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
+ can choose to directly pass a flattened representation of an image.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
+ embedded representation.
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, DFineModel
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("PekingU/DFine_r50vd")
+ >>> model = DFineModel.from_pretrained("PekingU/DFine_r50vd")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ >>> list(last_hidden_states.shape)
+ [1, 300, 256]
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size, num_channels, height, width = pixel_values.shape
+ device = pixel_values.device
+
+ if pixel_mask is None:
+ pixel_mask = torch.ones(((batch_size, height, width)), device=device)
+
+ features = self.backbone(pixel_values, pixel_mask)
+
+ proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ proj_feats,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if output_hidden_states else None,
+ attentions=encoder_outputs[2]
+ if len(encoder_outputs) > 2
+ else encoder_outputs[1]
+ if output_attentions
+ else None,
+ )
+
+ # Equivalent to def _get_encoder_input
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/DFine_pytorch/src/zoo/DFine/DFine_decoder.py#L412
+ sources = []
+ for level, source in enumerate(encoder_outputs[0]):
+ sources.append(self.decoder_input_proj[level](source))
+
+ # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
+ if self.config.num_feature_levels > len(sources):
+ _len_sources = len(sources)
+ sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1])
+ for i in range(_len_sources + 1, self.config.num_feature_levels):
+ sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1]))
+
+ # Prepare encoder inputs (by flattening)
+ source_flatten = []
+ spatial_shapes_list = []
+ spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long)
+ for level, source in enumerate(sources):
+ height, width = source.shape[-2:]
+ spatial_shapes[level, 0] = height
+ spatial_shapes[level, 1] = width
+ spatial_shapes_list.append((height, width))
+ source = source.flatten(2).transpose(1, 2)
+ source_flatten.append(source)
+ source_flatten = torch.cat(source_flatten, 1)
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
+
+ # prepare denoising training
+ if self.training and self.config.num_denoising > 0 and labels is not None:
+ (
+ denoising_class,
+ denoising_bbox_unact,
+ attention_mask,
+ denoising_meta_values,
+ ) = get_contrastive_denoising_training_group(
+ targets=labels,
+ num_classes=self.config.num_labels,
+ num_queries=self.config.num_queries,
+ class_embed=self.denoising_class_embed,
+ num_denoising_queries=self.config.num_denoising,
+ label_noise_ratio=self.config.label_noise_ratio,
+ box_noise_scale=self.config.box_noise_scale,
+ )
+ else:
+ denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None
+
+ batch_size = len(source_flatten)
+ device = source_flatten.device
+ dtype = source_flatten.dtype
+
+ # prepare input for decoder
+ if self.training or self.config.anchor_image_size is None:
+ # Pass spatial_shapes as tuple to make it hashable and make sure
+ # lru_cache is working for generate_anchors()
+ spatial_shapes_tuple = tuple(spatial_shapes_list)
+ anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype)
+ else:
+ anchors, valid_mask = self.anchors, self.valid_mask
+ anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype)
+
+ # use the valid_mask to selectively retain values in the feature map where the mask is `True`
+ memory = valid_mask.to(source_flatten.dtype) * source_flatten
+
+ output_memory = self.enc_output(memory)
+
+ enc_outputs_class = self.enc_score_head(output_memory)
+ enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors
+
+ _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1)
+
+ reference_points_unact = enc_outputs_coord_logits.gather(
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1])
+ )
+
+ enc_topk_bboxes = F.sigmoid(reference_points_unact)
+ if denoising_bbox_unact is not None:
+ reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)
+
+ enc_topk_logits = enc_outputs_class.gather(
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])
+ )
+
+ # extract region features
+ if self.config.learn_initial_query:
+ target = self.weight_embedding.tile([batch_size, 1, 1])
+ else:
+ target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
+ target = target.detach()
+
+ if denoising_class is not None:
+ target = torch.concat([denoising_class, target], 1)
+
+ init_reference_points = reference_points_unact.detach()
+
+ # decoder
+ decoder_outputs = self.decoder(
+ inputs_embeds=target,
+ encoder_hidden_states=source_flatten,
+ encoder_attention_mask=attention_mask,
+ reference_points=init_reference_points,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ level_start_index=level_start_index,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ enc_outputs = tuple(
+ value
+ for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits]
+ if value is not None
+ )
+ dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values])
+ tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs
+
+ return tuple_outputs
+
+ return DFineModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
+ intermediate_logits=decoder_outputs.intermediate_logits,
+ intermediate_reference_points=decoder_outputs.intermediate_reference_points,
+ intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners,
+ initial_reference_points=decoder_outputs.initial_reference_points,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ init_reference_points=init_reference_points,
+ enc_topk_logits=enc_topk_logits,
+ enc_topk_bboxes=enc_topk_bboxes,
+ enc_outputs_class=enc_outputs_class,
+ enc_outputs_coord_logits=enc_outputs_coord_logits,
+ denoising_meta_values=denoising_meta_values,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`DFineForObjectDetection`].
+ """
+)
+class DFineObjectDetectionOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
+ scale-invariant IoU loss.
+ loss_dict (`Dict`, *optional*):
+ A dictionary containing the individual losses. Useful for logging.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
+ Classification logits (including no-object) for all queries.
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
+ possible padding). You can use [`~DFineImageProcessor.post_process_object_detection`] to retrieve the
+ unnormalized (absolute) bounding boxes.
+ auxiliary_outputs (`list[Dict]`, *optional*):
+ Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
+ `pred_boxes`) for each decoder layer.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`):
+ Stacked intermediate logits (logits of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked initial reference points (initial reference points of each layer of the decoder).
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Initial reference points sent through the Transformer decoder.
+ enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Logits of predicted bounding boxes coordinates in the encoder.
+ enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Logits of predicted bounding boxes coordinates in the encoder.
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
+ foreground and background).
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Logits of predicted bounding boxes coordinates in the first stage.
+ denoising_meta_values (`dict`):
+ Extra dictionary for the denoising related values
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ loss_dict: Optional[dict] = None
+ logits: Optional[torch.FloatTensor] = None
+ pred_boxes: Optional[torch.FloatTensor] = None
+ auxiliary_outputs: Optional[list[dict]] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
+ intermediate_logits: Optional[torch.FloatTensor] = None
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
+ intermediate_predicted_corners: Optional[torch.FloatTensor] = None
+ initial_reference_points: Optional[torch.FloatTensor] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ init_reference_points: Optional[tuple[torch.FloatTensor]] = None
+ enc_topk_logits: Optional[torch.FloatTensor] = None
+ enc_topk_bboxes: Optional[torch.FloatTensor] = None
+ enc_outputs_class: Optional[torch.FloatTensor] = None
+ enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
+ denoising_meta_values: Optional[dict] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further
+ decoded into scores and classes.
+ """
+)
+class DFineForObjectDetection(DFinePreTrainedModel):
+ # When using clones, all layers > 0 will be clones, but layer 0 *is* required
+ _tied_weights_keys = ["bbox_embed", "class_embed"]
+ # We can't initialize the model on meta device as some weights are modified during the initialization
+ _no_split_modules = None
+
+ def __init__(self, config: DFineConfig):
+ super().__init__(config)
+
+ # D-FINE encoder-decoder model
+ self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx
+ self.model = DFineModel(config)
+ scaled_dim = round(config.layer_scale * config.hidden_size)
+ num_pred = config.decoder_layers
+ self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)])
+ self.bbox_embed = nn.ModuleList(
+ [
+ DFineMLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3)
+ for _ in range(self.eval_idx + 1)
+ ]
+ + [
+ DFineMLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3)
+ for _ in range(config.decoder_layers - self.eval_idx - 1)
+ ]
+ )
+
+ # here self.model.decoder.bbox_embed is null, but not self.bbox_embed
+ self.model.decoder.class_embed = self.class_embed
+ self.model.decoder.bbox_embed = self.bbox_embed
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @torch.jit.unused
+ def _set_aux_loss(self, outputs_class, outputs_coord):
+ # this is a workaround to make torchscript happy, as torchscript
+ # doesn't support dictionary with non-homogeneous values, such
+ # as a dict having both a Tensor and a list.
+ return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[list[dict]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.FloatTensor], DFineObjectDetectionOutput]:
+ r"""
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers.image_utils import load_image
+ >>> from transformers import AutoImageProcessor, DFineForObjectDetection
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = load_image(url)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-xlarge-coco")
+ >>> model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-xlarge-coco")
+
+ >>> # prepare image for the model
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> # forward pass
+ >>> outputs = model(**inputs)
+
+ >>> logits = outputs.logits
+ >>> list(logits.shape)
+ [1, 300, 80]
+
+ >>> boxes = outputs.pred_boxes
+ >>> list(boxes.shape)
+ [1, 300, 4]
+
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
+ >>> target_sizes = torch.tensor([image.size[::-1]])
+ >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)
+ >>> result = results[0] # first image in batch
+
+ >>> for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
+ ... box = [round(i, 2) for i in box.tolist()]
+ ... print(
+ ... f"Detected {model.config.id2label[label.item()]} with confidence "
+ ... f"{round(score.item(), 3)} at location {box}"
+ ... )
+ Detected cat with confidence 0.958 at location [344.49, 23.4, 639.84, 374.27]
+ Detected cat with confidence 0.956 at location [11.71, 53.52, 316.64, 472.33]
+ Detected remote with confidence 0.947 at location [40.46, 73.7, 175.62, 117.57]
+ Detected sofa with confidence 0.918 at location [0.59, 1.88, 640.25, 474.74]
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ pixel_values,
+ pixel_mask=pixel_mask,
+ encoder_outputs=encoder_outputs,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ labels=labels,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ denoising_meta_values = (
+ outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None
+ )
+
+ outputs_class = outputs.intermediate_logits if return_dict else outputs[2]
+ outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3]
+ predicted_corners = outputs.intermediate_predicted_corners if return_dict else outputs[4]
+ initial_reference_points = outputs.initial_reference_points if return_dict else outputs[5]
+
+ logits = outputs_class[:, -1]
+ pred_boxes = outputs_coord[:, -1]
+
+ loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
+ if labels is not None:
+ enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5]
+ enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4]
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
+ logits,
+ labels,
+ self.device,
+ pred_boxes,
+ self.config,
+ outputs_class,
+ outputs_coord,
+ enc_topk_logits=enc_topk_logits,
+ enc_topk_bboxes=enc_topk_bboxes,
+ denoising_meta_values=denoising_meta_values,
+ predicted_corners=predicted_corners,
+ initial_reference_points=initial_reference_points,
+ **kwargs,
+ )
+
+ if not return_dict:
+ if auxiliary_outputs is not None:
+ output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs
+ else:
+ output = (logits, pred_boxes) + outputs
+ return ((loss, loss_dict) + output) if loss is not None else output
+
+ return DFineObjectDetectionOutput(
+ loss=loss,
+ loss_dict=loss_dict,
+ logits=logits,
+ pred_boxes=pred_boxes,
+ auxiliary_outputs=auxiliary_outputs,
+ last_hidden_state=outputs.last_hidden_state,
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
+ intermediate_logits=outputs.intermediate_logits,
+ intermediate_reference_points=outputs.intermediate_reference_points,
+ intermediate_predicted_corners=outputs.intermediate_predicted_corners,
+ initial_reference_points=outputs.initial_reference_points,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ init_reference_points=outputs.init_reference_points,
+ enc_topk_logits=outputs.enc_topk_logits,
+ enc_topk_bboxes=outputs.enc_topk_bboxes,
+ enc_outputs_class=outputs.enc_outputs_class,
+ enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
+ denoising_meta_values=outputs.denoising_meta_values,
+ )
+
+
+# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
+class DFineMLPPredictionHead(nn.Module):
+ """
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
+ height and width of a bounding box w.r.t. an image.
+
+ Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
+ Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/DFine_paddle/ppdet/modeling/transformers/utils.py#L453
+
+ """
+
+ def __init__(self, config, input_dim, d_model, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [d_model] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+class DFineMLP(nn.Module):
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: str = "relu"):
+ super().__init__()
+ self.num_layers = num_layers
+ hidden_dims = [hidden_dim] * (num_layers - 1)
+ input_dims = [input_dim] + hidden_dims
+ output_dims = hidden_dims + [output_dim]
+ self.layers = nn.ModuleList(nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, output_dims))
+ self.act = ACT2CLS[act]()
+
+ def forward(self, stat_features: torch.Tensor) -> torch.Tensor:
+ for i, layer in enumerate(self.layers):
+ stat_features = self.act(layer(stat_features)) if i < self.num_layers - 1 else layer(stat_features)
+ return stat_features
+
+
+class DFineLQE(nn.Module):
+ def __init__(self, config: DFineConfig):
+ super().__init__()
+ self.top_prob_values = config.top_prob_values
+ self.max_num_bins = config.max_num_bins
+ self.reg_conf = DFineMLP(4 * (self.top_prob_values + 1), config.lqe_hidden_dim, 1, config.lqe_layers)
+
+ def forward(self, scores: torch.Tensor, pred_corners: torch.Tensor) -> torch.Tensor:
+ batch_size, length, _ = pred_corners.size()
+ prob = F.softmax(pred_corners.reshape(batch_size, length, 4, self.max_num_bins + 1), dim=-1)
+ prob_topk, _ = prob.topk(self.top_prob_values, dim=-1)
+ stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1)
+ quality_score = self.reg_conf(stat.reshape(batch_size, length, -1))
+ scores = scores + quality_score
+ return scores
+
+
+class DFineConvNormLayer(nn.Module):
+ def __init__(
+ self,
+ config: DFineConfig,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int,
+ groups: int = 1,
+ padding: Optional[int] = None,
+ activation: Optional[str] = None,
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ groups=groups,
+ padding=(kernel_size - 1) // 2 if padding is None else padding,
+ bias=False,
+ )
+ self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
+
+ def forward(self, hidden_state):
+ hidden_state = self.conv(hidden_state)
+ hidden_state = self.norm(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class DFineRepVggBlock(nn.Module):
+ """
+ RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
+ """
+
+ def __init__(self, config: DFineConfig, in_channels: int, out_channels: int):
+ super().__init__()
+
+ activation = config.activation_function
+ hidden_channels = in_channels
+ self.conv1 = DFineConvNormLayer(config, hidden_channels, out_channels, 3, 1, padding=1)
+ self.conv2 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, padding=0)
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
+
+ def forward(self, x):
+ y = self.conv1(x) + self.conv2(x)
+ return self.activation(y)
+
+
+class DFineCSPRepLayer(nn.Module):
+ """
+ Cross Stage Partial (CSP) network layer with RepVGG blocks.
+ """
+
+ def __init__(
+ self, config: DFineConfig, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0
+ ):
+ super().__init__()
+ activation = config.activation_function
+
+ hidden_channels = int(out_channels * expansion)
+ self.conv1 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
+ self.conv2 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
+ self.bottlenecks = nn.ModuleList(
+ [DFineRepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)]
+ )
+ if hidden_channels != out_channels:
+ self.conv3 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
+ else:
+ self.conv3 = nn.Identity()
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state_1 = self.conv1(hidden_state)
+ for bottleneck in self.bottlenecks:
+ hidden_state_1 = bottleneck(hidden_state_1)
+ hidden_state_2 = self.conv2(hidden_state)
+ hidden_state_3 = self.conv3(hidden_state_1 + hidden_state_2)
+ return hidden_state_3
+
+
+class DFineRepNCSPELAN4(nn.Module):
+ def __init__(self, config: DFineConfig, act: str = "silu", numb_blocks: int = 3):
+ super().__init__()
+ conv1_dim = config.encoder_hidden_dim * 2
+ conv2_dim = config.encoder_hidden_dim
+ conv3_dim = config.encoder_hidden_dim * 2
+ conv4_dim = round(config.hidden_expansion * config.encoder_hidden_dim // 2)
+ self.conv_dim = conv3_dim // 2
+ self.conv1 = DFineConvNormLayer(config, conv1_dim, conv3_dim, 1, 1, activation=act)
+ self.csp_rep1 = DFineCSPRepLayer(config, conv3_dim // 2, conv4_dim, num_blocks=numb_blocks)
+ self.conv2 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
+ self.csp_rep2 = DFineCSPRepLayer(config, conv4_dim, conv4_dim, num_blocks=numb_blocks)
+ self.conv3 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
+ self.conv4 = DFineConvNormLayer(config, conv3_dim + (2 * conv4_dim), conv2_dim, 1, 1, activation=act)
+
+ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
+ # Split initial features into two branches after first convolution
+ split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1))
+
+ # Process branches sequentially
+ branch1 = self.csp_rep1(split_features[-1])
+ branch1 = self.conv2(branch1)
+ branch2 = self.csp_rep2(branch1)
+ branch2 = self.conv3(branch2)
+
+ split_features.extend([branch1, branch2])
+ merged_features = torch.cat(split_features, 1)
+ merged_features = self.conv4(merged_features)
+ return merged_features
+
+
+class DFineSCDown(nn.Module):
+ def __init__(self, config: DFineConfig, kernel_size: int, stride: int):
+ super().__init__()
+ self.conv1 = DFineConvNormLayer(config, config.encoder_hidden_dim, config.encoder_hidden_dim, 1, 1)
+ self.conv2 = DFineConvNormLayer(
+ config,
+ config.encoder_hidden_dim,
+ config.encoder_hidden_dim,
+ kernel_size,
+ stride,
+ config.encoder_hidden_dim,
+ )
+
+ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
+ input_features = self.conv1(input_features)
+ input_features = self.conv2(input_features)
+ return input_features
+
+
+class DFineEncoderLayer(nn.Module):
+ def __init__(self, config: DFineConfig):
+ super().__init__()
+ self.normalize_before = config.normalize_before
+
+ # self-attention
+ self.self_attn = DFineMultiheadAttention(
+ embed_dim=config.encoder_hidden_dim,
+ num_heads=config.num_attention_heads,
+ dropout=config.dropout,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.encoder_activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim)
+ self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ **kwargs,
+ ):
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
+ values.
+ position_embeddings (`torch.FloatTensor`, *optional*):
+ Object queries (also called content embeddings), to be added to the hidden states.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ if self.normalize_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_embeddings=position_embeddings,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ if not self.normalize_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ if self.normalize_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+ residual = hidden_states
+
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+
+ hidden_states = self.fc2(hidden_states)
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ hidden_states = residual + hidden_states
+ if not self.normalize_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ if self.training:
+ if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class DFineEncoder(nn.Module):
+ def __init__(self, config: DFineConfig):
+ super().__init__()
+
+ self.layers = nn.ModuleList([DFineEncoderLayer(config) for _ in range(config.encoder_layers)])
+
+ def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor:
+ hidden_states = src
+ for layer in self.layers:
+ hidden_states = layer(
+ hidden_states,
+ attention_mask=src_mask,
+ position_embeddings=pos_embed,
+ output_attentions=output_attentions,
+ )
+ return hidden_states
+
+
+class DFineHybridEncoder(nn.Module):
+ """
+ Decoder consisting of a projection layer, a set of `DFineEncoder`, a top-down Feature Pyramid Network
+ (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://huggingface.co/papers/2304.08069
+
+ Args:
+ config: DFineConfig
+ """
+
+ def __init__(self, config: DFineConfig):
+ super().__init__()
+ self.config = config
+ self.in_channels = config.encoder_in_channels
+ self.num_fpn_stages = len(self.in_channels) - 1
+ self.feat_strides = config.feat_strides
+ self.encoder_hidden_dim = config.encoder_hidden_dim
+ self.encode_proj_layers = config.encode_proj_layers
+ self.positional_encoding_temperature = config.positional_encoding_temperature
+ self.eval_size = config.eval_size
+ self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
+ self.out_strides = self.feat_strides
+
+ # encoder transformer
+ self.encoder = nn.ModuleList([DFineEncoder(config) for _ in range(len(self.encode_proj_layers))])
+ # top-down fpn
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_blocks = nn.ModuleList()
+ for _ in range(len(self.in_channels) - 1, 0, -1):
+ lateral_layer = DFineConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1)
+ self.lateral_convs.append(lateral_layer)
+ num_blocks = round(3 * config.depth_mult)
+ fpn_layer = DFineRepNCSPELAN4(config, numb_blocks=num_blocks)
+ self.fpn_blocks.append(fpn_layer)
+
+ # bottom-up pan
+ self.downsample_convs = nn.ModuleList()
+ self.pan_blocks = nn.ModuleList()
+ for _ in range(len(self.in_channels) - 1):
+ self.downsample_convs.append(DFineSCDown(config, 3, 2))
+ num_blocks = round(3 * config.depth_mult)
+ self.pan_blocks.append(DFineRepNCSPELAN4(config, numb_blocks=num_blocks))
+
+ @staticmethod
+ def build_2d_sincos_position_embedding(
+ width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
+ ):
+ grid_w = torch.arange(torch_int(width), device=device).to(dtype)
+ grid_h = torch.arange(torch_int(height), device=device).to(dtype)
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
+ if embed_dim % 4 != 0:
+ raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
+ pos_dim = embed_dim // 4
+ omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
+ omega = 1.0 / (temperature**omega)
+
+ out_w = grid_w.flatten()[..., None] @ omega[None]
+ out_h = grid_h.flatten()[..., None] @ omega[None]
+
+ return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :]
+
+ def forward(
+ self,
+ inputs_embeds=None,
+ attention_mask=None,
+ position_embeddings=None,
+ spatial_shapes=None,
+ level_start_index=None,
+ valid_ratios=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
+ - 1 for pixel features that are real (i.e. **not masked**),
+ - 0 for pixel features that are padding (i.e. **masked**).
+ [What are attention masks?](../glossary#attention-mask)
+ position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Position embeddings that are added to the queries and keys in each self-attention layer.
+ spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
+ Spatial shapes of each feature map.
+ level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
+ Starting index of each feature map.
+ valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
+ Ratio of valid area in each feature level.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states = inputs_embeds
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # encoder
+ if self.config.encoder_layers > 0:
+ for i, enc_ind in enumerate(self.encode_proj_layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states[enc_ind],)
+ height, width = hidden_states[enc_ind].shape[2:]
+ # flatten [batch, channel, height, width] to [batch, height*width, channel]
+ src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1)
+ if self.training or self.eval_size is None:
+ pos_embed = self.build_2d_sincos_position_embedding(
+ width,
+ height,
+ self.encoder_hidden_dim,
+ self.positional_encoding_temperature,
+ device=src_flatten.device,
+ dtype=src_flatten.dtype,
+ )
+ else:
+ pos_embed = None
+
+ layer_outputs = self.encoder[i](
+ src_flatten,
+ pos_embed=pos_embed,
+ output_attentions=output_attentions,
+ )
+ hidden_states[enc_ind] = (
+ layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous()
+ )
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states[enc_ind],)
+
+ # top-down FPN
+ fpn_feature_maps = [hidden_states[-1]]
+ for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
+ backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1]
+ top_fpn_feature_map = fpn_feature_maps[-1]
+ # apply lateral block
+ top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
+ fpn_feature_maps[-1] = top_fpn_feature_map
+ # apply fpn block
+ top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest")
+ fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1)
+ new_fpn_feature_map = fpn_block(fused_feature_map)
+ fpn_feature_maps.append(new_fpn_feature_map)
+
+ fpn_feature_maps = fpn_feature_maps[::-1]
+
+ # bottom-up PAN
+ pan_feature_maps = [fpn_feature_maps[0]]
+ for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)):
+ top_pan_feature_map = pan_feature_maps[-1]
+ fpn_feature_map = fpn_feature_maps[idx + 1]
+ downsampled_feature_map = downsample_conv(top_pan_feature_map)
+ fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1)
+ new_pan_feature_map = pan_block(fused_feature_map)
+ pan_feature_maps.append(new_pan_feature_map)
+
+ if not return_dict:
+ return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+__all__ = ["DFineModel", "DFinePreTrainedModel", "DFineForObjectDetection"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/d_fine/modular_d_fine.py b/venv/lib/python3.13/site-packages/transformers/models/d_fine/modular_d_fine.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a41fb23308eefb77878b3a3f6bee4a25992eb10
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/d_fine/modular_d_fine.py
@@ -0,0 +1,1229 @@
+# coding=utf-8
+# Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Any, Optional
+
+import torch
+import torch.nn.functional as F
+import torch.nn.init as init
+from torch import nn
+
+from ...activations import ACT2CLS
+from ...configuration_utils import PretrainedConfig
+from ...image_transforms import corners_to_center_format
+from ...utils import is_torchdynamo_compiling, logging
+from ...utils.backbone_utils import verify_backbone_config_arguments
+from ..auto import CONFIG_MAPPING
+from ..rt_detr.modeling_rt_detr import (
+ RTDetrConvNormLayer,
+ RTDetrDecoder,
+ RTDetrDecoderLayer,
+ RTDetrDecoderOutput,
+ RTDetrEncoder,
+ RTDetrForObjectDetection,
+ RTDetrHybridEncoder,
+ RTDetrMLPPredictionHead,
+ RTDetrModel,
+ RTDetrPreTrainedModel,
+ RTDetrRepVggBlock,
+ inverse_sigmoid,
+)
+from ..rt_detr_v2.modeling_rt_detr_v2 import multi_scale_deformable_attention_v2
+
+
+logger = logging.get_logger(__name__)
+
+
+# TODO: Attribute map assignment logic should be fixed in modular
+# as well as super() call parsing because otherwise we cannot re-write args after initialization
+class DFineConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`DFineModel`]. It is used to instantiate a D-FINE
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of D-FINE-X-COCO "[ustc-community/dfine-xlarge-coco"](https://huggingface.co/ustc-community/dfine-xlarge-coco").
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ initializer_range (`float`, *optional*, defaults to 0.01):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_bias_prior_prob (`float`, *optional*):
+ The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
+ If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ batch_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the batch normalization layers.
+ backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
+ The configuration of the backbone model.
+ backbone (`str`, *optional*):
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
+ Whether to use pretrained weights for the backbone.
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
+ library.
+ freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
+ Whether to freeze the batch normalization layers in the backbone.
+ backbone_kwargs (`dict`, *optional*):
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
+ encoder_hidden_dim (`int`, *optional*, defaults to 256):
+ Dimension of the layers in hybrid encoder.
+ encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
+ Multi level features input for encoder.
+ feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`):
+ Strides used in each feature map.
+ encoder_layers (`int`, *optional*, defaults to 1):
+ Total of layers to be used by the encoder.
+ encoder_ffn_dim (`int`, *optional*, defaults to 1024):
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The ratio for all dropout layers.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`):
+ Indexes of the projected layers to be used in the encoder.
+ positional_encoding_temperature (`int`, *optional*, defaults to 10000):
+ The temperature parameter used to create the positional encodings.
+ encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ activation_function (`str`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ eval_size (`tuple[int, int]`, *optional*):
+ Height and width used to computes the effective height and width of the position embeddings after taking
+ into account the stride.
+ normalize_before (`bool`, *optional*, defaults to `False`):
+ Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
+ feed-forward modules.
+ hidden_expansion (`float`, *optional*, defaults to 1.0):
+ Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
+ d_model (`int`, *optional*, defaults to 256):
+ Dimension of the layers exclude hybrid encoder.
+ num_queries (`int`, *optional*, defaults to 300):
+ Number of object queries.
+ decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
+ Multi level features dimension for decoder
+ decoder_ffn_dim (`int`, *optional*, defaults to 1024):
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
+ num_feature_levels (`int`, *optional*, defaults to 3):
+ The number of input feature levels.
+ decoder_n_points (`int`, *optional*, defaults to 4):
+ The number of sampled keys in each feature level for each attention head in the decoder.
+ decoder_layers (`int`, *optional*, defaults to 6):
+ Number of decoder layers.
+ decoder_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
+ The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ num_denoising (`int`, *optional*, defaults to 100):
+ The total number of denoising tasks or queries to be used for contrastive denoising.
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
+ The fraction of denoising labels to which random noise should be added.
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
+ Scale or magnitude of noise to be added to the bounding boxes.
+ learn_initial_query (`bool`, *optional*, defaults to `False`):
+ Indicates whether the initial query embeddings for the decoder should be learned during training
+ anchor_image_size (`tuple[int, int]`, *optional*):
+ Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
+ with_box_refine (`bool`, *optional*, defaults to `True`):
+ Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
+ based on the predictions from the previous layer.
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+ Whether the architecture has an encoder decoder structure.
+ matcher_alpha (`float`, *optional*, defaults to 0.25):
+ Parameter alpha used by the Hungarian Matcher.
+ matcher_gamma (`float`, *optional*, defaults to 2.0):
+ Parameter gamma used by the Hungarian Matcher.
+ matcher_class_cost (`float`, *optional*, defaults to 2.0):
+ The relative weight of the class loss used by the Hungarian Matcher.
+ matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
+ The relative weight of the bounding box loss used by the Hungarian Matcher.
+ matcher_giou_cost (`float`, *optional*, defaults to 2.0):
+ The relative weight of the giou loss of used by the Hungarian Matcher.
+ use_focal_loss (`bool`, *optional*, defaults to `True`):
+ Parameter informing if focal focal should be used.
+ auxiliary_loss (`bool`, *optional*, defaults to `True`):
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
+ focal_loss_alpha (`float`, *optional*, defaults to 0.75):
+ Parameter alpha used to compute the focal loss.
+ focal_loss_gamma (`float`, *optional*, defaults to 2.0):
+ Parameter gamma used to compute the focal loss.
+ weight_loss_vfl (`float`, *optional*, defaults to 1.0):
+ Relative weight of the varifocal loss in the object detection loss.
+ weight_loss_bbox (`float`, *optional*, defaults to 5.0):
+ Relative weight of the L1 bounding box loss in the object detection loss.
+ weight_loss_giou (`float`, *optional*, defaults to 2.0):
+ Relative weight of the generalized IoU loss in the object detection loss.
+ weight_loss_fgl (`float`, *optional*, defaults to 0.15):
+ Relative weight of the fine-grained localization loss in the object detection loss.
+ weight_loss_ddf (`float`, *optional*, defaults to 1.5):
+ Relative weight of the decoupled distillation focal loss in the object detection loss.
+ eos_coefficient (`float`, *optional*, defaults to 0.0001):
+ Relative classification weight of the 'no-object' class in the object detection loss.
+ eval_idx (`int`, *optional*, defaults to -1):
+ Index of the decoder layer to use for evaluation. If negative, counts from the end
+ (e.g., -1 means use the last layer). This allows for early prediction in the decoder
+ stack while still training later layers.
+ layer_scale (`float`, *optional*, defaults to `1.0`):
+ Scaling factor for the hidden dimension in later decoder layers. Used to adjust the
+ model capacity after the evaluation layer.
+ max_num_bins (`int`, *optional*, defaults to 32):
+ Maximum number of bins for the distribution-guided bounding box refinement.
+ Higher values allow for more fine-grained localization but increase computation.
+ reg_scale (`float`, *optional*, defaults to 4.0):
+ Scale factor for the regression distribution. Controls the range and granularity
+ of the bounding box refinement process.
+ depth_mult (`float`, *optional*, defaults to 1.0):
+ Multiplier for the number of blocks in RepNCSPELAN4 layers. Used to scale the model's
+ depth while maintaining its architecture.
+ top_prob_values (`int`, *optional*, defaults to 4):
+ Number of top probability values to consider from each corner's distribution.
+ lqe_hidden_dim (`int`, *optional*, defaults to 64):
+ Hidden dimension size for the Location Quality Estimator (LQE) network.
+ lqe_layers (`int`, *optional*, defaults to 2):
+ Number of layers in the Location Quality Estimator MLP.
+ decoder_offset_scale (`float`, *optional*, defaults to 0.5):
+ Offset scale used in deformable attention.
+ decoder_method (`str`, *optional*, defaults to `"default"`):
+ The method to use for the decoder: `"default"` or `"discrete"`.
+ up (`float`, *optional*, defaults to 0.5):
+ Controls the upper bounds of the Weighting Function.
+ """
+
+ model_type = "d_fine"
+ layer_types = ["basic", "bottleneck"]
+ attribute_map = {
+ "hidden_size": "d_model",
+ "num_attention_heads": "encoder_attention_heads",
+ }
+
+ def __init__(
+ self,
+ initializer_range=0.01,
+ initializer_bias_prior_prob=None,
+ layer_norm_eps=1e-5,
+ batch_norm_eps=1e-5,
+ # backbone
+ backbone_config=None,
+ backbone=None,
+ use_pretrained_backbone=False,
+ use_timm_backbone=False,
+ freeze_backbone_batch_norms=True,
+ backbone_kwargs=None,
+ # encoder HybridEncoder
+ encoder_hidden_dim=256,
+ encoder_in_channels=[512, 1024, 2048],
+ feat_strides=[8, 16, 32],
+ encoder_layers=1,
+ encoder_ffn_dim=1024,
+ encoder_attention_heads=8,
+ dropout=0.0,
+ activation_dropout=0.0,
+ encode_proj_layers=[2],
+ positional_encoding_temperature=10000,
+ encoder_activation_function="gelu",
+ activation_function="silu",
+ eval_size=None,
+ normalize_before=False,
+ hidden_expansion=1.0,
+ # decoder DFineTransformer
+ d_model=256,
+ num_queries=300,
+ decoder_in_channels=[256, 256, 256],
+ decoder_ffn_dim=1024,
+ num_feature_levels=3,
+ decoder_n_points=4,
+ decoder_layers=6,
+ decoder_attention_heads=8,
+ decoder_activation_function="relu",
+ attention_dropout=0.0,
+ num_denoising=100,
+ label_noise_ratio=0.5,
+ box_noise_scale=1.0,
+ learn_initial_query=False,
+ anchor_image_size=None,
+ with_box_refine=True,
+ is_encoder_decoder=True,
+ # Loss
+ matcher_alpha=0.25,
+ matcher_gamma=2.0,
+ matcher_class_cost=2.0,
+ matcher_bbox_cost=5.0,
+ matcher_giou_cost=2.0,
+ use_focal_loss=True,
+ auxiliary_loss=True,
+ focal_loss_alpha=0.75,
+ focal_loss_gamma=2.0,
+ weight_loss_vfl=1.0,
+ weight_loss_bbox=5.0,
+ weight_loss_giou=2.0,
+ weight_loss_fgl=0.15,
+ weight_loss_ddf=1.5,
+ eos_coefficient=1e-4,
+ eval_idx=-1,
+ layer_scale=1,
+ max_num_bins=32,
+ reg_scale=4.0,
+ depth_mult=1.0,
+ top_prob_values=4,
+ lqe_hidden_dim=64,
+ lqe_layers=2,
+ decoder_offset_scale=0.5,
+ decoder_method="default",
+ up=0.5,
+ **kwargs,
+ ):
+ self.initializer_range = initializer_range
+ self.initializer_bias_prior_prob = initializer_bias_prior_prob
+ self.layer_norm_eps = layer_norm_eps
+ self.batch_norm_eps = batch_norm_eps
+ # backbone
+ if backbone_config is None and backbone is None:
+ logger.info(
+ "`backbone_config` and `backbone` are `None`. Initializing the config with the default `HGNet-V2` backbone."
+ )
+ backbone_model_type = "hgnet_v2"
+ config_class = CONFIG_MAPPING[backbone_model_type]
+ # this will map it to RTDetrResNetConfig
+ # note: we can instead create HGNetV2Config
+ # and we would need to create HGNetV2Backbone
+ backbone_config = config_class(
+ num_channels=3,
+ embedding_size=64,
+ hidden_sizes=[256, 512, 1024, 2048],
+ depths=[3, 4, 6, 3],
+ layer_type="bottleneck",
+ hidden_act="relu",
+ downsample_in_first_stage=False,
+ downsample_in_bottleneck=False,
+ out_features=None,
+ out_indices=[2, 3, 4],
+ )
+ elif isinstance(backbone_config, dict):
+ backbone_model_type = backbone_config.pop("model_type")
+ config_class = CONFIG_MAPPING[backbone_model_type]
+ backbone_config = config_class.from_dict(backbone_config)
+
+ verify_backbone_config_arguments(
+ use_timm_backbone=use_timm_backbone,
+ use_pretrained_backbone=use_pretrained_backbone,
+ backbone=backbone,
+ backbone_config=backbone_config,
+ backbone_kwargs=backbone_kwargs,
+ )
+
+ self.backbone_config = backbone_config
+ self.backbone = backbone
+ self.use_pretrained_backbone = use_pretrained_backbone
+ self.use_timm_backbone = use_timm_backbone
+ self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
+ self.backbone_kwargs = backbone_kwargs
+ # encoder
+ self.encoder_hidden_dim = encoder_hidden_dim
+ self.encoder_in_channels = encoder_in_channels
+ self.feat_strides = feat_strides
+ self.encoder_attention_heads = encoder_attention_heads
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.dropout = dropout
+ self.activation_dropout = activation_dropout
+ self.encode_proj_layers = encode_proj_layers
+ self.encoder_layers = encoder_layers
+ self.positional_encoding_temperature = positional_encoding_temperature
+ self.eval_size = eval_size
+ self.normalize_before = normalize_before
+ self.encoder_activation_function = encoder_activation_function
+ self.activation_function = activation_function
+ self.hidden_expansion = hidden_expansion
+ # decoder
+ self.d_model = d_model
+ self.num_queries = num_queries
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.decoder_in_channels = decoder_in_channels
+ self.num_feature_levels = num_feature_levels
+ self.decoder_n_points = decoder_n_points
+ self.decoder_layers = decoder_layers
+ self.decoder_attention_heads = decoder_attention_heads
+ self.decoder_activation_function = decoder_activation_function
+ self.attention_dropout = attention_dropout
+ self.num_denoising = num_denoising
+ self.label_noise_ratio = label_noise_ratio
+ self.box_noise_scale = box_noise_scale
+ self.learn_initial_query = learn_initial_query
+ self.anchor_image_size = anchor_image_size
+ self.auxiliary_loss = auxiliary_loss
+ self.with_box_refine = with_box_refine
+ # Loss
+ self.matcher_alpha = matcher_alpha
+ self.matcher_gamma = matcher_gamma
+ self.matcher_class_cost = matcher_class_cost
+ self.matcher_bbox_cost = matcher_bbox_cost
+ self.matcher_giou_cost = matcher_giou_cost
+ self.use_focal_loss = use_focal_loss
+ self.focal_loss_alpha = focal_loss_alpha
+ self.focal_loss_gamma = focal_loss_gamma
+ self.weight_loss_vfl = weight_loss_vfl
+ self.weight_loss_bbox = weight_loss_bbox
+ self.weight_loss_giou = weight_loss_giou
+ self.weight_loss_fgl = weight_loss_fgl
+ self.weight_loss_ddf = weight_loss_ddf
+ self.eos_coefficient = eos_coefficient
+ # add the new attributes with the given values or defaults
+ self.eval_idx = eval_idx
+ self.layer_scale = layer_scale
+ self.max_num_bins = max_num_bins
+ self.reg_scale = reg_scale
+ self.depth_mult = depth_mult
+ self.decoder_offset_scale = decoder_offset_scale
+ self.decoder_method = decoder_method
+ self.top_prob_values = top_prob_values
+ self.lqe_hidden_dim = lqe_hidden_dim
+ self.lqe_layers = lqe_layers
+ self.up = up
+
+ if isinstance(self.decoder_n_points, list):
+ if len(self.decoder_n_points) != self.num_feature_levels:
+ raise ValueError(
+ f"Length of decoder_n_points list ({len(self.decoder_n_points)}) must match num_feature_levels ({self.num_feature_levels})."
+ )
+
+ head_dim = self.d_model // self.decoder_attention_heads
+ if head_dim * self.decoder_attention_heads != self.d_model:
+ raise ValueError(
+ f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}"
+ )
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
+
+ @property
+ def num_attention_heads(self) -> int:
+ return self.encoder_attention_heads
+
+ @property
+ def hidden_size(self) -> int:
+ return self.d_model
+
+ @property
+ def sub_configs(self):
+ return (
+ {"backbone_config": type(self.backbone_config)}
+ if getattr(self, "backbone_config", None) is not None
+ else {}
+ )
+
+ @classmethod
+ def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
+ """Instantiate a [`DFineConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
+ configuration.
+
+ Args:
+ backbone_config ([`PretrainedConfig`]):
+ The backbone configuration.
+
+ Returns:
+ [`DFineConfig`]: An instance of a configuration object
+ """
+ return cls(
+ backbone_config=backbone_config,
+ **kwargs,
+ )
+
+
+class DFineMultiscaleDeformableAttention(nn.Module):
+ def __init__(self, config: DFineConfig):
+ """
+ D-Fine version of multiscale deformable attention
+ """
+ super().__init__()
+ self.d_model = config.d_model
+ self.n_heads = config.decoder_attention_heads
+ self.n_levels = config.num_feature_levels
+ self.offset_scale = config.decoder_offset_scale
+ self.decoder_method = config.decoder_method
+ self.n_points = config.decoder_n_points
+
+ if isinstance(self.n_points, list):
+ num_points_list = self.n_points
+ else:
+ num_points_list = [self.n_points for _ in range(self.n_levels)]
+
+ self.num_points_list = num_points_list
+ num_points_scale = [1 / n for n in self.num_points_list for _ in range(n)]
+ self.register_buffer("num_points_scale", torch.tensor(num_points_scale, dtype=torch.float32))
+
+ self.total_points = self.n_heads * sum(self.num_points_list)
+
+ self.sampling_offsets = nn.Linear(self.d_model, self.total_points * 2)
+ self.attention_weights = nn.Linear(self.d_model, self.total_points)
+
+ self.ms_deformable_attn_core = multi_scale_deformable_attention_v2
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ reference_points=None,
+ encoder_hidden_states=None,
+ spatial_shapes=None,
+ spatial_shapes_list=None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size, num_queries, _ = hidden_states.shape
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
+
+ if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
+ raise ValueError(
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
+ )
+
+ # Reshape for multi-head attention
+ value = encoder_hidden_states.reshape(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
+ if attention_mask is not None:
+ value = value.masked_fill(~attention_mask[..., None], float(0))
+
+ sampling_offsets: torch.Tensor = self.sampling_offsets(hidden_states)
+ sampling_offsets = sampling_offsets.reshape(
+ batch_size, num_queries, self.n_heads, sum(self.num_points_list), 2
+ )
+
+ attention_weights = self.attention_weights(hidden_states).reshape(
+ batch_size, num_queries, self.n_heads, sum(self.num_points_list)
+ )
+ attention_weights = F.softmax(attention_weights, dim=-1)
+
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.tensor(spatial_shapes)
+ offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.n_levels, 1, 2)
+ sampling_locations = (
+ reference_points.reshape(batch_size, sequence_length, 1, self.n_levels, 1, 2)
+ + sampling_offsets / offset_normalizer
+ )
+ elif reference_points.shape[-1] == 4:
+ # reference_points [8, 480, None, 1, 4]
+ # sampling_offsets [8, 480, 8, 12, 2]
+ num_points_scale = self.num_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1)
+ offset = sampling_offsets * num_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale
+ sampling_locations = reference_points[:, :, None, :, :2] + offset
+ else:
+ raise ValueError(
+ f"Last dim of reference_points must be 2 or 4, but get {reference_points.shape[-1]} instead."
+ )
+
+ output = self.ms_deformable_attn_core(
+ value,
+ spatial_shapes_list,
+ sampling_locations,
+ attention_weights,
+ self.num_points_list,
+ self.decoder_method,
+ )
+
+ return output, attention_weights
+
+
+class DFineGate(nn.Module):
+ def __init__(self, d_model: int):
+ super().__init__()
+ self.gate = nn.Linear(2 * d_model, 2 * d_model)
+ self.norm = nn.LayerNorm(d_model)
+
+ def forward(self, second_residual: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
+ gate_input = torch.cat([second_residual, hidden_states], dim=-1)
+ gates = torch.sigmoid(self.gate(gate_input))
+ gate1, gate2 = gates.chunk(2, dim=-1)
+ hidden_states = self.norm(gate1 * second_residual + gate2 * hidden_states)
+ return hidden_states
+
+
+class DFineDecoderLayer(RTDetrDecoderLayer):
+ def __init__(self, config: DFineConfig):
+ super().__init__(config)
+
+ # override the encoder attention module with d-fine version
+ self.encoder_attn = DFineMultiscaleDeformableAttention(config=config)
+ # gate
+ self.gateway = DFineGate(config.d_model)
+
+ del self.encoder_attn_layer_norm
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[torch.Tensor] = None,
+ reference_points=None,
+ spatial_shapes=None,
+ spatial_shapes_list=None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.Tensor, Any, Any]:
+ # Self Attention
+ hidden_states_2, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=encoder_attention_mask,
+ position_embeddings=position_embeddings,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
+ hidden_states = hidden_states + hidden_states_2
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ residual = hidden_states
+
+ # Cross-Attention
+ cross_attn_weights = None
+ hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings
+ hidden_states_2, cross_attn_weights = self.encoder_attn(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ )
+
+ hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
+ hidden_states = self.gateway(residual, hidden_states_2)
+
+ # Fully Connected
+ hidden_states_2 = self.activation_fn(self.fc1(hidden_states))
+ hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.activation_dropout, training=self.training)
+ hidden_states_2 = self.fc2(hidden_states_2)
+ hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
+ hidden_states = hidden_states + hidden_states_2
+ hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504))
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ return outputs
+
+
+class DFinePreTrainedModel(RTDetrPreTrainedModel):
+ def _init_weights(self, module):
+ # initialize linear layer bias value according to a given probability value.
+ if isinstance(module, (DFineForObjectDetection, DFineDecoder)):
+ if module.class_embed is not None:
+ for layer in module.class_embed:
+ prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
+ bias = float(-math.log((1 - prior_prob) / prior_prob))
+ nn.init.xavier_uniform_(layer.weight)
+ nn.init.constant_(layer.bias, bias)
+
+ if module.bbox_embed is not None:
+ for layer in module.bbox_embed:
+ nn.init.constant_(layer.layers[-1].weight, 0)
+ nn.init.constant_(layer.layers[-1].bias, 0)
+
+ if hasattr(module, "reg_scale"):
+ module.reg_scale.fill_(self.config.reg_scale)
+
+ if hasattr(module, "up"):
+ module.up.fill_(self.config.up)
+
+ if isinstance(module, DFineMultiscaleDeformableAttention):
+ nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
+ default_dtype = torch.get_default_dtype()
+ thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
+ 2.0 * math.pi / module.n_heads
+ )
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values
+ grid_init = grid_init.reshape(module.n_heads, 1, 2).tile([1, sum(module.num_points_list), 1])
+ scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1)
+ grid_init *= scaling
+ with torch.no_grad():
+ module.sampling_offsets.bias.data[...] = grid_init.flatten()
+
+ nn.init.constant_(module.attention_weights.weight.data, 0.0)
+ nn.init.constant_(module.attention_weights.bias.data, 0.0)
+
+ if isinstance(module, DFineModel):
+ prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
+ bias = float(-math.log((1 - prior_prob) / prior_prob))
+ nn.init.xavier_uniform_(module.enc_score_head.weight)
+ nn.init.constant_(module.enc_score_head.bias, bias)
+
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ if isinstance(module, DFineGate):
+ bias = float(-math.log((1 - 0.5) / 0.5))
+ init.constant_(module.gate.bias, bias)
+ init.constant_(module.gate.weight, 0)
+
+ if isinstance(module, DFineLQE):
+ init.constant_(module.reg_conf.layers[-1].bias, 0)
+ init.constant_(module.reg_conf.layers[-1].weight, 0)
+
+ if isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+
+ if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
+ nn.init.xavier_uniform_(module.weight_embedding.weight)
+ if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0:
+ nn.init.xavier_uniform_(module.denoising_class_embed.weight)
+
+
+class DFineIntegral(nn.Module):
+ """
+ A static layer that calculates integral results from a distribution.
+
+ This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`,
+ where Pr(n) is the softmax probability vector representing the discrete
+ distribution, and W(n) is the non-uniform Weighting Function.
+
+ Args:
+ max_num_bins (int): Max number of the discrete bins. Default is 32.
+ It can be adjusted based on the dataset or task requirements.
+ """
+
+ def __init__(self, config: DFineConfig):
+ super().__init__()
+ self.max_num_bins = config.max_num_bins
+
+ def forward(self, pred_corners: torch.Tensor, project: torch.Tensor) -> torch.Tensor:
+ batch_size, num_queries, _ = pred_corners.shape
+ pred_corners = F.softmax(pred_corners.reshape(-1, self.max_num_bins + 1), dim=1)
+ pred_corners = F.linear(pred_corners, project.to(pred_corners.device)).reshape(-1, 4)
+ pred_corners = pred_corners.reshape(batch_size, num_queries, -1)
+ return pred_corners
+
+
+class DFineDecoderOutput(RTDetrDecoderOutput):
+ pass
+
+
+class DFineDecoder(RTDetrDecoder):
+ """
+ D-FINE Decoder implementing Fine-grained Distribution Refinement (FDR).
+
+ This decoder refines object detection predictions through iterative updates across multiple layers,
+ utilizing attention mechanisms, location quality estimators, and distribution refinement techniques
+ to improve bounding box accuracy and robustness.
+ """
+
+ def __init__(self, config: DFineConfig):
+ self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx
+ super().__init__(config=config)
+ self.reg_scale = nn.Parameter(torch.tensor([config.reg_scale]), requires_grad=False)
+ self.max_num_bins = config.max_num_bins
+ self.d_model = config.d_model
+ self.layer_scale = config.layer_scale
+ self.pre_bbox_head = DFineMLP(config.hidden_size, config.hidden_size, 4, 3)
+ self.integral = DFineIntegral(config)
+ self.num_head = config.decoder_attention_heads
+ self.up = nn.Parameter(torch.tensor([config.up]), requires_grad=False)
+ self.lqe_layers = nn.ModuleList([DFineLQE(config) for _ in range(config.decoder_layers)])
+ self.layers = nn.ModuleList(
+ [DFineDecoderLayer(config) for _ in range(config.decoder_layers)]
+ + [DFineDecoderLayer(config) for _ in range(config.decoder_layers - self.eval_idx - 1)]
+ )
+
+ def forward(
+ self,
+ encoder_hidden_states: torch.Tensor,
+ reference_points: torch.Tensor,
+ inputs_embeds: torch.Tensor,
+ spatial_shapes,
+ level_start_index=None,
+ spatial_shapes_list=None,
+ output_hidden_states=None,
+ encoder_attention_mask=None,
+ memory_mask=None,
+ output_attentions=None,
+ return_dict=None,
+ ) -> DFineDecoderOutput:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ intermediate = ()
+ intermediate_reference_points = ()
+ intermediate_logits = ()
+ intermediate_predicted_corners = ()
+ initial_reference_points = ()
+
+ output_detach = pred_corners_undetach = 0
+
+ project = weighting_function(self.max_num_bins, self.up, self.reg_scale)
+ ref_points_detach = F.sigmoid(reference_points)
+
+ for i, decoder_layer in enumerate(self.layers):
+ ref_points_input = ref_points_detach.unsqueeze(2)
+ query_pos_embed = self.query_pos_head(ref_points_detach).clamp(min=-10, max=10)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ output = decoder_layer(
+ hidden_states=hidden_states,
+ position_embeddings=query_pos_embed,
+ reference_points=ref_points_input,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = output[0]
+
+ if i == 0:
+ # Initial bounding box predictions with inverse sigmoid refinement
+ new_reference_points = F.sigmoid(self.pre_bbox_head(output[0]) + inverse_sigmoid(ref_points_detach))
+ ref_points_initial = new_reference_points.detach()
+
+ # Refine bounding box corners using FDR, integrating previous layer's corrections
+ if self.bbox_embed is not None:
+ pred_corners = self.bbox_embed[i](hidden_states + output_detach) + pred_corners_undetach
+ inter_ref_bbox = distance2bbox(
+ ref_points_initial, self.integral(pred_corners, project), self.reg_scale
+ )
+ pred_corners_undetach = pred_corners
+ ref_points_detach = inter_ref_bbox.detach()
+
+ output_detach = hidden_states.detach()
+
+ intermediate += (hidden_states,)
+
+ if self.class_embed is not None and (self.training or i == self.eval_idx):
+ scores = self.class_embed[i](hidden_states)
+ # Add initial logits and reference points with pre-bbox head
+ if i == 0:
+ intermediate_logits += (scores,)
+ intermediate_reference_points += (new_reference_points,)
+ # Lqe does not affect the performance here.
+ scores = self.lqe_layers[i](scores, pred_corners)
+ intermediate_logits += (scores,)
+ intermediate_reference_points += (inter_ref_bbox,)
+ initial_reference_points += (ref_points_initial,)
+ intermediate_predicted_corners += (pred_corners,)
+
+ if output_attentions:
+ all_self_attns += (output[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (output[2],)
+
+ # Keep batch_size as first dimension
+ intermediate = torch.stack(intermediate)
+ if self.class_embed is not None and self.bbox_embed is not None:
+ intermediate_logits = torch.stack(intermediate_logits, dim=1)
+ intermediate_predicted_corners = torch.stack(intermediate_predicted_corners, dim=1)
+ initial_reference_points = torch.stack(initial_reference_points, dim=1)
+ intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ intermediate,
+ intermediate_logits,
+ intermediate_reference_points,
+ intermediate_predicted_corners,
+ initial_reference_points,
+ all_hidden_states,
+ all_self_attns,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+
+ return DFineDecoderOutput(
+ last_hidden_state=hidden_states,
+ intermediate_hidden_states=intermediate,
+ intermediate_logits=intermediate_logits,
+ intermediate_reference_points=intermediate_reference_points,
+ intermediate_predicted_corners=intermediate_predicted_corners,
+ initial_reference_points=initial_reference_points,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class DFineModel(RTDetrModel):
+ def __init__(self, config: DFineConfig):
+ super().__init__(config)
+ del self.decoder_input_proj
+ self.encoder = DFineHybridEncoder(config=config)
+ num_backbone_outs = len(config.decoder_in_channels)
+ decoder_input_proj = []
+ in_channels = config.decoder_in_channels[-1]
+ for _ in range(num_backbone_outs):
+ if config.hidden_size == config.decoder_in_channels[-1]:
+ decoder_input_proj.append(nn.Identity())
+ else:
+ conv = nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False)
+ batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps)
+ decoder_input_proj.append(nn.Sequential(conv, batchnorm))
+ for _ in range(config.num_feature_levels - num_backbone_outs):
+ if config.hidden_size == config.decoder_in_channels[-1]:
+ decoder_input_proj.append(nn.Identity())
+ else:
+ conv = nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False)
+ batchnorm = nn.BatchNorm2d(config.d_model, config.batch_norm_eps)
+ decoder_input_proj.append(nn.Sequential(conv, batchnorm))
+ self.decoder_input_proj = nn.ModuleList(decoder_input_proj)
+ self.decoder = DFineDecoder(config)
+
+
+class DFineForObjectDetection(RTDetrForObjectDetection, DFinePreTrainedModel):
+ def __init__(self, config: DFineConfig):
+ DFinePreTrainedModel.__init__(self, config)
+
+ # D-FINE encoder-decoder model
+ self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx
+ self.model = DFineModel(config)
+ scaled_dim = round(config.layer_scale * config.hidden_size)
+ num_pred = config.decoder_layers
+ self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)])
+ self.bbox_embed = nn.ModuleList(
+ [
+ DFineMLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3)
+ for _ in range(self.eval_idx + 1)
+ ]
+ + [
+ DFineMLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3)
+ for _ in range(config.decoder_layers - self.eval_idx - 1)
+ ]
+ )
+
+ # here self.model.decoder.bbox_embed is null, but not self.bbox_embed
+ self.model.decoder.class_embed = self.class_embed
+ self.model.decoder.bbox_embed = self.bbox_embed
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(**super_kwargs):
+ r"""
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers.image_utils import load_image
+ >>> from transformers import AutoImageProcessor, DFineForObjectDetection
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = load_image(url)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-xlarge-coco")
+ >>> model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-xlarge-coco")
+
+ >>> # prepare image for the model
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> # forward pass
+ >>> outputs = model(**inputs)
+
+ >>> logits = outputs.logits
+ >>> list(logits.shape)
+ [1, 300, 80]
+
+ >>> boxes = outputs.pred_boxes
+ >>> list(boxes.shape)
+ [1, 300, 4]
+
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
+ >>> target_sizes = torch.tensor([image.size[::-1]])
+ >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)
+ >>> result = results[0] # first image in batch
+
+ >>> for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
+ ... box = [round(i, 2) for i in box.tolist()]
+ ... print(
+ ... f"Detected {model.config.id2label[label.item()]} with confidence "
+ ... f"{round(score.item(), 3)} at location {box}"
+ ... )
+ Detected cat with confidence 0.958 at location [344.49, 23.4, 639.84, 374.27]
+ Detected cat with confidence 0.956 at location [11.71, 53.52, 316.64, 472.33]
+ Detected remote with confidence 0.947 at location [40.46, 73.7, 175.62, 117.57]
+ Detected sofa with confidence 0.918 at location [0.59, 1.88, 640.25, 474.74]
+ ```
+ """
+ super().forward(**super_kwargs)
+
+
+def weighting_function(max_num_bins: int, up: torch.Tensor, reg_scale: int) -> torch.Tensor:
+ """
+ Generates the non-uniform Weighting Function W(n) for bounding box regression.
+
+ Args:
+ max_num_bins (int): Max number of the discrete bins.
+ up (Tensor): Controls upper bounds of the sequence,
+ where maximum offset is ±up * H / W.
+ reg_scale (float): Controls the curvature of the Weighting Function.
+ Larger values result in flatter weights near the central axis W(max_num_bins/2)=0
+ and steeper weights at both ends.
+ Returns:
+ Tensor: Sequence of Weighting Function.
+ """
+ upper_bound1 = abs(up[0]) * abs(reg_scale)
+ upper_bound2 = abs(up[0]) * abs(reg_scale) * 2
+ step = (upper_bound1 + 1) ** (2 / (max_num_bins - 2))
+ left_values = [-((step) ** i) + 1 for i in range(max_num_bins // 2 - 1, 0, -1)]
+ right_values = [(step) ** i - 1 for i in range(1, max_num_bins // 2)]
+ values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2]
+ values = torch.cat(values, 0)
+ return values
+
+
+class DFineMLPPredictionHead(RTDetrMLPPredictionHead):
+ pass
+
+
+def distance2bbox(points, distance: torch.Tensor, reg_scale: float) -> torch.Tensor:
+ """
+ Decodes edge-distances into bounding box coordinates.
+
+ Args:
+ points (`torch.Tensor`):
+ (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height]
+ distance (`torch.Tensor`):
+ (batch_size, num_boxes, 4) or (num_boxes, 4), representing distances from the point to the left, top, right, and bottom boundaries.
+ reg_scale (`float`):
+ Controls the curvature of the Weighting Function.
+ Returns:
+ `torch.Tensor`: Bounding boxes in (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height]
+ """
+ reg_scale = abs(reg_scale)
+ top_left_x = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale)
+ top_left_y = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale)
+ bottom_right_x = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale)
+ bottom_right_y = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale)
+
+ bboxes = torch.stack([top_left_x, top_left_y, bottom_right_x, bottom_right_y], -1)
+
+ return corners_to_center_format(bboxes)
+
+
+class DFineMLP(nn.Module):
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: str = "relu"):
+ super().__init__()
+ self.num_layers = num_layers
+ hidden_dims = [hidden_dim] * (num_layers - 1)
+ input_dims = [input_dim] + hidden_dims
+ output_dims = hidden_dims + [output_dim]
+ self.layers = nn.ModuleList(nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, output_dims))
+ self.act = ACT2CLS[act]()
+
+ def forward(self, stat_features: torch.Tensor) -> torch.Tensor:
+ for i, layer in enumerate(self.layers):
+ stat_features = self.act(layer(stat_features)) if i < self.num_layers - 1 else layer(stat_features)
+ return stat_features
+
+
+class DFineLQE(nn.Module):
+ def __init__(self, config: DFineConfig):
+ super().__init__()
+ self.top_prob_values = config.top_prob_values
+ self.max_num_bins = config.max_num_bins
+ self.reg_conf = DFineMLP(4 * (self.top_prob_values + 1), config.lqe_hidden_dim, 1, config.lqe_layers)
+
+ def forward(self, scores: torch.Tensor, pred_corners: torch.Tensor) -> torch.Tensor:
+ batch_size, length, _ = pred_corners.size()
+ prob = F.softmax(pred_corners.reshape(batch_size, length, 4, self.max_num_bins + 1), dim=-1)
+ prob_topk, _ = prob.topk(self.top_prob_values, dim=-1)
+ stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1)
+ quality_score = self.reg_conf(stat.reshape(batch_size, length, -1))
+ scores = scores + quality_score
+ return scores
+
+
+class DFineConvNormLayer(RTDetrConvNormLayer):
+ def __init__(
+ self,
+ config: DFineConfig,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int,
+ groups: int = 1,
+ padding: Optional[int] = None,
+ activation: Optional[str] = None,
+ ):
+ super().__init__(config, in_channels, out_channels, kernel_size, stride, padding=None, activation=activation)
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ groups=groups,
+ padding=(kernel_size - 1) // 2 if padding is None else padding,
+ bias=False,
+ )
+
+
+class DFineRepVggBlock(RTDetrRepVggBlock):
+ def __init__(self, config: DFineConfig, in_channels: int, out_channels: int):
+ super().__init__(config)
+ hidden_channels = in_channels
+ self.conv1 = DFineConvNormLayer(config, hidden_channels, out_channels, 3, 1, padding=1)
+ self.conv2 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, padding=0)
+
+
+class DFineCSPRepLayer(nn.Module):
+ """
+ Cross Stage Partial (CSP) network layer with RepVGG blocks.
+ """
+
+ def __init__(
+ self, config: DFineConfig, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0
+ ):
+ super().__init__()
+ activation = config.activation_function
+
+ hidden_channels = int(out_channels * expansion)
+ self.conv1 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
+ self.conv2 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
+ self.bottlenecks = nn.ModuleList(
+ [DFineRepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)]
+ )
+ if hidden_channels != out_channels:
+ self.conv3 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
+ else:
+ self.conv3 = nn.Identity()
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state_1 = self.conv1(hidden_state)
+ for bottleneck in self.bottlenecks:
+ hidden_state_1 = bottleneck(hidden_state_1)
+ hidden_state_2 = self.conv2(hidden_state)
+ hidden_state_3 = self.conv3(hidden_state_1 + hidden_state_2)
+ return hidden_state_3
+
+
+class DFineRepNCSPELAN4(nn.Module):
+ def __init__(self, config: DFineConfig, act: str = "silu", numb_blocks: int = 3):
+ super().__init__()
+ conv1_dim = config.encoder_hidden_dim * 2
+ conv2_dim = config.encoder_hidden_dim
+ conv3_dim = config.encoder_hidden_dim * 2
+ conv4_dim = round(config.hidden_expansion * config.encoder_hidden_dim // 2)
+ self.conv_dim = conv3_dim // 2
+ self.conv1 = DFineConvNormLayer(config, conv1_dim, conv3_dim, 1, 1, activation=act)
+ self.csp_rep1 = DFineCSPRepLayer(config, conv3_dim // 2, conv4_dim, num_blocks=numb_blocks)
+ self.conv2 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
+ self.csp_rep2 = DFineCSPRepLayer(config, conv4_dim, conv4_dim, num_blocks=numb_blocks)
+ self.conv3 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
+ self.conv4 = DFineConvNormLayer(config, conv3_dim + (2 * conv4_dim), conv2_dim, 1, 1, activation=act)
+
+ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
+ # Split initial features into two branches after first convolution
+ split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1))
+
+ # Process branches sequentially
+ branch1 = self.csp_rep1(split_features[-1])
+ branch1 = self.conv2(branch1)
+ branch2 = self.csp_rep2(branch1)
+ branch2 = self.conv3(branch2)
+
+ split_features.extend([branch1, branch2])
+ merged_features = torch.cat(split_features, 1)
+ merged_features = self.conv4(merged_features)
+ return merged_features
+
+
+class DFineSCDown(nn.Module):
+ def __init__(self, config: DFineConfig, kernel_size: int, stride: int):
+ super().__init__()
+ self.conv1 = DFineConvNormLayer(config, config.encoder_hidden_dim, config.encoder_hidden_dim, 1, 1)
+ self.conv2 = DFineConvNormLayer(
+ config,
+ config.encoder_hidden_dim,
+ config.encoder_hidden_dim,
+ kernel_size,
+ stride,
+ config.encoder_hidden_dim,
+ )
+
+ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
+ input_features = self.conv1(input_features)
+ input_features = self.conv2(input_features)
+ return input_features
+
+
+class DFineEncoder(RTDetrEncoder):
+ pass
+
+
+class DFineHybridEncoder(RTDetrHybridEncoder):
+ def __init__(self, config: DFineConfig):
+ nn.Module.__init__(self)
+ self.config = config
+ self.in_channels = config.encoder_in_channels
+ self.num_fpn_stages = len(self.in_channels) - 1
+ self.feat_strides = config.feat_strides
+ self.encoder_hidden_dim = config.encoder_hidden_dim
+ self.encode_proj_layers = config.encode_proj_layers
+ self.positional_encoding_temperature = config.positional_encoding_temperature
+ self.eval_size = config.eval_size
+ self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
+ self.out_strides = self.feat_strides
+
+ # encoder transformer
+ self.encoder = nn.ModuleList([DFineEncoder(config) for _ in range(len(self.encode_proj_layers))])
+ # top-down fpn
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_blocks = nn.ModuleList()
+ for _ in range(len(self.in_channels) - 1, 0, -1):
+ lateral_layer = DFineConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1)
+ self.lateral_convs.append(lateral_layer)
+ num_blocks = round(3 * config.depth_mult)
+ fpn_layer = DFineRepNCSPELAN4(config, numb_blocks=num_blocks)
+ self.fpn_blocks.append(fpn_layer)
+
+ # bottom-up pan
+ self.downsample_convs = nn.ModuleList()
+ self.pan_blocks = nn.ModuleList()
+ for _ in range(len(self.in_channels) - 1):
+ self.downsample_convs.append(DFineSCDown(config, 3, 2))
+ num_blocks = round(3 * config.depth_mult)
+ self.pan_blocks.append(DFineRepNCSPELAN4(config, numb_blocks=num_blocks))
+
+
+__all__ = [
+ "DFineConfig",
+ "DFineModel",
+ "DFinePreTrainedModel",
+ "DFineForObjectDetection",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_v3/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_v3/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..298f4c968375e6d90728a3d4e78c5046e4591c01
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_v3/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_deepseek_v3 import *
+ from .modeling_deepseek_v3 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_v3/configuration_deepseek_v3.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_v3/configuration_deepseek_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b885f8d5ac6e22a1bf3da5e736a3aad26226a21
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_v3/configuration_deepseek_v3.py
@@ -0,0 +1,253 @@
+# coding=utf-8
+# Copyright 2025 bzantium and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on the DeepSeekV3 implementations from the DeepSeek AI team. (https://huggingface.co/deepseek-ai/DeepSeek-V3)
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DeepSeekV3 model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
+
+
+class DeepseekV3Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the DeepSeek-V3.
+ e.g. [bzantium/tiny-deepseek-v3](https://huggingface.co/bzantium/tiny-deepseek-v3)
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 129280):
+ Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`DeepseekV3Model`]
+ hidden_size (`int`, *optional*, defaults to 7168):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 18432):
+ Dimension of the MLP representations.
+ moe_intermediate_size (`int`, *optional*, defaults to 2048):
+ Dimension of the MoE representations.
+ num_hidden_layers (`int`, *optional*, defaults to 61):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 128):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 128):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ n_shared_experts (`int`, *optional*, defaults to 1):
+ Number of shared experts.
+ n_routed_experts (`int`, *optional*, defaults to 256):
+ Number of routed experts.
+ routed_scaling_factor (`float`, *optional*, defaults to 2.5):
+ Scaling factor or routed experts.
+ kv_lora_rank (`int`, *optional*, defaults to 512):
+ Rank of the LoRA matrices for key and value projections.
+ q_lora_rank (`int`, *optional*, defaults to 1536):
+ Rank of the LoRA matrices for query projections.
+ qk_rope_head_dim (`int`, *optional*, defaults to 64):
+ Dimension of the query/key heads that use rotary position embeddings.
+ v_head_dim (`int`, *optional*, defaults to 128):
+ Dimension of the value heads.
+ qk_nope_head_dim (`int`, *optional*, defaults to 128):
+ Dimension of the query/key heads that don't use rotary position embeddings.
+ n_group (`int`, *optional*, defaults to 8):
+ Number of groups for routed experts.
+ topk_group (`int`, *optional*, defaults to 4):
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
+ Number of selected experts, None means dense model.
+ first_k_dense_replace (`int`, *optional*, defaults to 3):
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
+ \--k dense layers--/
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the weights of the routed experts.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 0):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 1):
+ End of stream token id.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum.
+ rope_interleave (`bool`, *optional*, defaults to `True`):
+ Whether to interleave the rotary position embeddings.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+
+ ```python
+ >>> from transformers import DeepseekV3Model, DeepseekV3Config
+
+ >>> # Initializing a Deepseek-V3 style configuration
+ >>> configuration = DeepseekV3Config()
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "deepseek_v3"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = { # TODO: only replicate attention layers when > first_k_dense_replace
+ "layers.*.mlp.experts.*.gate_proj": "local_colwise",
+ "layers.*.mlp.experts.*.up_proj": "local_colwise",
+ "layers.*.mlp.experts.*.down_proj": "local_rowwise",
+ "layers.*.mlp.experts.*": "local", # each expert is wrapped in a module list
+ "layers.*.mlp.shared_experts.gate_proj": "local_colwise",
+ "layers.*.mlp.shared_experts.up_proj": "local_colwise",
+ "layers.*.mlp.shared_experts.down_proj": "local_rowwise",
+ "layers.*.mlp.shared_experts": "local",
+ "layers.*.mlp.gate_proj": "local_colwise",
+ "layers.*.mlp.up_proj": "local_colwise",
+ "layers.*.mlp.down_proj": "local_rowwise",
+ "layers.*.mlp": "gather", # This is the only moment where results are gathered
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=129280,
+ hidden_size=7168,
+ intermediate_size=18432,
+ moe_intermediate_size=2048,
+ num_hidden_layers=61,
+ num_attention_heads=128,
+ num_key_value_heads=128,
+ n_shared_experts=1,
+ n_routed_experts=256,
+ routed_scaling_factor=2.5,
+ kv_lora_rank=512,
+ q_lora_rank=1536,
+ qk_rope_head_dim=64,
+ v_head_dim=128,
+ qk_nope_head_dim=128,
+ n_group=8,
+ topk_group=4,
+ num_experts_per_tok=8,
+ first_k_dense_replace=3,
+ norm_topk_prob=True,
+ hidden_act="silu",
+ max_position_embeddings=4096,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=0,
+ eos_token_id=1,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ rope_interleave=True,
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.n_shared_experts = n_shared_experts
+ self.n_routed_experts = n_routed_experts
+ self.routed_scaling_factor = routed_scaling_factor
+ self.kv_lora_rank = kv_lora_rank
+ self.q_lora_rank = q_lora_rank
+ self.qk_rope_head_dim = qk_rope_head_dim
+ self.v_head_dim = v_head_dim
+ self.qk_nope_head_dim = qk_nope_head_dim
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
+ self.head_dim = qk_rope_head_dim
+ self.n_group = n_group
+ self.topk_group = topk_group
+ self.num_experts_per_tok = num_experts_per_tok
+ self.first_k_dense_replace = first_k_dense_replace
+ self.norm_topk_prob = norm_topk_prob
+ self.rope_interleave = rope_interleave
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+
+ if self.rope_scaling is not None:
+ for key in ["beta_fast", "beta_slow", "factor"]:
+ if key in self.rope_scaling:
+ self.rope_scaling[key] = float(self.rope_scaling[key])
+
+ rope_config_validation(self)
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+__all__ = ["DeepseekV3Config"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_v3/modeling_deepseek_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..eecc6c8b67e110e2dc3d35af881bb500f0e1d56c
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_v3/modeling_deepseek_v3.py
@@ -0,0 +1,693 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_deepseek_v3.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+import math
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import (
+ GenericForSequenceClassification,
+ GenericForTokenClassification,
+ GradientCheckpointingLayer,
+)
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_deepseek_v3 import DeepseekV3Config
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class DeepseekV3RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class DeepseekV3RotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: DeepseekV3Config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class DeepseekV3MLP(nn.Module):
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
+
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class DeepseekV3TopkRouter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.top_k = config.num_experts_per_tok
+ self.n_routed_experts = config.n_routed_experts
+ self.routed_scaling_factor = config.routed_scaling_factor
+ self.n_group = config.n_group
+ self.topk_group = config.topk_group
+ self.norm_topk_prob = config.norm_topk_prob
+
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
+ self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts))
+
+ @torch.no_grad()
+ def get_topk_indices(self, scores):
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
+ group_scores = (
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
+ .topk(2, dim=-1)[0]
+ .sum(dim=-1)
+ )
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
+ group_mask = torch.zeros_like(group_scores)
+ group_mask.scatter_(1, group_idx, 1)
+ score_mask = (
+ group_mask.unsqueeze(-1)
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
+ .reshape(-1, self.n_routed_experts)
+ )
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
+ return topk_indices
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
+ scores = router_logits.sigmoid()
+ topk_indices = self.get_topk_indices(scores)
+ topk_weights = scores.gather(1, topk_indices)
+ if self.norm_topk_prob:
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
+ topk_weights /= denominator
+ topk_weights = topk_weights * self.routed_scaling_factor
+ return topk_indices, topk_weights
+
+
+class DeepseekV3MoE(nn.Module):
+ """
+ A mixed expert module containing shared experts.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.experts = nn.ModuleList(
+ [
+ DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
+ for _ in range(config.n_routed_experts)
+ ]
+ )
+ self.gate = DeepseekV3TopkRouter(config)
+ self.shared_experts = DeepseekV3MLP(
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
+ )
+
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
+ r"""
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
+ """
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
+ expert_mask = expert_mask.permute(2, 0, 1)
+
+ for expert_idx in range(len(self.experts)):
+ expert = self.experts[expert_idx]
+ mask = expert_mask[expert_idx]
+ token_indices, weight_indices = torch.where(mask)
+
+ if token_indices.numel() > 0:
+ expert_weights = topk_weights[token_indices, weight_indices]
+ expert_input = hidden_states[token_indices]
+ expert_output = expert(expert_input)
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
+
+ # in original deepseek, the output of the experts are gathered once we leave this module
+ # thus the moe module is itelsf an IsolatedParallel module
+ # and all expert are "local" meaning we shard but we don't gather
+ return final_hidden_states.type(hidden_states.dtype)
+
+ def forward(self, hidden_states):
+ residuals = hidden_states
+ orig_shape = hidden_states.shape
+ topk_indices, topk_weights = self.gate(hidden_states)
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
+ hidden_states = hidden_states + self.shared_experts(residuals)
+ return hidden_states
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ r"""
+ TODO let's just use the original freqcis computation to not have the view
+ transpose + reshape! This is not optimized!
+ Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ b, h, s, d = q.shape
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
+
+ b, h, s, d = k.shape
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def yarn_get_mscale(scale=1, mscale=1):
+ if scale <= 1:
+ return 1.0
+ return 0.1 * mscale * math.log(scale) + 1.0
+
+
+class DeepseekV3Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.attention_dropout = config.attention_dropout
+ self.num_heads = config.num_attention_heads
+ self.rope_theta = config.rope_theta
+ self.q_lora_rank = config.q_lora_rank
+ self.qk_rope_head_dim = config.qk_rope_head_dim
+ self.kv_lora_rank = config.kv_lora_rank
+ self.v_head_dim = config.v_head_dim
+ self.qk_nope_head_dim = config.qk_nope_head_dim
+ self.qk_head_dim = config.qk_head_dim
+
+ self.is_causal = True
+ if self.q_lora_rank is None:
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
+ else:
+ self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
+ self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
+ self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
+
+ self.kv_a_proj_with_mqa = nn.Linear(
+ config.hidden_size,
+ self.kv_lora_rank + self.qk_rope_head_dim,
+ bias=config.attention_bias,
+ )
+ self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
+ self.kv_b_proj = nn.Linear(
+ self.kv_lora_rank,
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
+ bias=False,
+ )
+
+ self.o_proj = nn.Linear(
+ self.num_heads * self.v_head_dim,
+ config.hidden_size,
+ bias=config.attention_bias,
+ )
+
+ self.scaling = self.qk_head_dim ** (-0.5)
+ if self.config.rope_scaling is not None:
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
+ scaling_factor = self.config.rope_scaling["factor"]
+ if mscale_all_dim:
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
+ self.scaling = self.scaling * mscale * mscale
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ batch_size, seq_length = hidden_states.shape[:-1]
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
+
+ if self.q_lora_rank is None:
+ q_states = self.q_proj(hidden_states)
+ else:
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
+ q_states = q_states.view(query_shape).transpose(1, 2)
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
+
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
+
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
+
+ cos, sin = position_embeddings
+ if self.config.rope_interleave: # support using interleaved weights for efficiency
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
+ else:
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
+
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
+ value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
+
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class DeepseekV3DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
+
+ if layer_idx >= config.first_k_dense_replace:
+ self.mlp = DeepseekV3MoE(config)
+ else:
+ self.mlp = DeepseekV3MLP(config)
+
+ self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class DeepseekV3PreTrainedModel(PreTrainedModel):
+ config: DeepseekV3Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["DeepseekV3DecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _can_compile_fullgraph = False
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": DeepseekV3DecoderLayer,
+ "attentions": DeepseekV3Attention,
+ }
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, DeepseekV3TopkRouter):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+
+@auto_docstring
+class DeepseekV3Model(DeepseekV3PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"]
+
+ def __init__(self, config: DeepseekV3Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = DeepseekV3Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
+
+ >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class DeepseekV3ForSequenceClassification(GenericForSequenceClassification, DeepseekV3PreTrainedModel):
+ pass
+
+
+class DeepseekV3ForTokenClassification(GenericForTokenClassification, DeepseekV3PreTrainedModel):
+ pass
+
+
+__all__ = [
+ "DeepseekV3PreTrainedModel",
+ "DeepseekV3Model",
+ "DeepseekV3ForCausalLM",
+ "DeepseekV3ForSequenceClassification",
+ "DeepseekV3ForTokenClassification",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_v3/modular_deepseek_v3.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_v3/modular_deepseek_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc3dc0c4ce3b382213917f215171d1bffaf914bc
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_v3/modular_deepseek_v3.py
@@ -0,0 +1,373 @@
+import math
+from typing import Callable, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GenericForSequenceClassification, GenericForTokenClassification
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import logging
+from ...utils.deprecation import deprecate_kwarg
+from ..llama.modeling_llama import (
+ LlamaDecoderLayer,
+ LlamaForCausalLM,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ apply_rotary_pos_emb,
+ eager_attention_forward,
+ rotate_half,
+)
+from .configuration_deepseek_v3 import DeepseekV3Config
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeepseekV3RMSNorm(LlamaRMSNorm):
+ pass
+
+
+class DeepseekV3RotaryEmbedding(LlamaRotaryEmbedding):
+ pass
+
+
+def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ r"""
+ TODO let's just use the original freqcis computation to not have the view
+ transpose + reshape! This is not optimized!
+ Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ b, h, s, d = q.shape
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
+
+ b, h, s, d = k.shape
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def yarn_get_mscale(scale=1, mscale=1):
+ if scale <= 1:
+ return 1.0
+ return 0.1 * mscale * math.log(scale) + 1.0
+
+
+class DeepseekV3MLP(nn.Module):
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
+
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class DeepseekV3TopkRouter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.top_k = config.num_experts_per_tok
+ self.n_routed_experts = config.n_routed_experts
+ self.routed_scaling_factor = config.routed_scaling_factor
+ self.n_group = config.n_group
+ self.topk_group = config.topk_group
+ self.norm_topk_prob = config.norm_topk_prob
+
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
+ self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts))
+
+ @torch.no_grad()
+ def get_topk_indices(self, scores):
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
+ group_scores = (
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
+ .topk(2, dim=-1)[0]
+ .sum(dim=-1)
+ )
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
+ group_mask = torch.zeros_like(group_scores)
+ group_mask.scatter_(1, group_idx, 1)
+ score_mask = (
+ group_mask.unsqueeze(-1)
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
+ .reshape(-1, self.n_routed_experts)
+ )
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
+ return topk_indices
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
+ scores = router_logits.sigmoid()
+ topk_indices = self.get_topk_indices(scores)
+ topk_weights = scores.gather(1, topk_indices)
+ if self.norm_topk_prob:
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
+ topk_weights /= denominator
+ topk_weights = topk_weights * self.routed_scaling_factor
+ return topk_indices, topk_weights
+
+
+class DeepseekV3MoE(nn.Module):
+ """
+ A mixed expert module containing shared experts.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.experts = nn.ModuleList(
+ [
+ DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
+ for _ in range(config.n_routed_experts)
+ ]
+ )
+ self.gate = DeepseekV3TopkRouter(config)
+ self.shared_experts = DeepseekV3MLP(
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
+ )
+
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
+ r"""
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
+ """
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
+ expert_mask = expert_mask.permute(2, 0, 1)
+
+ for expert_idx in range(len(self.experts)):
+ expert = self.experts[expert_idx]
+ mask = expert_mask[expert_idx]
+ token_indices, weight_indices = torch.where(mask)
+
+ if token_indices.numel() > 0:
+ expert_weights = topk_weights[token_indices, weight_indices]
+ expert_input = hidden_states[token_indices]
+ expert_output = expert(expert_input)
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
+
+ # in original deepseek, the output of the experts are gathered once we leave this module
+ # thus the moe module is itelsf an IsolatedParallel module
+ # and all expert are "local" meaning we shard but we don't gather
+ return final_hidden_states.type(hidden_states.dtype)
+
+ def forward(self, hidden_states):
+ residuals = hidden_states
+ orig_shape = hidden_states.shape
+ topk_indices, topk_weights = self.gate(hidden_states)
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
+ hidden_states = hidden_states + self.shared_experts(residuals)
+ return hidden_states
+
+
+class DeepseekV3Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.attention_dropout = config.attention_dropout
+ self.num_heads = config.num_attention_heads
+ self.rope_theta = config.rope_theta
+ self.q_lora_rank = config.q_lora_rank
+ self.qk_rope_head_dim = config.qk_rope_head_dim
+ self.kv_lora_rank = config.kv_lora_rank
+ self.v_head_dim = config.v_head_dim
+ self.qk_nope_head_dim = config.qk_nope_head_dim
+ self.qk_head_dim = config.qk_head_dim
+
+ self.is_causal = True
+ if self.q_lora_rank is None:
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
+ else:
+ self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
+ self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
+ self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
+
+ self.kv_a_proj_with_mqa = nn.Linear(
+ config.hidden_size,
+ self.kv_lora_rank + self.qk_rope_head_dim,
+ bias=config.attention_bias,
+ )
+ self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
+ self.kv_b_proj = nn.Linear(
+ self.kv_lora_rank,
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
+ bias=False,
+ )
+
+ self.o_proj = nn.Linear(
+ self.num_heads * self.v_head_dim,
+ config.hidden_size,
+ bias=config.attention_bias,
+ )
+
+ self.scaling = self.qk_head_dim ** (-0.5)
+ if self.config.rope_scaling is not None:
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
+ scaling_factor = self.config.rope_scaling["factor"]
+ if mscale_all_dim:
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
+ self.scaling = self.scaling * mscale * mscale
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ batch_size, seq_length = hidden_states.shape[:-1]
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
+
+ if self.q_lora_rank is None:
+ q_states = self.q_proj(hidden_states)
+ else:
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
+ q_states = q_states.view(query_shape).transpose(1, 2)
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
+
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
+
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
+
+ cos, sin = position_embeddings
+ if self.config.rope_interleave: # support using interleaved weights for efficiency
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
+ else:
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
+
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
+ value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
+
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class DeepseekV3DecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
+ nn.Module.__init__(self)
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
+
+ if layer_idx >= config.first_k_dense_replace:
+ self.mlp = DeepseekV3MoE(config)
+ else:
+ self.mlp = DeepseekV3MLP(config)
+
+ self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+
+class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
+ _can_compile_fullgraph = False
+
+ def _init_weights(self, module):
+ PreTrainedModel._init_weights(self, module)
+ if isinstance(module, DeepseekV3TopkRouter):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+
+class DeepseekV3Model(LlamaModel):
+ _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"]
+
+
+class DeepseekV3ForCausalLM(LlamaForCausalLM):
+ pass
+
+
+class DeepseekV3ForSequenceClassification(GenericForSequenceClassification, DeepseekV3PreTrainedModel):
+ pass
+
+
+class DeepseekV3ForTokenClassification(GenericForTokenClassification, DeepseekV3PreTrainedModel):
+ pass
+
+
+__all__ = [
+ "DeepseekV3PreTrainedModel",
+ "DeepseekV3Model",
+ "DeepseekV3ForCausalLM",
+ "DeepseekV3ForSequenceClassification",
+ "DeepseekV3ForTokenClassification",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2422b31e31050b378f7373dd6ed73830ee1b9f0d
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_deepseek_vl import *
+ from .image_processing_deepseek_vl import *
+ from .image_processing_deepseek_vl_fast import *
+ from .modeling_deepseek_vl import *
+ from .processing_deepseek_vl import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/configuration_deepseek_vl.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/configuration_deepseek_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3abae5af0a7f7613432e71ef13a1c3b9774176c
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/configuration_deepseek_vl.py
@@ -0,0 +1,97 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/deepseek_vl/modular_deepseek_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_deepseek_vl.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeepseekVLConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DeepseekVLModel`]. It is used to instantiate a
+ DeepseekVL model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DeepseekVL
+ [deepseek-community/deepseek-vl-1.3b-chat](https://huggingface.co/deepseek-community/deepseek-vl-1.3b-chat) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SiglipVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 100015):
+ The index representing image tokens in the model's token vocabulary.
+
+ Example:
+
+ ```python
+ >>> from transformers import DeepseekVLConfig, DeepseekVLModel
+
+ >>> # Initializing a DeepseekVL deepseek-community/deepseek-vl-1.3b-chat style configuration
+ >>> configuration = DeepseekVLConfig()
+
+ >>> # Initializing a model (with random weights) from the deepseek-community/deepseek-vl-1.3b-chat style configuration
+ >>> model = DeepseekVLModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "deepseek_vl"
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
+
+ def __init__(
+ self,
+ text_config: Optional[AutoConfig] = None,
+ vision_config: Optional[AutoConfig] = None,
+ image_token_id: int = 100015,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if text_config is None:
+ text_config = {}
+ logger.info("`text_config` is `None`. Initializing the `LlamaConfig` with default values.")
+
+ if vision_config is None:
+ vision_config = {}
+ logger.info("`vision_config` is `None`. Initializing the `SiglipVisionConfig` with default values.")
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "llama")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "siglip_vision_model")
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+
+ self.text_config = text_config
+ self.vision_config = vision_config
+ self.image_token_id = image_token_id
+
+
+__all__ = ["DeepseekVLConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/image_processing_deepseek_vl.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/image_processing_deepseek_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..12aa7caf892e52262c5d612551a19e8f9bd4f58d
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/image_processing_deepseek_vl.py
@@ -0,0 +1,425 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/deepseek_vl/modular_deepseek_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_deepseek_vl.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeepseekVLImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a DEEPSEEK_VL image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ min_size (`int`, *optional*, defaults to 14):
+ The minimum allowed size for the resized image. Ensures that neither the height nor width
+ falls below this value after resizing.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
+ overridden by the `resample` parameter in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+ `do_rescale` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image to square or not.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ min_size: int = 14,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ do_pad: Optional[bool] = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 384, "width": 384}
+ size = get_size_dict(size, default_to_square=True)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+ self.do_convert_rgb = do_convert_rgb
+
+ self.do_pad = do_pad
+ self.min_size = min_size
+ if image_mean is None:
+ self.background_color = (127, 127, 127)
+ else:
+ self.background_color = tuple(int(x * 255) for x in image_mean)
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Union[dict[str, int], int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to dynamically calculated size.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]` or `int`):
+ The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `None`: will be inferred from input
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+
+ height, width = get_image_size(image, input_data_format)
+ max_size = max(height, width)
+
+ size = get_size_dict(size, default_to_square=True)
+ if size["height"] != size["width"]:
+ raise ValueError(
+ f"Output height and width must be the same. Got height={size['height']} and width={size['width']}"
+ )
+ size = size["height"]
+
+ delta = size / max_size
+ # Largest side becomes `size` and the other side is scaled according to the aspect ratio.
+ output_size_nonpadded = [
+ max(int(height * delta), self.min_size),
+ max(int(width * delta), self.min_size),
+ ]
+
+ image = resize(
+ image,
+ size=output_size_nonpadded,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+ return image
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ background_color: Optional[Union[int, tuple[int, int, int]]] = None,
+ do_pad: Optional[bool] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Controls the size of the image after `resize`. The shortest edge of the image is resized to
+ `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
+ is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
+ edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to normalize the image by if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ background_color (`tuple[int, int, int]`):
+ The background color to use for the padding.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image to square or not.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ background_color = background_color if background_color is not None else self.background_color
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+ # PIL RGBA images are converted to RGB
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if do_resize:
+ images = [
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_pad:
+ # Expand and pad the images to obtain a square image of dimensions `size x size`
+ images = [
+ self.pad_to_square(
+ image=image,
+ background_color=background_color,
+ input_data_format=input_data_format,
+ )
+ for image in images
+ ]
+
+ if do_rescale:
+ images = [
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_normalize:
+ images = [
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+ ]
+
+ encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
+
+ return encoded_outputs
+
+ def pad_to_square(
+ self,
+ image: np.ndarray,
+ background_color: Union[int, tuple[int, int, int]] = 0,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pads an image to a square based on the longest edge.
+
+ Args:
+ image (`np.ndarray`):
+ The image to pad.
+ background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
+ The color to use for the padding. Can be an integer for single channel or a
+ tuple of integers representing for multi-channel images. If passed as integer
+ in multi-channel mode, it will default to `0` in subsequent channels.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The padded image.
+ """
+ height, width = get_image_size(image, input_data_format)
+ num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
+
+ if height == width:
+ image = (
+ to_channel_dimension_format(image, data_format, input_data_format)
+ if data_format is not None
+ else image
+ )
+ return image
+
+ max_dim = max(height, width)
+
+ # Ensure background_color is the correct shape
+ if isinstance(background_color, int):
+ background_color = [background_color]
+ elif len(background_color) != num_channels:
+ raise ValueError(
+ f"background_color must have no more than {num_channels} elements to match the number of channels"
+ )
+
+ if input_data_format == ChannelDimension.FIRST:
+ result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
+ for i, color in enumerate(background_color):
+ result[i, :, :] = color
+ if width > height:
+ start = (max_dim - height) // 2
+ result[:, start : start + height, :] = image
+ else:
+ start = (max_dim - width) // 2
+ result[:, :, start : start + width] = image
+ else:
+ result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
+ for i, color in enumerate(background_color):
+ result[:, :, i] = color
+ if width > height:
+ start = (max_dim - height) // 2
+ result[start : start + height, :, :] = image
+ else:
+ start = (max_dim - width) // 2
+ result[:, start : start + width, :] = image
+
+ return result
+
+
+__all__ = ["DeepseekVLImageProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..896e91f0692c13f99d53ca475ddbebb60dd8affe
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py
@@ -0,0 +1,193 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/deepseek_vl/modular_deepseek_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_deepseek_vl.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+import torch.nn.functional as F
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling, SizeDict
+from ...processing_utils import Unpack
+from ...utils import TensorType, auto_docstring
+
+
+class DeepseekVLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ r"""
+ min_size (`int`, *optional*, defaults to 14):
+ The minimum allowed size for the resized image. Ensures that neither the height nor width
+ falls below this value after resizing.
+ """
+
+ min_size: int
+
+
+@auto_docstring
+class DeepseekVLImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ size = {"height": 384, "width": 384}
+ min_size = 14
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_pad = True
+ valid_kwargs = DeepseekVLFastImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[DeepseekVLFastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+ if kwargs.get("image_mean") is None:
+ background_color = (127, 127, 127)
+ else:
+ background_color = tuple(int(x * 255) for x in kwargs.get("image_mean"))
+ self.background_color = tuple(background_color)
+
+ def resize(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ min_size: int,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ antialias: bool = True,
+ **kwargs,
+ ) -> "torch.Tensor":
+ if size.height is None or size.width is None or size.height != size.width:
+ raise ValueError(
+ f"Output height and width must be the same. Got height={size['height']} and width={size['width']}"
+ )
+ size = size.height
+
+ height, width = image.shape[-2:]
+ max_size = max(height, width)
+
+ delta = size / max_size
+ # Largest side becomes `size` and the other side is scaled according to the aspect ratio.
+ output_size_nonpadded = SizeDict(
+ height=max(int(height * delta), min_size),
+ width=max(int(width * delta), min_size),
+ )
+
+ return super().resize(image, size=output_size_nonpadded, interpolation=interpolation, antialias=antialias)
+
+ def pad_to_square(
+ self,
+ images: "torch.Tensor",
+ background_color: Union[int, tuple[int, int, int]] = 0,
+ ) -> "torch.Tensor":
+ """
+ Pads an image to a square based on the longest edge.
+
+ Args:
+ images (`torch.Tensor`):
+ The images to pad.
+ background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
+ The color to use for the padding. Can be an integer for single channel or a
+ tuple of integers representing for multi-channel images. If passed as integer
+ in multi-channel mode, it will default to `0` in subsequent channels.
+
+ Returns:
+ `torch.Tensor`: The padded images.
+ """
+ height, width = images.shape[-2:]
+ num_channels = images.shape[1]
+ batch_size = images.shape[0]
+
+ if height == width:
+ return images
+
+ max_dim = max(height, width)
+
+ # Ensure background_color is the correct shape
+ if isinstance(background_color, int):
+ background_color = [background_color]
+ elif len(background_color) != num_channels:
+ raise ValueError(
+ f"background_color must have no more than {num_channels} elements to match the number of channels"
+ )
+
+ padded_images = torch.zeros(
+ (batch_size, num_channels, max_dim, max_dim), dtype=images.dtype, device=images.device
+ )
+ for i, color in enumerate(background_color):
+ padded_images[:, i, :, :] = color
+ if width > height:
+ start = (max_dim - height) // 2
+ padded_images[:, :, start : start + height, :] = images
+ else:
+ start = (max_dim - width) // 2
+ padded_images[:, :, :, start : start + width] = images
+
+ return padded_images
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ min_size: int,
+ interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ do_pad: bool = True,
+ **kwargs,
+ ) -> BatchFeature:
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(
+ image=stacked_images, size=size, min_size=min_size, interpolation=interpolation
+ )
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_pad:
+ stacked_images = self.pad_to_square(stacked_images, background_color=self.background_color)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+
+__all__ = ["DeepseekVLImageProcessorFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/modeling_deepseek_vl.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/modeling_deepseek_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce884da8d08bb8680152bd2a58e537f876919f0f
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/modeling_deepseek_vl.py
@@ -0,0 +1,350 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/deepseek_vl/modular_deepseek_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_deepseek_vl.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...modeling_outputs import ModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ..auto import AutoModel
+from .configuration_deepseek_vl import DeepseekVLConfig
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for DeepseekVL model's outputs that may also contain a past key/values (to speed up sequential decoding).
+ """
+)
+class DeepseekVLBaseModelOutputWithPast(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+ sequence_length, hidden_size)`.
+
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for DeepseekVL causal language model (or autoregressive) outputs.
+ """
+)
+class DeepseekVLCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+ sequence_length, hidden_size)`.
+
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+class DeepseekVLAligner(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ in_features = config.vision_config.hidden_size
+ out_features = config.text_config.hidden_size
+
+ self.linear1 = nn.Linear(in_features, out_features)
+ self.activation = nn.GELU()
+ self.linear2 = nn.Linear(out_features, out_features)
+
+ def forward(self, vision_encodings: torch.Tensor) -> torch.Tensor:
+ x = self.linear1(vision_encodings)
+ x = self.activation(x)
+ x = self.linear2(x)
+ return x
+
+
+@auto_docstring
+class DeepseekVLPreTrainedModel(PreTrainedModel):
+ config: DeepseekVLConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values", "causal_mask"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_param_buffer_assignment = False
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ # Required only for Linear layer in DeepseekVLAligner
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.text_config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+@auto_docstring
+class DeepseekVLModel(DeepseekVLPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.vision_model = AutoModel.from_config(config.vision_config)
+ self.aligner = DeepseekVLAligner(config)
+
+ self.language_model = AutoModel.from_config(config=config.text_config)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing.
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def get_image_features(self, pixel_values):
+ image_embeds = self.vision_model(pixel_values)
+ image_embeds = self.aligner(image_embeds.last_hidden_state)
+ return image_embeds
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ return special_image_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ):
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_embeds = self.get_image_features(pixel_values)
+ image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ image_attention_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
+
+ lm_output = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ return DeepseekVLBaseModelOutputWithPast(
+ last_hidden_state=lm_output.last_hidden_state,
+ past_key_values=lm_output.past_key_values,
+ hidden_states=lm_output.hidden_states,
+ attentions=lm_output.attentions,
+ image_hidden_states=image_embeds if pixel_values is not None else None,
+ )
+
+
+class DeepseekVLForConditionalGeneration(DeepseekVLPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
+ _can_compile_fullgraph = True
+
+ def __init__(self, config: DeepseekVLConfig):
+ super().__init__(config)
+ self.config = config
+ self.model = DeepseekVLModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing.
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.language_model.set_input_embeddings(value)
+
+ def prepare_embeddings_for_image_generation(self) -> torch.Tensor:
+ raise AttributeError("Not needed for DeepseekVL")
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ """
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return DeepseekVLCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ pixel_values=None,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- extra custom processing
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ if cache_position[0] == 0:
+ model_inputs["pixel_values"] = pixel_values
+
+ return model_inputs
+
+
+__all__ = ["DeepseekVLPreTrainedModel", "DeepseekVLModel", "DeepseekVLForConditionalGeneration"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/modular_deepseek_vl.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/modular_deepseek_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bfc0ae7d74c3b48f9a009edcc7daba60ccc3382
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/modular_deepseek_vl.py
@@ -0,0 +1,333 @@
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import PretrainedConfig
+from ...image_processing_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import (
+ PreTokenizedInput,
+ TextInput,
+)
+from ...utils import (
+ auto_docstring,
+ logging,
+)
+from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel
+from ..idefics.modeling_idefics import IdeficsBaseModelOutputWithPast, IdeficsCausalLMOutputWithPast
+from ..janus.image_processing_janus import JanusImageProcessor
+from ..janus.image_processing_janus_fast import JanusImageProcessorFast
+from ..janus.modeling_janus import JanusForConditionalGeneration, JanusModel, JanusPreTrainedModel
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeepseekVLConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DeepseekVLModel`]. It is used to instantiate a
+ DeepseekVL model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DeepseekVL
+ [deepseek-community/deepseek-vl-1.3b-chat](https://huggingface.co/deepseek-community/deepseek-vl-1.3b-chat) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SiglipVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 100015):
+ The index representing image tokens in the model's token vocabulary.
+
+ Example:
+
+ ```python
+ >>> from transformers import DeepseekVLConfig, DeepseekVLModel
+
+ >>> # Initializing a DeepseekVL deepseek-community/deepseek-vl-1.3b-chat style configuration
+ >>> configuration = DeepseekVLConfig()
+
+ >>> # Initializing a model (with random weights) from the deepseek-community/deepseek-vl-1.3b-chat style configuration
+ >>> model = DeepseekVLModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "deepseek_vl"
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
+
+ def __init__(
+ self,
+ text_config: Optional[AutoConfig] = None,
+ vision_config: Optional[AutoConfig] = None,
+ image_token_id: int = 100015,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if text_config is None:
+ text_config = {}
+ logger.info("`text_config` is `None`. Initializing the `LlamaConfig` with default values.")
+
+ if vision_config is None:
+ vision_config = {}
+ logger.info("`vision_config` is `None`. Initializing the `SiglipVisionConfig` with default values.")
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "llama")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "siglip_vision_model")
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+
+ self.text_config = text_config
+ self.vision_config = vision_config
+ self.image_token_id = image_token_id
+
+
+class DeepseekVLBaseModelOutputWithPast(IdeficsBaseModelOutputWithPast):
+ pass
+
+
+class DeepseekVLCausalLMOutputWithPast(IdeficsCausalLMOutputWithPast):
+ pass
+
+
+class DeepseekVLAligner(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ in_features = config.vision_config.hidden_size
+ out_features = config.text_config.hidden_size
+
+ self.linear1 = nn.Linear(in_features, out_features)
+ self.activation = nn.GELU()
+ self.linear2 = nn.Linear(out_features, out_features)
+
+ def forward(self, vision_encodings: torch.Tensor) -> torch.Tensor:
+ x = self.linear1(vision_encodings)
+ x = self.activation(x)
+ x = self.linear2(x)
+ return x
+
+
+class DeepseekVLPreTrainedModel(JanusPreTrainedModel):
+ _no_split_modules = ["LlamaDecoderLayer"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ # Required only for Linear layer in DeepseekVLAligner
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.text_config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+@auto_docstring
+class DeepseekVLModel(JanusModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.vision_model = AutoModel.from_config(config.vision_config)
+ self.aligner = DeepseekVLAligner(config)
+
+ self.language_model = AutoModel.from_config(config=config.text_config)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing.
+ self.post_init()
+
+ del self.vqmodel
+ del self.generation_embeddings
+ del self.generation_aligner
+ del self.generation_head
+
+
+class DeepseekVLForConditionalGeneration(JanusForConditionalGeneration):
+ def prepare_embeddings_for_image_generation(self):
+ raise AttributeError("Not needed for DeepseekVL")
+
+ def decode_image_tokens(self):
+ raise AttributeError("Not needed for DeepseekVL")
+
+ def generate(self):
+ raise AttributeError("Not needed for DeepseekVL")
+
+
+class DeepseekVLImageProcessor(JanusImageProcessor):
+ def __init__(self, **super_kwargs):
+ super().__init__(**super_kwargs)
+
+ def postprocess(self):
+ raise AttributeError("Not needed for DeepseekVL")
+
+ def unnormalize(self):
+ raise AttributeError("Not needed for DeepseekVL")
+
+
+class DeepseekVLImageProcessorFast(JanusImageProcessorFast):
+ def __init__(self, **super_kwargs):
+ super().__init__(**super_kwargs)
+
+ def postprocess(self):
+ raise AttributeError("Not needed for DeepseekVL")
+
+
+class DeepseekVLProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {"padding": False},
+ "common_kwargs": {"return_tensors": "pt"},
+ }
+
+
+class DeepseekVLProcessor(ProcessorMixin):
+ r"""
+ Constructs a DeepseekVL processor which wraps a DeepseekVL Image Processor and a Llama tokenizer into a single processor.
+
+ [`DeepseekVLProcessor`] offers all the functionalities of [`DeepseekVLImageProcessor`] and [`LlamaTokenizerFast`]. See the
+ [`~DeepseekVLProcessor.__call__`] and [`~DeepseekVLProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`DeepseekVLImageProcessor`]):
+ The image processor is a required input.
+ tokenizer ([`LlamaTokenizerFast`]):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*):
+ A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ num_image_tokens (`int`, *optional*, defaults to 576):
+ The number of special image tokens used as placeholders for visual content in text sequences.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = ["chat_template", "num_image_tokens"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ image_processor,
+ tokenizer,
+ chat_template=None,
+ num_image_tokens=576,
+ ):
+ self.image_token = tokenizer.image_token
+ self.num_image_tokens = num_image_tokens
+
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ images: Optional[ImageInput] = None,
+ **kwargs: Unpack[DeepseekVLProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
+ DeepseekVLImageProcessor's [`~DeepseekVLImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
+ of the above two methods for more information.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ DeepseekVLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs
+ )
+ if text is None and images is None:
+ raise ValueError("You must specify either text or images.")
+
+ if text is not None:
+ if isinstance(text, str):
+ text = [text]
+ elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+
+ prompt_strings = []
+ one_img_tokens = self.image_token * self.num_image_tokens
+ for prompt in text:
+ prompt = prompt.replace(self.image_token, one_img_tokens)
+ prompt_strings.append(prompt)
+
+ data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
+
+ # process images if pixel_values are provided
+ if images is not None:
+ data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
+
+ return BatchFeature(data=data)
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
+
+__all__ = [
+ "DeepseekVLConfig",
+ "DeepseekVLPreTrainedModel",
+ "DeepseekVLModel",
+ "DeepseekVLForConditionalGeneration",
+ "DeepseekVLImageProcessor",
+ "DeepseekVLImageProcessorFast",
+ "DeepseekVLProcessor",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/processing_deepseek_vl.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/processing_deepseek_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..26d59d85a2955022c2a428509dcd29eab5c7c5d4
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl/processing_deepseek_vl.py
@@ -0,0 +1,156 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/deepseek_vl/modular_deepseek_vl.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_deepseek_vl.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+from ...image_processing_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+
+
+class DeepseekVLProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {"padding": False},
+ "common_kwargs": {"return_tensors": "pt"},
+ }
+
+
+class DeepseekVLProcessor(ProcessorMixin):
+ r"""
+ Constructs a DeepseekVL processor which wraps a DeepseekVL Image Processor and a Llama tokenizer into a single processor.
+
+ [`DeepseekVLProcessor`] offers all the functionalities of [`DeepseekVLImageProcessor`] and [`LlamaTokenizerFast`]. See the
+ [`~DeepseekVLProcessor.__call__`] and [`~DeepseekVLProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`DeepseekVLImageProcessor`]):
+ The image processor is a required input.
+ tokenizer ([`LlamaTokenizerFast`]):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*):
+ A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ num_image_tokens (`int`, *optional*, defaults to 576):
+ The number of special image tokens used as placeholders for visual content in text sequences.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = ["chat_template", "num_image_tokens"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ image_processor,
+ tokenizer,
+ chat_template=None,
+ num_image_tokens=576,
+ ):
+ self.image_token = tokenizer.image_token
+ self.num_image_tokens = num_image_tokens
+
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ images: Optional[ImageInput] = None,
+ **kwargs: Unpack[DeepseekVLProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
+ DeepseekVLImageProcessor's [`~DeepseekVLImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
+ of the above two methods for more information.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ DeepseekVLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs
+ )
+ if text is None and images is None:
+ raise ValueError("You must specify either text or images.")
+
+ if text is not None:
+ if isinstance(text, str):
+ text = [text]
+ elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+
+ prompt_strings = []
+ one_img_tokens = self.image_token * self.num_image_tokens
+ for prompt in text:
+ prompt = prompt.replace(self.image_token, one_img_tokens)
+ prompt_strings.append(prompt)
+
+ data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
+
+ # process images if pixel_values are provided
+ if images is not None:
+ data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
+
+ return BatchFeature(data=data)
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
+
+__all__ = ["DeepseekVLProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..da85178ccc84d9385fd3473a25eb51757dc9d34b
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_deepseek_vl_hybrid import *
+ from .image_processing_deepseek_vl_fast_hybrid import *
+ from .image_processing_deepseek_vl_hybrid import *
+ from .image_processing_deepseek_vl_hybrid_fast import *
+ from .modeling_deepseek_vl_hybrid import *
+ from .processing_deepseek_vl_hybrid import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8c6e2df6ea31b295fa4fd277a9e7257e976ca8f
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py
@@ -0,0 +1,110 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_deepseek_vl_hybrid.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeepseekVLHybridConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DeepseekVLHybridModel`]. It is used to instantiate a
+ DeepseekVLHybrid model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DeepseekVLHybrid
+ [deepseek-community/deepseek-vl-7b-chat](https://huggingface.co/deepseek-community/deepseek-vl-7b-chat) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SiglipVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ high_res_vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SamVisionConfig`):
+ The config object or dictionary of the high resolution vision backbone.
+ image_token_id (`int`, *optional*, defaults to 100015):
+ The index representing image tokens in the model's token vocabulary.
+
+ Example:
+
+ ```python
+ >>> from transformers import DeepseekVLHybridConfig, DeepseekVLHybridModel
+
+ >>> # Initializing a DeepseekVLHybrid deepseek-community/deepseek-vl-7b-chat style configuration
+ >>> configuration = DeepseekVLHybridConfig()
+
+ >>> # Initializing a model (with random weights) from the deepseek-community/deepseek-vl-7b-chat style configuration
+ >>> model = DeepseekVLHybridModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "deepseek_vl_hybrid"
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig, "high_res_vision_config": AutoConfig}
+
+ def __init__(
+ self,
+ text_config: Optional[AutoConfig] = None,
+ vision_config: Optional[AutoConfig] = None,
+ high_res_vision_config: Optional[AutoConfig] = None,
+ image_token_id: int = 100015,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if text_config is None:
+ text_config = {}
+ logger.info("`text_config` is `None`. Initializing the `LlamaConfig` with default values.")
+
+ if vision_config is None:
+ vision_config = {}
+ logger.info("`vision_config` is `None`. Initializing the `SiglipVisionConfig` with default values.")
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "llama")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "siglip_vision_model")
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+
+ self.text_config = text_config
+ self.vision_config = vision_config
+ self.image_token_id = image_token_id
+
+ if high_res_vision_config is None:
+ high_res_vision_config = {}
+ logger.info("`high_res_vision_config` is `None`. Initializing the `SamVisionConfig` with default values.")
+
+ if isinstance(high_res_vision_config, dict):
+ high_res_vision_config["model_type"] = high_res_vision_config.get("model_type", "sam_vision_model")
+ high_res_vision_config = CONFIG_MAPPING[high_res_vision_config["model_type"]](**high_res_vision_config)
+
+ self.high_res_vision_config = high_res_vision_config
+
+
+__all__ = ["DeepseekVLHybridConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py
new file mode 100644
index 0000000000000000000000000000000000000000..865e13fa964fac4cc3d630ca80ada376a3556cf1
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py
@@ -0,0 +1,498 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_deepseek_vl_hybrid.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor
+from ...image_processing_utils_fast import BatchFeature, get_size_dict
+from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeepseekVLHybridImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a DEEPSEEK_VL_HYBRID image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ high_res_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`):
+ Size of the high resolution output image after resizing. Can be overridden by the `high_res_size` parameter in the `preprocess`
+ method.
+ min_size (`int`, *optional*, defaults to 14):
+ The minimum allowed size for the resized image. Ensures that neither the height nor width
+ falls below this value after resizing.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
+ overridden by the `resample` parameter in the `preprocess` method.
+ high_res_resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
+ overridden by the `high_res_resample` parameter in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+ `do_rescale` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ high_res_image_mean (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_MEAN`):
+ Mean to use if normalizing the high resolution image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `high_res_image_mean` parameter in the `preprocess` method.
+ high_res_image_std (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_STD`):
+ Standard deviation to use if normalizing the high resolution image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `high_res_image_std` parameter in the `preprocess` method.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image to square or not.
+ """
+
+ model_input_names = ["pixel_values", "high_res_pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ high_res_size: Optional[dict[str, int]] = None,
+ min_size: int = 14,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ high_res_resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ high_res_image_mean: Optional[Union[float, list[float]]] = None,
+ high_res_image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ do_pad: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ high_res_size = high_res_size if high_res_size is not None else {"height": 1024, "width": 1024}
+ high_res_size = get_size_dict(high_res_size, default_to_square=True)
+
+ self.high_res_size = high_res_size
+ self.high_res_image_mean = high_res_image_mean if high_res_image_mean is not None else OPENAI_CLIP_MEAN
+ self.high_res_image_std = high_res_image_std if high_res_image_std is not None else OPENAI_CLIP_STD
+
+ self.resample = resample
+ self.high_res_resample = high_res_resample
+ size = size if size is not None else {"height": 384, "width": 384}
+ size = get_size_dict(size, default_to_square=True)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+ self.do_convert_rgb = do_convert_rgb
+
+ self.do_pad = do_pad
+ self.min_size = min_size
+ if image_mean is None:
+ self.background_color = (127, 127, 127)
+ else:
+ self.background_color = tuple(int(x * 255) for x in image_mean)
+
+ if high_res_image_mean is None:
+ self.high_res_background_color = (127, 127, 127)
+ else:
+ self.high_res_background_color = tuple(int(x * 255) for x in high_res_image_mean)
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Union[dict[str, int], int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to dynamically calculated size.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]` or `int`):
+ The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `None`: will be inferred from input
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+
+ height, width = get_image_size(image, input_data_format)
+ max_size = max(height, width)
+
+ size = get_size_dict(size, default_to_square=True)
+ if size["height"] != size["width"]:
+ raise ValueError(
+ f"Output height and width must be the same. Got height={size['height']} and width={size['width']}"
+ )
+ size = size["height"]
+
+ delta = size / max_size
+ # Largest side becomes `size` and the other side is scaled according to the aspect ratio.
+ output_size_nonpadded = [
+ max(int(height * delta), self.min_size),
+ max(int(width * delta), self.min_size),
+ ]
+
+ image = resize(
+ image,
+ size=output_size_nonpadded,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+ return image
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ high_res_size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ high_res_resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ high_res_image_mean: Optional[Union[float, list[float]]] = None,
+ high_res_image_std: Optional[Union[float, list[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ do_pad: Optional[bool] = None,
+ background_color: Optional[tuple[int, int, int]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
+ resizing.
+ high_res_size (`Dict[str, int]`, *optional*, defaults to `self.high_res_size`):
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the high resolution output image after
+ resizing.
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
+ an effect if `do_resize` is set to `True`.
+ high_res_resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BICUBIC`. Only has
+ an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use if `do_normalize` is set to `True`.
+ high_res_image_mean (`float` or `List[float]`, *optional*, defaults to `self.high_res_image_mean`):
+ Image mean to use if `do_normalize` is set to `True`.
+ high_res_image_std (`float` or `List[float]`, *optional*, defaults to `self.high_res_image_std`):
+ Image standard deviation to use if `do_normalize` is set to `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image to square or not.
+ background_color (`tuple[int, int, int]`):
+ The background color to use for the padding.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ resample = resample if resample is not None else self.resample
+ high_res_resample = high_res_resample if high_res_resample is not None else self.high_res_resample
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ high_res_image_mean = high_res_image_mean if high_res_image_mean is not None else self.high_res_image_mean
+ high_res_image_std = high_res_image_std if high_res_image_std is not None else self.high_res_image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ background_color = background_color if background_color is not None else self.background_color
+
+ size = size if size is not None else self.size
+ size_dict = get_size_dict(size)
+ high_res_size = high_res_size if high_res_size is not None else self.high_res_size
+ high_res_size_dict = get_size_dict(high_res_size)
+
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ all_images = []
+ all_high_res_images = []
+ for image in images:
+ # high_res_image: resize (high) -> rescale -> normalize (high)
+ # low_res_image: resize (high) -> rescale -> resize (low) -> normalize (low)
+ high_res_image = image
+ if do_resize:
+ high_res_image = self.resize(
+ image=high_res_image,
+ size=high_res_size_dict,
+ resample=high_res_resample,
+ input_data_format=input_data_format,
+ )
+ if do_pad:
+ # Expand and pad the images to obtain a square image of dimensions `size x size`
+ high_res_image = self.pad_to_square(
+ image=high_res_image,
+ background_color=background_color,
+ input_data_format=input_data_format,
+ )
+ image = self.resize(
+ image=high_res_image,
+ size=size_dict,
+ resample=resample,
+ input_data_format=input_data_format,
+ )
+ if do_pad:
+ image = self.pad_to_square(
+ image=image,
+ background_color=background_color,
+ input_data_format=input_data_format,
+ )
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+ high_res_image = self.rescale(
+ image=high_res_image, scale=rescale_factor, input_data_format=input_data_format
+ )
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+ high_res_image = self.normalize(
+ image=high_res_image,
+ mean=high_res_image_mean,
+ std=high_res_image_std,
+ input_data_format=input_data_format,
+ )
+
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ high_res_image = to_channel_dimension_format(
+ high_res_image, data_format, input_channel_dim=input_data_format
+ )
+
+ all_images.append(image)
+ all_high_res_images.append(high_res_image)
+
+ data = {"pixel_values": all_images, "high_res_pixel_values": all_high_res_images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def pad_to_square(
+ self,
+ image: np.ndarray,
+ background_color: Union[int, tuple[int, int, int]] = 0,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pads an image to a square based on the longest edge.
+
+ Args:
+ image (`np.ndarray`):
+ The image to pad.
+ background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
+ The color to use for the padding. Can be an integer for single channel or a
+ tuple of integers representing for multi-channel images. If passed as integer
+ in multi-channel mode, it will default to `0` in subsequent channels.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The padded image.
+ """
+ height, width = get_image_size(image, input_data_format)
+ num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
+
+ if height == width:
+ image = (
+ to_channel_dimension_format(image, data_format, input_data_format)
+ if data_format is not None
+ else image
+ )
+ return image
+
+ max_dim = max(height, width)
+
+ # Ensure background_color is the correct shape
+ if isinstance(background_color, int):
+ background_color = [background_color]
+ elif len(background_color) != num_channels:
+ raise ValueError(
+ f"background_color must have no more than {num_channels} elements to match the number of channels"
+ )
+
+ if input_data_format == ChannelDimension.FIRST:
+ result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
+ for i, color in enumerate(background_color):
+ result[i, :, :] = color
+ if width > height:
+ start = (max_dim - height) // 2
+ result[:, start : start + height, :] = image
+ else:
+ start = (max_dim - width) // 2
+ result[:, :, start : start + width] = image
+ else:
+ result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
+ for i, color in enumerate(background_color):
+ result[:, :, i] = color
+ if width > height:
+ start = (max_dim - height) // 2
+ result[start : start + height, :, :] = image
+ else:
+ start = (max_dim - width) // 2
+ result[:, start : start + width, :] = image
+
+ return result
+
+
+__all__ = ["DeepseekVLHybridImageProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..c04e006e358d5fcc157cedabc3ba18137cf2eea6
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py
@@ -0,0 +1,327 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_deepseek_vl_hybrid.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ BatchFeature,
+ DefaultFastImageProcessorKwargs,
+ get_size_dict,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ PILImageResampling,
+ SizeDict,
+ pil_torch_interpolation_mapping,
+)
+from ...processing_utils import Unpack
+from ...utils import TensorType, auto_docstring
+
+
+class DeepseekVLHybridFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ r"""
+ min_size (`int`, *optional*, defaults to 14):
+ The minimum allowed size for the resized image. Ensures that neither the height nor width
+ falls below this value after resizing.
+ high_res_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`):
+ Size of the high resolution output image after resizing. Can be overridden by the `high_res_size` parameter in the `preprocess`
+ method.
+ high_res_resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
+ overridden by the `high_res_resample` parameter in the `preprocess` method.
+ high_res_image_mean (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_MEAN`):
+ Mean to use if normalizing the high resolution image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `high_res_image_mean` parameter in the `preprocess` method.
+ high_res_image_std (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_STD`):
+ Standard deviation to use if normalizing the high resolution image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `high_res_image_std` parameter in the `preprocess` method.
+ """
+
+ min_size: int
+ high_res_size: dict
+ high_res_resample: "PILImageResampling"
+ high_res_image_mean: list[float]
+ high_res_image_std: list[float]
+
+
+@auto_docstring
+class DeepseekVLHybridImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ size = {"height": 384, "width": 384}
+ min_size = 14
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_pad = True
+ valid_kwargs = DeepseekVLHybridFastImageProcessorKwargs
+ high_res_image_mean = OPENAI_CLIP_MEAN
+ high_res_image_std = OPENAI_CLIP_STD
+ high_res_size = {"height": 1024, "width": 1024}
+ high_res_resample = PILImageResampling.BICUBIC
+ model_input_names = ["pixel_values", "high_res_pixel_values"]
+
+ def __init__(self, **kwargs: Unpack[DeepseekVLHybridFastImageProcessorKwargs]):
+ if kwargs.get("image_mean") is None:
+ background_color = (127, 127, 127)
+ else:
+ background_color = tuple([int(x * 255) for x in kwargs.get("image_mean")])
+ if kwargs.get("high_res_image_mean") is None:
+ high_res_background_color = (127, 127, 127)
+ else:
+ high_res_background_color = tuple(int(x * 255) for x in kwargs.get("high_res_image_mean"))
+ super().__init__(**kwargs)
+ self.background_color = tuple(background_color)
+ self.high_res_background_color = tuple(high_res_background_color)
+
+ def resize(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ min_size: int,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ antialias: bool = True,
+ **kwargs,
+ ) -> "torch.Tensor":
+ if size.height is None or size.width is None or size.height != size.width:
+ raise ValueError(
+ f"Output height and width must be the same. Got height={size['height']} and width={size['width']}"
+ )
+ size = size.height
+
+ height, width = image.shape[-2:]
+ max_size = max(height, width)
+
+ delta = size / max_size
+ # Largest side becomes `size` and the other side is scaled according to the aspect ratio.
+ output_size_nonpadded = SizeDict(
+ height=max(int(height * delta), min_size),
+ width=max(int(width * delta), min_size),
+ )
+
+ return super().resize(image, size=output_size_nonpadded, interpolation=interpolation, antialias=antialias)
+
+ def pad_to_square(
+ self,
+ images: "torch.Tensor",
+ background_color: Union[int, tuple[int, int, int]] = 0,
+ ) -> "torch.Tensor":
+ """
+ Pads an image to a square based on the longest edge.
+
+ Args:
+ images (`torch.Tensor`):
+ The images to pad.
+ background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
+ The color to use for the padding. Can be an integer for single channel or a
+ tuple of integers representing for multi-channel images. If passed as integer
+ in multi-channel mode, it will default to `0` in subsequent channels.
+
+ Returns:
+ `torch.Tensor`: The padded images.
+ """
+ height, width = images.shape[-2:]
+ num_channels = images.shape[1]
+ batch_size = images.shape[0]
+
+ if height == width:
+ return images
+
+ max_dim = max(height, width)
+
+ # Ensure background_color is the correct shape
+ if isinstance(background_color, int):
+ background_color = [background_color]
+ elif len(background_color) != num_channels:
+ raise ValueError(
+ f"background_color must have no more than {num_channels} elements to match the number of channels"
+ )
+
+ padded_images = torch.zeros(
+ (batch_size, num_channels, max_dim, max_dim), dtype=images.dtype, device=images.device
+ )
+ for i, color in enumerate(background_color):
+ padded_images[:, i, :, :] = color
+ if width > height:
+ start = (max_dim - height) // 2
+ padded_images[:, :, start : start + height, :] = images
+ else:
+ start = (max_dim - width) // 2
+ padded_images[:, :, :, start : start + width] = images
+
+ return padded_images
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ high_res_size: SizeDict,
+ min_size: int,
+ interpolation: Optional["F.InterpolationMode"],
+ high_res_interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ high_res_image_mean: Optional[Union[float, list[float]]],
+ high_res_image_std: Optional[Union[float, list[float]]],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ do_pad: bool = True,
+ **kwargs,
+ ) -> BatchFeature:
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ high_res_resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_high_res_images = self.resize(
+ image=stacked_images, size=high_res_size, min_size=min_size, interpolation=high_res_interpolation
+ )
+ high_res_resized_images_grouped[shape] = stacked_high_res_images
+ high_res_resized_images = reorder_images(high_res_resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_high_res_images, grouped_high_res_images_index = group_images_by_shape(
+ high_res_resized_images, disable_grouping=disable_grouping
+ )
+ high_res_padded_images = {}
+ high_res_processed_images_grouped = {}
+ for shape, stacked_high_res_images in grouped_high_res_images.items():
+ if do_pad:
+ stacked_high_res_images = self.pad_to_square(
+ stacked_high_res_images, background_color=self.high_res_background_color
+ )
+ high_res_padded_images[shape] = stacked_high_res_images
+ # Fused rescale and normalize
+ stacked_high_res_images = self.rescale_and_normalize(
+ stacked_high_res_images,
+ do_rescale,
+ rescale_factor,
+ do_normalize,
+ high_res_image_mean,
+ high_res_image_std,
+ )
+ high_res_processed_images_grouped[shape] = stacked_high_res_images
+ high_res_processed_images = reorder_images(high_res_processed_images_grouped, grouped_high_res_images_index)
+ high_res_processed_images = (
+ torch.stack(high_res_processed_images, dim=0) if return_tensors else high_res_processed_images
+ )
+
+ resized_images_grouped = {}
+ for shape, stacked_high_res_padded_images in high_res_padded_images.items():
+ if do_resize:
+ stacked_images = self.resize(
+ image=stacked_high_res_padded_images, size=size, min_size=min_size, interpolation=interpolation
+ )
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_high_res_images_index)
+
+ grouped_resized_images, grouped_resized_images_index = group_images_by_shape(
+ resized_images, disable_grouping=disable_grouping
+ )
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_resized_images.items():
+ if do_pad:
+ stacked_images = self.pad_to_square(stacked_images, background_color=self.background_color)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+ processed_images = reorder_images(processed_images_grouped, grouped_resized_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(
+ data={"pixel_values": processed_images, "high_res_pixel_values": high_res_processed_images},
+ tensor_type=return_tensors,
+ )
+
+ def _further_process_kwargs(
+ self,
+ size: Optional[SizeDict] = None,
+ high_res_size: Optional[SizeDict] = None,
+ default_to_square: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ high_res_image_mean: Optional[Union[float, list[float]]] = None,
+ high_res_image_std: Optional[Union[float, list[float]]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ **kwargs,
+ ) -> dict:
+ """
+ Update kwargs that need further processing before being validated
+ Can be overridden by subclasses to customize the processing of kwargs.
+ """
+ if kwargs is None:
+ kwargs = {}
+ if size is not None:
+ size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
+ if high_res_size is not None:
+ high_res_size = SizeDict(**get_size_dict(size=high_res_size, default_to_square=default_to_square))
+ if isinstance(image_mean, list):
+ image_mean = tuple(image_mean)
+ if isinstance(image_std, list):
+ image_std = tuple(image_std)
+ if isinstance(high_res_image_mean, list):
+ high_res_image_mean = tuple(high_res_image_mean)
+ if isinstance(high_res_image_std, list):
+ high_res_image_std = tuple(high_res_image_std)
+ if data_format is None:
+ data_format = ChannelDimension.FIRST
+
+ high_res_resample = kwargs.pop("high_res_resample")
+ kwargs["high_res_interpolation"] = (
+ pil_torch_interpolation_mapping[high_res_resample]
+ if isinstance(high_res_resample, (int, PILImageResampling))
+ else high_res_resample
+ )
+
+ low_res_resample = kwargs.pop("resample")
+ kwargs["interpolation"] = (
+ pil_torch_interpolation_mapping[low_res_resample]
+ if isinstance(low_res_resample, (int, PILImageResampling))
+ else low_res_resample
+ )
+
+ kwargs["size"] = size
+ kwargs["high_res_size"] = high_res_size
+ kwargs["image_mean"] = image_mean
+ kwargs["image_std"] = image_std
+ kwargs["high_res_image_mean"] = high_res_image_mean
+ kwargs["high_res_image_std"] = high_res_image_std
+ kwargs["data_format"] = data_format
+
+ return kwargs
+
+
+__all__ = ["DeepseekVLHybridImageProcessorFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9a85654e901576c61c83c6b3ce3c3ea54d3aeaf
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py
@@ -0,0 +1,497 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_deepseek_vl_hybrid.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...modeling_outputs import ModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ..auto import AutoModel
+from .configuration_deepseek_vl_hybrid import DeepseekVLHybridConfig
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for DeepseekVLHybrid model's outputs that may also contain a past key/values (to speed up sequential decoding).
+ """
+)
+class DeepseekVLHybridBaseModelOutputWithPast(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+ sequence_length, hidden_size)`.
+
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for DeepseekVLHybrid causal language model (or autoregressive) outputs.
+ """
+)
+class DeepseekVLHybridCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+ sequence_length, hidden_size)`.
+
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+class DeepseekVLHybridLayerNorm(nn.LayerNorm):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
+ super().__init__(normalized_shape, eps=eps, **kwargs)
+ if data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError(f"Unsupported data format: {data_format}")
+ self.data_format = data_format
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
+ """
+ if self.data_format == "channels_first":
+ features = features.permute(0, 2, 3, 1)
+ features = super().forward(features)
+ features = features.permute(0, 3, 1, 2)
+ else:
+ features = super().forward(features)
+ return features
+
+
+class DeepseekVLSamVisionNeck(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
+ self.layer_norm1 = DeepseekVLHybridLayerNorm(config.output_channels, data_format="channels_first")
+ self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False)
+ self.layer_norm2 = DeepseekVLHybridLayerNorm(config.output_channels, data_format="channels_first")
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = self.layer_norm1(hidden_states)
+
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.layer_norm2(hidden_states)
+ return hidden_states
+
+
+class DeepseekVLSamVisionProj(nn.Module):
+ def __init__(self, config, output_size: int = 24):
+ super().__init__()
+ self.config = config
+ self.output_size = output_size
+
+ self.conv1 = nn.Conv2d(
+ config.output_channels, config.output_channels * 2, kernel_size=3, stride=2, padding=1, bias=False
+ )
+ self.conv2 = nn.Conv2d(
+ config.output_channels * 2, config.output_channels * 4, kernel_size=3, stride=2, padding=1, bias=False
+ )
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ # interpolate Sam encodings to match Siglip encodings
+ features = torch.nn.functional.interpolate(
+ features,
+ size=(4 * self.output_size, 4 * self.output_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ features = self.conv1(features)
+ features = self.conv2(features)
+ return features
+
+
+class DeepseekVLHybridAligner(nn.Module):
+ def __init__(self, config: DeepseekVLHybridConfig):
+ super().__init__()
+
+ in_channels = config.vision_config.hidden_size
+ high_res_in_channels = config.high_res_vision_config.output_channels * 4
+ out_channels = config.text_config.hidden_size
+
+ self.vision_proj = nn.Linear(in_channels, out_channels // 2)
+ self.high_res_vision_proj = nn.Linear(high_res_in_channels, out_channels // 2)
+
+ self.act = nn.GELU()
+ self.proj = nn.Linear(out_channels, out_channels)
+
+ def forward(
+ self,
+ vision_encodings: torch.Tensor,
+ high_res_vision_encodings: torch.Tensor,
+ ) -> torch.Tensor:
+ vision_encodings = self.vision_proj(vision_encodings)
+ high_res_vision_encodings = self.high_res_vision_proj(high_res_vision_encodings)
+
+ encodings = torch.concat([high_res_vision_encodings, vision_encodings], dim=-1)
+ encodings = self.act(encodings)
+ encodings = self.proj(encodings)
+
+ return encodings
+
+
+@auto_docstring
+class DeepseekVLHybridPreTrainedModel(PreTrainedModel):
+ config: DeepseekVLHybridConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values", "causal_mask"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_param_buffer_assignment = False
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.text_config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Conv2d):
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, DeepseekVLHybridLayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+ elif isinstance(module, DeepseekVLHybridModel):
+ module.high_res_vision_alpha.data.zero_()
+
+
+DEEPSEEK_VL_COMMON_CUSTOM_ARGS = r"""
+ high_res_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size), *optional*):
+ The tensors corresponding to the input images. Pixel values can be obtained using
+ [`AutoImageProcessor`].
+"""
+
+
+@auto_docstring
+class DeepseekVLHybridModel(DeepseekVLHybridPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.output_size = config.vision_config.image_size // config.vision_config.patch_size
+ self.global_attn_index = config.high_res_vision_config.global_attn_indexes[0]
+
+ self.high_res_vision_model = AutoModel.from_config(config.high_res_vision_config)
+ self.high_res_vision_neck = DeepseekVLSamVisionNeck(config.high_res_vision_config)
+ self.high_res_vision_proj = DeepseekVLSamVisionProj(
+ config.high_res_vision_config, output_size=self.output_size
+ )
+ self.high_res_vision_alpha = nn.Parameter(torch.zeros(1))
+ self.config = config
+
+ self.vision_model = AutoModel.from_config(config.vision_config)
+ self.aligner = DeepseekVLHybridAligner(config)
+
+ self.language_model = AutoModel.from_config(config=config.text_config)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing.
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def get_image_features(self, pixel_values, high_res_pixel_values):
+ vision_encodings = self.get_low_res_image_features(pixel_values)
+ high_res_vision_encodings = self.get_high_res_image_features(high_res_pixel_values)
+ images_embeds = self.aligner(vision_encodings, high_res_vision_encodings)
+ return images_embeds
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ return special_image_mask
+
+ @can_return_tuple
+ @auto_docstring(custom_args=DEEPSEEK_VL_COMMON_CUSTOM_ARGS)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ high_res_pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ):
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None and high_res_pixel_values is None:
+ raise ValueError("Both pixel_values and high_res_pixel_values should be specified at the same time")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ if input_ids is None:
+ image_attention_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ image_attention_mask = image_attention_mask.all(-1)
+ else:
+ image_attention_mask = input_ids == self.config.image_token_id
+
+ image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ image_embeds = self.get_image_features(pixel_values, high_res_pixel_values)
+ image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
+
+ lm_output = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ return DeepseekVLHybridBaseModelOutputWithPast(
+ last_hidden_state=lm_output.last_hidden_state,
+ past_key_values=lm_output.past_key_values,
+ hidden_states=lm_output.hidden_states,
+ attentions=lm_output.attentions,
+ image_hidden_states=image_embeds if pixel_values is not None else None,
+ )
+
+ def get_low_res_image_features(self, pixel_values):
+ output = self.vision_model(pixel_values)
+ output = output[0]
+ return output
+
+ def get_high_res_image_features(self, pixel_values):
+ output = self.high_res_vision_model(
+ pixel_values=pixel_values,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+ last_hidden_state = output.last_hidden_state
+ last_hidden_state = self.high_res_vision_proj(last_hidden_state)
+
+ hidden_states = output.hidden_states
+ global_hidden_state = hidden_states[self.global_attn_index + 1] # +1 for embedding layer
+ global_hidden_state = self.high_res_vision_neck(global_hidden_state)
+ global_hidden_state = self.high_res_vision_proj(global_hidden_state)
+
+ output = last_hidden_state + global_hidden_state * self.high_res_vision_alpha
+
+ # batch_size, hidden_size, height, width -> batch_size, seq_len, hidden_size
+ output = output.permute(0, 2, 3, 1)
+ output = output.reshape(output.shape[0], -1, output.shape[-1])
+
+ return output
+
+
+class DeepseekVLHybridForConditionalGeneration(DeepseekVLHybridPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
+ _can_compile_fullgraph = True
+
+ def __init__(self, config: DeepseekVLHybridConfig):
+ super().__init__(config)
+ self.config = config
+ self.model = DeepseekVLHybridModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing.
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.language_model.set_input_embeddings(value)
+
+ def prepare_embeddings_for_image_generation(self) -> torch.Tensor:
+ raise AttributeError("Not needed for DeepseekVLHybrid")
+
+ @can_return_tuple
+ @auto_docstring(custom_args=DEEPSEEK_VL_COMMON_CUSTOM_ARGS)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ high_res_pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ """
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ high_res_pixel_values=high_res_pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return DeepseekVLHybridCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ high_res_pixel_values=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["high_res_pixel_values"] = high_res_pixel_values
+
+ return model_inputs
+
+
+__all__ = ["DeepseekVLHybridPreTrainedModel", "DeepseekVLHybridModel", "DeepseekVLHybridForConditionalGeneration"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py
new file mode 100644
index 0000000000000000000000000000000000000000..0da40603c2e939c0b9b85af929ff88a57240d5d5
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py
@@ -0,0 +1,1007 @@
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+from torchvision.transforms.v2 import functional as F
+
+from ...cache_utils import Cache
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ BatchFeature,
+ DefaultFastImageProcessorKwargs,
+ get_size_dict,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_transforms import convert_to_rgb, to_channel_dimension_format
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ pil_torch_interpolation_mapping,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...processing_utils import Unpack
+from ...tokenization_utils_base import (
+ PreTokenizedInput,
+ TextInput,
+)
+from ...utils import (
+ TensorType,
+ TransformersKwargs,
+ auto_docstring,
+ can_return_tuple,
+ filter_out_non_signature_kwargs,
+ logging,
+)
+from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel
+from ..deepseek_vl.configuration_deepseek_vl import DeepseekVLConfig
+from ..deepseek_vl.image_processing_deepseek_vl import DeepseekVLImageProcessor
+from ..deepseek_vl.image_processing_deepseek_vl_fast import DeepseekVLImageProcessorFast
+from ..deepseek_vl.modeling_deepseek_vl import (
+ DeepseekVLForConditionalGeneration,
+ DeepseekVLModel,
+ DeepseekVLPreTrainedModel,
+)
+from ..deepseek_vl.processing_deepseek_vl import DeepseekVLProcessor, DeepseekVLProcessorKwargs
+from ..idefics.modeling_idefics import IdeficsBaseModelOutputWithPast, IdeficsCausalLMOutputWithPast
+from ..sam.modeling_sam import SamLayerNorm, SamVisionNeck
+
+
+logger = logging.get_logger(__name__)
+
+
+DEEPSEEK_VL_COMMON_CUSTOM_ARGS = r"""
+ high_res_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size), *optional*):
+ The tensors corresponding to the input images. Pixel values can be obtained using
+ [`AutoImageProcessor`].
+"""
+
+
+class DeepseekVLHybridConfig(DeepseekVLConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DeepseekVLHybridModel`]. It is used to instantiate a
+ DeepseekVLHybrid model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DeepseekVLHybrid
+ [deepseek-community/deepseek-vl-7b-chat](https://huggingface.co/deepseek-community/deepseek-vl-7b-chat) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SiglipVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ high_res_vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SamVisionConfig`):
+ The config object or dictionary of the high resolution vision backbone.
+ image_token_id (`int`, *optional*, defaults to 100015):
+ The index representing image tokens in the model's token vocabulary.
+
+ Example:
+
+ ```python
+ >>> from transformers import DeepseekVLHybridConfig, DeepseekVLHybridModel
+
+ >>> # Initializing a DeepseekVLHybrid deepseek-community/deepseek-vl-7b-chat style configuration
+ >>> configuration = DeepseekVLHybridConfig()
+
+ >>> # Initializing a model (with random weights) from the deepseek-community/deepseek-vl-7b-chat style configuration
+ >>> model = DeepseekVLHybridModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "deepseek_vl_hybrid"
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig, "high_res_vision_config": AutoConfig}
+
+ def __init__(
+ self,
+ text_config: Optional[AutoConfig] = None,
+ vision_config: Optional[AutoConfig] = None,
+ high_res_vision_config: Optional[AutoConfig] = None,
+ image_token_id: int = 100015,
+ **kwargs,
+ ):
+ super().__init__(
+ text_config=text_config,
+ vision_config=vision_config,
+ image_token_id=image_token_id,
+ **kwargs,
+ )
+
+ if high_res_vision_config is None:
+ high_res_vision_config = {}
+ logger.info("`high_res_vision_config` is `None`. Initializing the `SamVisionConfig` with default values.")
+
+ if isinstance(high_res_vision_config, dict):
+ high_res_vision_config["model_type"] = high_res_vision_config.get("model_type", "sam_vision_model")
+ high_res_vision_config = CONFIG_MAPPING[high_res_vision_config["model_type"]](**high_res_vision_config)
+
+ self.high_res_vision_config = high_res_vision_config
+
+
+class DeepseekVLHybridBaseModelOutputWithPast(IdeficsBaseModelOutputWithPast):
+ pass
+
+
+class DeepseekVLHybridCausalLMOutputWithPast(IdeficsCausalLMOutputWithPast):
+ pass
+
+
+class DeepseekVLHybridLayerNorm(SamLayerNorm):
+ pass
+
+
+class DeepseekVLSamVisionNeck(SamVisionNeck):
+ def __init__(self, config):
+ super().__init__(config)
+
+
+class DeepseekVLSamVisionProj(nn.Module):
+ def __init__(self, config, output_size: int = 24):
+ super().__init__()
+ self.config = config
+ self.output_size = output_size
+
+ self.conv1 = nn.Conv2d(
+ config.output_channels, config.output_channels * 2, kernel_size=3, stride=2, padding=1, bias=False
+ )
+ self.conv2 = nn.Conv2d(
+ config.output_channels * 2, config.output_channels * 4, kernel_size=3, stride=2, padding=1, bias=False
+ )
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ # interpolate Sam encodings to match Siglip encodings
+ features = torch.nn.functional.interpolate(
+ features,
+ size=(4 * self.output_size, 4 * self.output_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ features = self.conv1(features)
+ features = self.conv2(features)
+ return features
+
+
+class DeepseekVLHybridAligner(nn.Module):
+ def __init__(self, config: DeepseekVLHybridConfig):
+ super().__init__()
+
+ in_channels = config.vision_config.hidden_size
+ high_res_in_channels = config.high_res_vision_config.output_channels * 4
+ out_channels = config.text_config.hidden_size
+
+ self.vision_proj = nn.Linear(in_channels, out_channels // 2)
+ self.high_res_vision_proj = nn.Linear(high_res_in_channels, out_channels // 2)
+
+ self.act = nn.GELU()
+ self.proj = nn.Linear(out_channels, out_channels)
+
+ def forward(
+ self,
+ vision_encodings: torch.Tensor,
+ high_res_vision_encodings: torch.Tensor,
+ ) -> torch.Tensor:
+ vision_encodings = self.vision_proj(vision_encodings)
+ high_res_vision_encodings = self.high_res_vision_proj(high_res_vision_encodings)
+
+ encodings = torch.concat([high_res_vision_encodings, vision_encodings], dim=-1)
+ encodings = self.act(encodings)
+ encodings = self.proj(encodings)
+
+ return encodings
+
+
+class DeepseekVLHybridPreTrainedModel(DeepseekVLPreTrainedModel):
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.text_config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Conv2d):
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, DeepseekVLHybridLayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+ elif isinstance(module, DeepseekVLHybridModel):
+ module.high_res_vision_alpha.data.zero_()
+
+
+class DeepseekVLHybridModel(DeepseekVLModel):
+ def __init__(self, config):
+ self.output_size = config.vision_config.image_size // config.vision_config.patch_size
+ self.global_attn_index = config.high_res_vision_config.global_attn_indexes[0]
+
+ self.high_res_vision_model = AutoModel.from_config(config.high_res_vision_config)
+ self.high_res_vision_neck = DeepseekVLSamVisionNeck(config.high_res_vision_config)
+ self.high_res_vision_proj = DeepseekVLSamVisionProj(
+ config.high_res_vision_config, output_size=self.output_size
+ )
+ self.high_res_vision_alpha = nn.Parameter(torch.zeros(1))
+
+ super().__init__(config)
+
+ def get_low_res_image_features(self, pixel_values):
+ output = self.vision_model(pixel_values)
+ output = output[0]
+ return output
+
+ def get_high_res_image_features(self, pixel_values):
+ output = self.high_res_vision_model(
+ pixel_values=pixel_values,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+ last_hidden_state = output.last_hidden_state
+ last_hidden_state = self.high_res_vision_proj(last_hidden_state)
+
+ hidden_states = output.hidden_states
+ global_hidden_state = hidden_states[self.global_attn_index + 1] # +1 for embedding layer
+ global_hidden_state = self.high_res_vision_neck(global_hidden_state)
+ global_hidden_state = self.high_res_vision_proj(global_hidden_state)
+
+ output = last_hidden_state + global_hidden_state * self.high_res_vision_alpha
+
+ # batch_size, hidden_size, height, width -> batch_size, seq_len, hidden_size
+ output = output.permute(0, 2, 3, 1)
+ output = output.reshape(output.shape[0], -1, output.shape[-1])
+
+ return output
+
+ def get_image_features(self, pixel_values, high_res_pixel_values):
+ vision_encodings = self.get_low_res_image_features(pixel_values)
+ high_res_vision_encodings = self.get_high_res_image_features(high_res_pixel_values)
+ images_embeds = self.aligner(vision_encodings, high_res_vision_encodings)
+ return images_embeds
+
+ @can_return_tuple
+ @auto_docstring(custom_args=DEEPSEEK_VL_COMMON_CUSTOM_ARGS)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ high_res_pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ):
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None and high_res_pixel_values is None:
+ raise ValueError("Both pixel_values and high_res_pixel_values should be specified at the same time")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ if input_ids is None:
+ image_attention_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ image_attention_mask = image_attention_mask.all(-1)
+ else:
+ image_attention_mask = input_ids == self.config.image_token_id
+
+ image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ image_embeds = self.get_image_features(pixel_values, high_res_pixel_values)
+ image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
+
+ lm_output = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ return DeepseekVLHybridBaseModelOutputWithPast(
+ last_hidden_state=lm_output.last_hidden_state,
+ past_key_values=lm_output.past_key_values,
+ hidden_states=lm_output.hidden_states,
+ attentions=lm_output.attentions,
+ image_hidden_states=image_embeds if pixel_values is not None else None,
+ )
+
+
+class DeepseekVLHybridForConditionalGeneration(DeepseekVLForConditionalGeneration):
+ @can_return_tuple
+ @auto_docstring(custom_args=DEEPSEEK_VL_COMMON_CUSTOM_ARGS)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ high_res_pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ """
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ high_res_pixel_values=high_res_pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return DeepseekVLHybridCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ high_res_pixel_values=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["high_res_pixel_values"] = high_res_pixel_values
+
+ return model_inputs
+
+
+class DeepseekVLHybridImageProcessor(DeepseekVLImageProcessor):
+ r"""
+ Constructs a DEEPSEEK_VL_HYBRID image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ high_res_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`):
+ Size of the high resolution output image after resizing. Can be overridden by the `high_res_size` parameter in the `preprocess`
+ method.
+ min_size (`int`, *optional*, defaults to 14):
+ The minimum allowed size for the resized image. Ensures that neither the height nor width
+ falls below this value after resizing.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
+ overridden by the `resample` parameter in the `preprocess` method.
+ high_res_resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
+ overridden by the `high_res_resample` parameter in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+ `do_rescale` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ high_res_image_mean (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_MEAN`):
+ Mean to use if normalizing the high resolution image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `high_res_image_mean` parameter in the `preprocess` method.
+ high_res_image_std (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_STD`):
+ Standard deviation to use if normalizing the high resolution image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `high_res_image_std` parameter in the `preprocess` method.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image to square or not.
+ """
+
+ model_input_names = ["pixel_values", "high_res_pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ high_res_size: Optional[dict[str, int]] = None,
+ min_size: int = 14,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ high_res_resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ high_res_image_mean: Optional[Union[float, list[float]]] = None,
+ high_res_image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ do_pad: bool = True,
+ **kwargs,
+ ) -> None:
+ high_res_size = high_res_size if high_res_size is not None else {"height": 1024, "width": 1024}
+ high_res_size = get_size_dict(high_res_size, default_to_square=True)
+
+ self.high_res_size = high_res_size
+ self.high_res_image_mean = high_res_image_mean if high_res_image_mean is not None else OPENAI_CLIP_MEAN
+ self.high_res_image_std = high_res_image_std if high_res_image_std is not None else OPENAI_CLIP_STD
+
+ self.resample = resample
+ self.high_res_resample = high_res_resample
+
+ super().__init__(
+ do_resize=do_resize,
+ size=size,
+ min_size=min_size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_convert_rgb=do_convert_rgb,
+ do_pad=do_pad,
+ **kwargs,
+ )
+
+ if high_res_image_mean is None:
+ self.high_res_background_color = (127, 127, 127)
+ else:
+ self.high_res_background_color = tuple(int(x * 255) for x in high_res_image_mean)
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ high_res_size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ high_res_resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ high_res_image_mean: Optional[Union[float, list[float]]] = None,
+ high_res_image_std: Optional[Union[float, list[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ do_pad: Optional[bool] = None,
+ background_color: Optional[tuple[int, int, int]] = None,
+ ):
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
+ resizing.
+ high_res_size (`Dict[str, int]`, *optional*, defaults to `self.high_res_size`):
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the high resolution output image after
+ resizing.
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
+ an effect if `do_resize` is set to `True`.
+ high_res_resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BICUBIC`. Only has
+ an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use if `do_normalize` is set to `True`.
+ high_res_image_mean (`float` or `List[float]`, *optional*, defaults to `self.high_res_image_mean`):
+ Image mean to use if `do_normalize` is set to `True`.
+ high_res_image_std (`float` or `List[float]`, *optional*, defaults to `self.high_res_image_std`):
+ Image standard deviation to use if `do_normalize` is set to `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image to square or not.
+ background_color (`tuple[int, int, int]`):
+ The background color to use for the padding.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ resample = resample if resample is not None else self.resample
+ high_res_resample = high_res_resample if high_res_resample is not None else self.high_res_resample
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ high_res_image_mean = high_res_image_mean if high_res_image_mean is not None else self.high_res_image_mean
+ high_res_image_std = high_res_image_std if high_res_image_std is not None else self.high_res_image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ background_color = background_color if background_color is not None else self.background_color
+
+ size = size if size is not None else self.size
+ size_dict = get_size_dict(size)
+ high_res_size = high_res_size if high_res_size is not None else self.high_res_size
+ high_res_size_dict = get_size_dict(high_res_size)
+
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ all_images = []
+ all_high_res_images = []
+ for image in images:
+ # high_res_image: resize (high) -> rescale -> normalize (high)
+ # low_res_image: resize (high) -> rescale -> resize (low) -> normalize (low)
+ high_res_image = image
+ if do_resize:
+ high_res_image = self.resize(
+ image=high_res_image,
+ size=high_res_size_dict,
+ resample=high_res_resample,
+ input_data_format=input_data_format,
+ )
+ if do_pad:
+ # Expand and pad the images to obtain a square image of dimensions `size x size`
+ high_res_image = self.pad_to_square(
+ image=high_res_image,
+ background_color=background_color,
+ input_data_format=input_data_format,
+ )
+ image = self.resize(
+ image=high_res_image,
+ size=size_dict,
+ resample=resample,
+ input_data_format=input_data_format,
+ )
+ if do_pad:
+ image = self.pad_to_square(
+ image=image,
+ background_color=background_color,
+ input_data_format=input_data_format,
+ )
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+ high_res_image = self.rescale(
+ image=high_res_image, scale=rescale_factor, input_data_format=input_data_format
+ )
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+ high_res_image = self.normalize(
+ image=high_res_image,
+ mean=high_res_image_mean,
+ std=high_res_image_std,
+ input_data_format=input_data_format,
+ )
+
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ high_res_image = to_channel_dimension_format(
+ high_res_image, data_format, input_channel_dim=input_data_format
+ )
+
+ all_images.append(image)
+ all_high_res_images.append(high_res_image)
+
+ data = {"pixel_values": all_images, "high_res_pixel_values": all_high_res_images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+class DeepseekVLHybridFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ r"""
+ min_size (`int`, *optional*, defaults to 14):
+ The minimum allowed size for the resized image. Ensures that neither the height nor width
+ falls below this value after resizing.
+ high_res_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`):
+ Size of the high resolution output image after resizing. Can be overridden by the `high_res_size` parameter in the `preprocess`
+ method.
+ high_res_resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
+ overridden by the `high_res_resample` parameter in the `preprocess` method.
+ high_res_image_mean (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_MEAN`):
+ Mean to use if normalizing the high resolution image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `high_res_image_mean` parameter in the `preprocess` method.
+ high_res_image_std (`float` or `list[float]`, *optional*, defaults to `OPENAI_CLIP_STD`):
+ Standard deviation to use if normalizing the high resolution image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `high_res_image_std` parameter in the `preprocess` method.
+ """
+
+ min_size: int
+ high_res_size: dict
+ high_res_resample: "PILImageResampling"
+ high_res_image_mean: list[float]
+ high_res_image_std: list[float]
+
+
+class DeepseekVLHybridImageProcessorFast(DeepseekVLImageProcessorFast):
+ high_res_image_mean = OPENAI_CLIP_MEAN
+ high_res_image_std = OPENAI_CLIP_STD
+ high_res_size = {"height": 1024, "width": 1024}
+ high_res_resample = PILImageResampling.BICUBIC
+ model_input_names = ["pixel_values", "high_res_pixel_values"]
+
+ def __init__(self, **kwargs: Unpack[DeepseekVLHybridFastImageProcessorKwargs]):
+ if kwargs.get("image_mean") is None:
+ background_color = (127, 127, 127)
+ else:
+ background_color = tuple([int(x * 255) for x in kwargs.get("image_mean")])
+ if kwargs.get("high_res_image_mean") is None:
+ high_res_background_color = (127, 127, 127)
+ else:
+ high_res_background_color = tuple(int(x * 255) for x in kwargs.get("high_res_image_mean"))
+ BaseImageProcessorFast.__init__(self, **kwargs)
+ self.background_color = tuple(background_color)
+ self.high_res_background_color = tuple(high_res_background_color)
+
+ def _further_process_kwargs(
+ self,
+ size: Optional[SizeDict] = None,
+ high_res_size: Optional[SizeDict] = None,
+ default_to_square: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ high_res_image_mean: Optional[Union[float, list[float]]] = None,
+ high_res_image_std: Optional[Union[float, list[float]]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ **kwargs,
+ ) -> dict:
+ """
+ Update kwargs that need further processing before being validated
+ Can be overridden by subclasses to customize the processing of kwargs.
+ """
+ if kwargs is None:
+ kwargs = {}
+ if size is not None:
+ size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
+ if high_res_size is not None:
+ high_res_size = SizeDict(**get_size_dict(size=high_res_size, default_to_square=default_to_square))
+ if isinstance(image_mean, list):
+ image_mean = tuple(image_mean)
+ if isinstance(image_std, list):
+ image_std = tuple(image_std)
+ if isinstance(high_res_image_mean, list):
+ high_res_image_mean = tuple(high_res_image_mean)
+ if isinstance(high_res_image_std, list):
+ high_res_image_std = tuple(high_res_image_std)
+ if data_format is None:
+ data_format = ChannelDimension.FIRST
+
+ high_res_resample = kwargs.pop("high_res_resample")
+ kwargs["high_res_interpolation"] = (
+ pil_torch_interpolation_mapping[high_res_resample]
+ if isinstance(high_res_resample, (int, PILImageResampling))
+ else high_res_resample
+ )
+
+ low_res_resample = kwargs.pop("resample")
+ kwargs["interpolation"] = (
+ pil_torch_interpolation_mapping[low_res_resample]
+ if isinstance(low_res_resample, (int, PILImageResampling))
+ else low_res_resample
+ )
+
+ kwargs["size"] = size
+ kwargs["high_res_size"] = high_res_size
+ kwargs["image_mean"] = image_mean
+ kwargs["image_std"] = image_std
+ kwargs["high_res_image_mean"] = high_res_image_mean
+ kwargs["high_res_image_std"] = high_res_image_std
+ kwargs["data_format"] = data_format
+
+ return kwargs
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ high_res_size: SizeDict,
+ min_size: int,
+ interpolation: Optional["F.InterpolationMode"],
+ high_res_interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ high_res_image_mean: Optional[Union[float, list[float]]],
+ high_res_image_std: Optional[Union[float, list[float]]],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ do_pad: bool = True,
+ **kwargs,
+ ) -> BatchFeature:
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ high_res_resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_high_res_images = self.resize(
+ image=stacked_images, size=high_res_size, min_size=min_size, interpolation=high_res_interpolation
+ )
+ high_res_resized_images_grouped[shape] = stacked_high_res_images
+ high_res_resized_images = reorder_images(high_res_resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_high_res_images, grouped_high_res_images_index = group_images_by_shape(
+ high_res_resized_images, disable_grouping=disable_grouping
+ )
+ high_res_padded_images = {}
+ high_res_processed_images_grouped = {}
+ for shape, stacked_high_res_images in grouped_high_res_images.items():
+ if do_pad:
+ stacked_high_res_images = self.pad_to_square(
+ stacked_high_res_images, background_color=self.high_res_background_color
+ )
+ high_res_padded_images[shape] = stacked_high_res_images
+ # Fused rescale and normalize
+ stacked_high_res_images = self.rescale_and_normalize(
+ stacked_high_res_images,
+ do_rescale,
+ rescale_factor,
+ do_normalize,
+ high_res_image_mean,
+ high_res_image_std,
+ )
+ high_res_processed_images_grouped[shape] = stacked_high_res_images
+ high_res_processed_images = reorder_images(high_res_processed_images_grouped, grouped_high_res_images_index)
+ high_res_processed_images = (
+ torch.stack(high_res_processed_images, dim=0) if return_tensors else high_res_processed_images
+ )
+
+ resized_images_grouped = {}
+ for shape, stacked_high_res_padded_images in high_res_padded_images.items():
+ if do_resize:
+ stacked_images = self.resize(
+ image=stacked_high_res_padded_images, size=size, min_size=min_size, interpolation=interpolation
+ )
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_high_res_images_index)
+
+ grouped_resized_images, grouped_resized_images_index = group_images_by_shape(
+ resized_images, disable_grouping=disable_grouping
+ )
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_resized_images.items():
+ if do_pad:
+ stacked_images = self.pad_to_square(stacked_images, background_color=self.background_color)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+ processed_images = reorder_images(processed_images_grouped, grouped_resized_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(
+ data={"pixel_values": processed_images, "high_res_pixel_values": high_res_processed_images},
+ tensor_type=return_tensors,
+ )
+
+
+class DeepseekVLHybridProcessorKwargs(DeepseekVLProcessorKwargs):
+ pass
+
+
+class DeepseekVLHybridProcessor(DeepseekVLProcessor):
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ images: Optional[ImageInput] = None,
+ **kwargs: Unpack[DeepseekVLHybridProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
+ DeepseekVLHybridImageProcessor's [`~DeepseekVLHybridImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
+ of the above two methods for more information.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ DeepseekVLHybridProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs
+ )
+ if text is None and images is None:
+ raise ValueError("You must specify either text or images.")
+
+ if text is not None:
+ if isinstance(text, str):
+ text = [text]
+ elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+
+ prompt_strings = []
+ one_img_tokens = self.image_token * self.num_image_tokens
+ for prompt in text:
+ prompt = prompt.replace(self.image_token, one_img_tokens)
+ prompt_strings.append(prompt)
+
+ data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
+
+ # process images if pixel_values are provided
+ if images is not None:
+ inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
+ data["pixel_values"] = inputs["pixel_values"]
+ data["high_res_pixel_values"] = inputs["high_res_pixel_values"]
+
+ return BatchFeature(data=data)
+
+
+__all__ = [
+ "DeepseekVLHybridConfig",
+ "DeepseekVLHybridPreTrainedModel",
+ "DeepseekVLHybridModel",
+ "DeepseekVLHybridForConditionalGeneration",
+ "DeepseekVLHybridImageProcessor",
+ "DeepseekVLHybridImageProcessorFast",
+ "DeepseekVLHybridProcessor",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py
new file mode 100644
index 0000000000000000000000000000000000000000..538fea5a6b322ddf0b8a0902915e726eab7420f0
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py
@@ -0,0 +1,158 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_deepseek_vl_hybrid.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+from ...image_processing_utils_fast import BatchFeature
+from ...image_utils import ImageInput
+from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+
+
+class DeepseekVLHybridProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {"padding": False},
+ "common_kwargs": {"return_tensors": "pt"},
+ }
+
+
+class DeepseekVLHybridProcessor(ProcessorMixin):
+ r"""
+ Constructs a DeepseekVLHybrid processor which wraps a DeepseekVLHybrid Image Processor and a Llama tokenizer into a single processor.
+
+ [`DeepseekVLHybridProcessor`] offers all the functionalities of [`DeepseekVLHybridImageProcessor`] and [`LlamaTokenizerFast`]. See the
+ [`~DeepseekVLHybridProcessor.__call__`] and [`~DeepseekVLHybridProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`DeepseekVLHybridImageProcessor`]):
+ The image processor is a required input.
+ tokenizer ([`LlamaTokenizerFast`]):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*):
+ A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ num_image_tokens (`int`, *optional*, defaults to 576):
+ The number of special image tokens used as placeholders for visual content in text sequences.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = ["chat_template", "num_image_tokens"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ image_processor,
+ tokenizer,
+ chat_template=None,
+ num_image_tokens=576,
+ ):
+ self.image_token = tokenizer.image_token
+ self.num_image_tokens = num_image_tokens
+
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ images: Optional[ImageInput] = None,
+ **kwargs: Unpack[DeepseekVLHybridProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
+ DeepseekVLHybridImageProcessor's [`~DeepseekVLHybridImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
+ of the above two methods for more information.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ DeepseekVLHybridProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs
+ )
+ if text is None and images is None:
+ raise ValueError("You must specify either text or images.")
+
+ if text is not None:
+ if isinstance(text, str):
+ text = [text]
+ elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+
+ prompt_strings = []
+ one_img_tokens = self.image_token * self.num_image_tokens
+ for prompt in text:
+ prompt = prompt.replace(self.image_token, one_img_tokens)
+ prompt_strings.append(prompt)
+
+ data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
+
+ # process images if pixel_values are provided
+ if images is not None:
+ inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
+ data["pixel_values"] = inputs["pixel_values"]
+ data["high_res_pixel_values"] = inputs["high_res_pixel_values"]
+
+ return BatchFeature(data=data)
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
+
+__all__ = ["DeepseekVLHybridProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..16a1959c30ff5474ed82a6b0a2e22896514d6a44
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_deformable_detr import *
+ from .feature_extraction_deformable_detr import *
+ from .image_processing_deformable_detr import *
+ from .image_processing_deformable_detr_fast import *
+ from .modeling_deformable_detr import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/configuration_deformable_detr.py b/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/configuration_deformable_detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..b85a7399908d4f25ce220aa1edab4b0e74c9bdb3
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/configuration_deformable_detr.py
@@ -0,0 +1,290 @@
+# coding=utf-8
+# Copyright 2022 SenseTime and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Deformable DETR model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import verify_backbone_config_arguments
+from ..auto import CONFIG_MAPPING
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeformableDetrConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DeformableDetrModel`]. It is used to instantiate
+ a Deformable DETR model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the Deformable DETR
+ [SenseTime/deformable-detr](https://huggingface.co/SenseTime/deformable-detr) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ use_timm_backbone (`bool`, *optional*, defaults to `True`):
+ Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
+ API.
+ backbone_config (`PretrainedConfig` or `dict`, *optional*):
+ The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
+ case it will default to `ResNetConfig()`.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ num_queries (`int`, *optional*, defaults to 300):
+ Number of object queries, i.e. detection slots. This is the maximal number of objects
+ [`DeformableDetrModel`] can detect in a single image. In case `two_stage` is set to `True`, we use
+ `two_stage_num_proposals` instead.
+ d_model (`int`, *optional*, defaults to 256):
+ Dimension of the layers.
+ encoder_layers (`int`, *optional*, defaults to 6):
+ Number of encoder layers.
+ decoder_layers (`int`, *optional*, defaults to 6):
+ Number of decoder layers.
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ decoder_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ decoder_ffn_dim (`int`, *optional*, defaults to 1024):
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
+ encoder_ffn_dim (`int`, *optional*, defaults to 1024):
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
+ activation_function (`str` or `function`, *optional*, defaults to `"relu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ init_xavier_std (`float`, *optional*, defaults to 1):
+ The scaling factor used for the Xavier initialization gain in the HM Attention map module.
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
+ for more details.
+ auxiliary_loss (`bool`, *optional*, defaults to `False`):
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
+ position_embedding_type (`str`, *optional*, defaults to `"sine"`):
+ Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
+ backbone (`str`, *optional*, defaults to `"resnet50"`):
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
+ use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
+ Whether to use pretrained weights for the backbone.
+ backbone_kwargs (`dict`, *optional*):
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
+ dilation (`bool`, *optional*, defaults to `False`):
+ Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
+ `use_timm_backbone` = `True`.
+ class_cost (`float`, *optional*, defaults to 1):
+ Relative weight of the classification error in the Hungarian matching cost.
+ bbox_cost (`float`, *optional*, defaults to 5):
+ Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
+ giou_cost (`float`, *optional*, defaults to 2):
+ Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
+ mask_loss_coefficient (`float`, *optional*, defaults to 1):
+ Relative weight of the Focal loss in the panoptic segmentation loss.
+ dice_loss_coefficient (`float`, *optional*, defaults to 1):
+ Relative weight of the DICE/F-1 loss in the panoptic segmentation loss.
+ bbox_loss_coefficient (`float`, *optional*, defaults to 5):
+ Relative weight of the L1 bounding box loss in the object detection loss.
+ giou_loss_coefficient (`float`, *optional*, defaults to 2):
+ Relative weight of the generalized IoU loss in the object detection loss.
+ eos_coefficient (`float`, *optional*, defaults to 0.1):
+ Relative classification weight of the 'no-object' class in the object detection loss.
+ num_feature_levels (`int`, *optional*, defaults to 4):
+ The number of input feature levels.
+ encoder_n_points (`int`, *optional*, defaults to 4):
+ The number of sampled keys in each feature level for each attention head in the encoder.
+ decoder_n_points (`int`, *optional*, defaults to 4):
+ The number of sampled keys in each feature level for each attention head in the decoder.
+ two_stage (`bool`, *optional*, defaults to `False`):
+ Whether to apply a two-stage deformable DETR, where the region proposals are also generated by a variant of
+ Deformable DETR, which are further fed into the decoder for iterative bounding box refinement.
+ two_stage_num_proposals (`int`, *optional*, defaults to 300):
+ The number of region proposals to be generated, in case `two_stage` is set to `True`.
+ with_box_refine (`bool`, *optional*, defaults to `False`):
+ Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
+ based on the predictions from the previous layer.
+ focal_alpha (`float`, *optional*, defaults to 0.25):
+ Alpha parameter in the focal loss.
+ disable_custom_kernels (`bool`, *optional*, defaults to `False`):
+ Disable the use of custom CUDA and CPU kernels. This option is necessary for the ONNX export, as custom
+ kernels are not supported by PyTorch ONNX export.
+
+ Examples:
+
+ ```python
+ >>> from transformers import DeformableDetrConfig, DeformableDetrModel
+
+ >>> # Initializing a Deformable DETR SenseTime/deformable-detr style configuration
+ >>> configuration = DeformableDetrConfig()
+
+ >>> # Initializing a model (with random weights) from the SenseTime/deformable-detr style configuration
+ >>> model = DeformableDetrModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "deformable_detr"
+ attribute_map = {
+ "hidden_size": "d_model",
+ "num_attention_heads": "encoder_attention_heads",
+ }
+
+ def __init__(
+ self,
+ use_timm_backbone=True,
+ backbone_config=None,
+ num_channels=3,
+ num_queries=300,
+ max_position_embeddings=1024,
+ encoder_layers=6,
+ encoder_ffn_dim=1024,
+ encoder_attention_heads=8,
+ decoder_layers=6,
+ decoder_ffn_dim=1024,
+ decoder_attention_heads=8,
+ encoder_layerdrop=0.0,
+ is_encoder_decoder=True,
+ activation_function="relu",
+ d_model=256,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ init_std=0.02,
+ init_xavier_std=1.0,
+ return_intermediate=True,
+ auxiliary_loss=False,
+ position_embedding_type="sine",
+ backbone="resnet50",
+ use_pretrained_backbone=True,
+ backbone_kwargs=None,
+ dilation=False,
+ num_feature_levels=4,
+ encoder_n_points=4,
+ decoder_n_points=4,
+ two_stage=False,
+ two_stage_num_proposals=300,
+ with_box_refine=False,
+ class_cost=1,
+ bbox_cost=5,
+ giou_cost=2,
+ mask_loss_coefficient=1,
+ dice_loss_coefficient=1,
+ bbox_loss_coefficient=5,
+ giou_loss_coefficient=2,
+ eos_coefficient=0.1,
+ focal_alpha=0.25,
+ disable_custom_kernels=False,
+ **kwargs,
+ ):
+ # We default to values which were previously hard-coded in the model. This enables configurability of the config
+ # while keeping the default behavior the same.
+ if use_timm_backbone and backbone_kwargs is None:
+ backbone_kwargs = {}
+ if dilation:
+ backbone_kwargs["output_stride"] = 16
+ backbone_kwargs["out_indices"] = [2, 3, 4] if num_feature_levels > 1 else [4]
+ backbone_kwargs["in_chans"] = num_channels
+ # Backwards compatibility
+ elif not use_timm_backbone and backbone in (None, "resnet50"):
+ if backbone_config is None:
+ logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
+ backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
+ elif isinstance(backbone_config, dict):
+ backbone_model_type = backbone_config.get("model_type")
+ config_class = CONFIG_MAPPING[backbone_model_type]
+ backbone_config = config_class.from_dict(backbone_config)
+
+ verify_backbone_config_arguments(
+ use_timm_backbone=use_timm_backbone,
+ use_pretrained_backbone=use_pretrained_backbone,
+ backbone=backbone,
+ backbone_config=backbone_config,
+ backbone_kwargs=backbone_kwargs,
+ )
+
+ self.use_timm_backbone = use_timm_backbone
+ self.backbone_config = backbone_config
+ self.num_channels = num_channels
+ self.num_queries = num_queries
+ self.max_position_embeddings = max_position_embeddings
+ self.d_model = d_model
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.encoder_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.decoder_layers = decoder_layers
+ self.decoder_attention_heads = decoder_attention_heads
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.init_xavier_std = init_xavier_std
+ self.encoder_layerdrop = encoder_layerdrop
+ self.auxiliary_loss = auxiliary_loss
+ self.position_embedding_type = position_embedding_type
+ self.backbone = backbone
+ self.use_pretrained_backbone = use_pretrained_backbone
+ self.backbone_kwargs = backbone_kwargs
+ self.dilation = dilation
+ # deformable attributes
+ self.num_feature_levels = num_feature_levels
+ self.encoder_n_points = encoder_n_points
+ self.decoder_n_points = decoder_n_points
+ self.two_stage = two_stage
+ self.two_stage_num_proposals = two_stage_num_proposals
+ self.with_box_refine = with_box_refine
+ if two_stage is True and with_box_refine is False:
+ raise ValueError("If two_stage is True, with_box_refine must be True.")
+ # Hungarian matcher
+ self.class_cost = class_cost
+ self.bbox_cost = bbox_cost
+ self.giou_cost = giou_cost
+ # Loss coefficients
+ self.mask_loss_coefficient = mask_loss_coefficient
+ self.dice_loss_coefficient = dice_loss_coefficient
+ self.bbox_loss_coefficient = bbox_loss_coefficient
+ self.giou_loss_coefficient = giou_loss_coefficient
+ self.eos_coefficient = eos_coefficient
+ self.focal_alpha = focal_alpha
+ self.disable_custom_kernels = disable_custom_kernels
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
+
+ @property
+ def num_attention_heads(self) -> int:
+ return self.encoder_attention_heads
+
+ @property
+ def hidden_size(self) -> int:
+ return self.d_model
+
+ @property
+ def sub_configs(self):
+ return (
+ {"backbone_config": type(self.backbone_config)}
+ if getattr(self, "backbone_config", None) is not None
+ else {}
+ )
+
+
+__all__ = ["DeformableDetrConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/feature_extraction_deformable_detr.py b/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/feature_extraction_deformable_detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..e349ca3db0ca92b349752484877ad8dd64d153b1
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/feature_extraction_deformable_detr.py
@@ -0,0 +1,48 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for Deformable DETR."""
+
+import warnings
+
+from ...image_transforms import rgb_to_id as _rgb_to_id
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_deformable_detr import DeformableDetrImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+def rgb_to_id(x):
+ warnings.warn(
+ "rgb_to_id has moved and will not be importable from this module from v5. "
+ "Please import from transformers.image_transforms instead.",
+ FutureWarning,
+ )
+ return _rgb_to_id(x)
+
+
+@requires(backends=("vision",))
+class DeformableDetrFeatureExtractor(DeformableDetrImageProcessor):
+ def __init__(self, *args, **kwargs) -> None:
+ warnings.warn(
+ "The class DeformableDetrFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
+ " Please use DeformableDetrImageProcessor instead.",
+ FutureWarning,
+ )
+ super().__init__(*args, **kwargs)
+
+
+__all__ = ["DeformableDetrFeatureExtractor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/image_processing_deformable_detr.py b/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/image_processing_deformable_detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6875eb9b8f80e152c0e310e9bdd051e6b2200e6
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/image_processing_deformable_detr.py
@@ -0,0 +1,1634 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Deformable DETR."""
+
+import io
+import pathlib
+from collections import defaultdict
+from collections.abc import Iterable
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_processing_utils import BaseImageProcessor, get_size_dict
+from ...image_transforms import (
+ PaddingMode,
+ center_to_corners_format,
+ corners_to_center_format,
+ id_to_rgb,
+ pad,
+ rescale,
+ resize,
+ rgb_to_id,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ AnnotationFormat,
+ AnnotationType,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_annotations,
+ validate_kwargs,
+ validate_preprocess_arguments,
+)
+from ...utils import (
+ TensorType,
+ is_flax_available,
+ is_jax_tensor,
+ is_scipy_available,
+ is_tf_available,
+ is_tf_tensor,
+ is_torch_available,
+ is_torch_tensor,
+ is_vision_available,
+ logging,
+)
+from ...utils.import_utils import requires
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+
+if is_vision_available():
+ import PIL
+
+if is_scipy_available():
+ import scipy.special
+ import scipy.stats
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
+
+
+# Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio
+def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]:
+ """
+ Computes the output image size given the input image size and the desired output size.
+
+ Args:
+ image_size (`tuple[int, int]`):
+ The input image size.
+ size (`int`):
+ The desired output size.
+ max_size (`int`, *optional*):
+ The maximum allowed output size.
+ """
+ height, width = image_size
+ raw_size = None
+ if max_size is not None:
+ min_original_size = float(min((height, width)))
+ max_original_size = float(max((height, width)))
+ if max_original_size / min_original_size * size > max_size:
+ raw_size = max_size * min_original_size / max_original_size
+ size = int(round(raw_size))
+
+ if (height <= width and height == size) or (width <= height and width == size):
+ oh, ow = height, width
+ elif width < height:
+ ow = size
+ if max_size is not None and raw_size is not None:
+ oh = int(raw_size * height / width)
+ else:
+ oh = int(size * height / width)
+ else:
+ oh = size
+ if max_size is not None and raw_size is not None:
+ ow = int(raw_size * width / height)
+ else:
+ ow = int(size * width / height)
+
+ return (oh, ow)
+
+
+# Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size
+def get_resize_output_image_size(
+ input_image: np.ndarray,
+ size: Union[int, tuple[int, int], list[int]],
+ max_size: Optional[int] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> tuple[int, int]:
+ """
+ Computes the output image size given the input image size and the desired output size. If the desired output size
+ is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output
+ image size is computed by keeping the aspect ratio of the input image size.
+
+ Args:
+ input_image (`np.ndarray`):
+ The image to resize.
+ size (`int` or `tuple[int, int]` or `list[int]`):
+ The desired output size.
+ max_size (`int`, *optional*):
+ The maximum allowed output size.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
+ """
+ image_size = get_image_size(input_image, input_data_format)
+ if isinstance(size, (list, tuple)):
+ return size
+
+ return get_size_with_aspect_ratio(image_size, size, max_size)
+
+
+# Copied from transformers.models.detr.image_processing_detr.get_image_size_for_max_height_width
+def get_image_size_for_max_height_width(
+ input_image: np.ndarray,
+ max_height: int,
+ max_width: int,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> tuple[int, int]:
+ """
+ Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
+ Important, even if image_height < max_height and image_width < max_width, the image will be resized
+ to at least one of the edges be equal to max_height or max_width.
+
+ For example:
+ - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
+ - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
+
+ Args:
+ input_image (`np.ndarray`):
+ The image to resize.
+ max_height (`int`):
+ The maximum allowed height.
+ max_width (`int`):
+ The maximum allowed width.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
+ """
+ image_size = get_image_size(input_image, input_data_format)
+ height, width = image_size
+ height_scale = max_height / height
+ width_scale = max_width / width
+ min_scale = min(height_scale, width_scale)
+ new_height = int(height * min_scale)
+ new_width = int(width * min_scale)
+ return new_height, new_width
+
+
+# Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn
+def get_numpy_to_framework_fn(arr) -> Callable:
+ """
+ Returns a function that converts a numpy array to the framework of the input array.
+
+ Args:
+ arr (`np.ndarray`): The array to convert.
+ """
+ if isinstance(arr, np.ndarray):
+ return np.array
+ if is_tf_available() and is_tf_tensor(arr):
+ import tensorflow as tf
+
+ return tf.convert_to_tensor
+ if is_torch_available() and is_torch_tensor(arr):
+ import torch
+
+ return torch.tensor
+ if is_flax_available() and is_jax_tensor(arr):
+ import jax.numpy as jnp
+
+ return jnp.array
+ raise ValueError(f"Cannot convert arrays of type {type(arr)}")
+
+
+# Copied from transformers.models.detr.image_processing_detr.safe_squeeze
+def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray:
+ """
+ Squeezes an array, but only if the axis specified has dim 1.
+ """
+ if axis is None:
+ return arr.squeeze()
+
+ try:
+ return arr.squeeze(axis=axis)
+ except ValueError:
+ return arr
+
+
+# Copied from transformers.models.detr.image_processing_detr.normalize_annotation
+def normalize_annotation(annotation: dict, image_size: tuple[int, int]) -> dict:
+ image_height, image_width = image_size
+ norm_annotation = {}
+ for key, value in annotation.items():
+ if key == "boxes":
+ boxes = value
+ boxes = corners_to_center_format(boxes)
+ boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32)
+ norm_annotation[key] = boxes
+ else:
+ norm_annotation[key] = value
+ return norm_annotation
+
+
+# Copied from transformers.models.detr.image_processing_detr.max_across_indices
+def max_across_indices(values: Iterable[Any]) -> list[Any]:
+ """
+ Return the maximum value across all indices of an iterable of values.
+ """
+ return [max(values_i) for values_i in zip(*values)]
+
+
+# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
+def get_max_height_width(
+ images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> list[int]:
+ """
+ Get the maximum height and width across all images in a batch.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if input_data_format == ChannelDimension.FIRST:
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
+ elif input_data_format == ChannelDimension.LAST:
+ max_height, max_width, _ = max_across_indices([img.shape for img in images])
+ else:
+ raise ValueError(f"Invalid channel dimension format: {input_data_format}")
+ return (max_height, max_width)
+
+
+# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
+def make_pixel_mask(
+ image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> np.ndarray:
+ """
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+
+ Args:
+ image (`np.ndarray`):
+ Image to make the pixel mask for.
+ output_size (`tuple[int, int]`):
+ Output size of the mask.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ mask = np.zeros(output_size, dtype=np.int64)
+ mask[:input_height, :input_width] = 1
+ return mask
+
+
+# Copied from transformers.models.detr.image_processing_detr.convert_coco_poly_to_mask
+def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndarray:
+ """
+ Convert a COCO polygon annotation to a mask.
+
+ Args:
+ segmentations (`list[list[float]]`):
+ List of polygons, each polygon represented by a list of x-y coordinates.
+ height (`int`):
+ Height of the mask.
+ width (`int`):
+ Width of the mask.
+ """
+ try:
+ from pycocotools import mask as coco_mask
+ except ImportError:
+ raise ImportError("Pycocotools is not installed in your environment.")
+
+ masks = []
+ for polygons in segmentations:
+ rles = coco_mask.frPyObjects(polygons, height, width)
+ mask = coco_mask.decode(rles)
+ if len(mask.shape) < 3:
+ mask = mask[..., None]
+ mask = np.asarray(mask, dtype=np.uint8)
+ mask = np.any(mask, axis=2)
+ masks.append(mask)
+ if masks:
+ masks = np.stack(masks, axis=0)
+ else:
+ masks = np.zeros((0, height, width), dtype=np.uint8)
+
+ return masks
+
+
+# Copied from transformers.models.detr.image_processing_detr.prepare_coco_detection_annotation with DETR->DeformableDetr
+def prepare_coco_detection_annotation(
+ image,
+ target,
+ return_segmentation_masks: bool = False,
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
+):
+ """
+ Convert the target in COCO format into the format expected by DeformableDetr.
+ """
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
+
+ image_id = target["image_id"]
+ image_id = np.asarray([image_id], dtype=np.int64)
+
+ # Get all COCO annotations for the given image.
+ annotations = target["annotations"]
+ annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0]
+
+ classes = [obj["category_id"] for obj in annotations]
+ classes = np.asarray(classes, dtype=np.int64)
+
+ # for conversion to coco api
+ area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32)
+ iscrowd = np.asarray([obj.get("iscrowd", 0) for obj in annotations], dtype=np.int64)
+
+ boxes = [obj["bbox"] for obj in annotations]
+ # guard against no boxes via resizing
+ boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4)
+ boxes[:, 2:] += boxes[:, :2]
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
+
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+
+ new_target = {}
+ new_target["image_id"] = image_id
+ new_target["class_labels"] = classes[keep]
+ new_target["boxes"] = boxes[keep]
+ new_target["area"] = area[keep]
+ new_target["iscrowd"] = iscrowd[keep]
+ new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64)
+
+ if annotations and "keypoints" in annotations[0]:
+ keypoints = [obj["keypoints"] for obj in annotations]
+ # Converting the filtered keypoints list to a numpy array
+ keypoints = np.asarray(keypoints, dtype=np.float32)
+ # Apply the keep mask here to filter the relevant annotations
+ keypoints = keypoints[keep]
+ num_keypoints = keypoints.shape[0]
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
+ new_target["keypoints"] = keypoints
+
+ if return_segmentation_masks:
+ segmentation_masks = [obj["segmentation"] for obj in annotations]
+ masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width)
+ new_target["masks"] = masks[keep]
+
+ return new_target
+
+
+# Copied from transformers.models.detr.image_processing_detr.masks_to_boxes
+def masks_to_boxes(masks: np.ndarray) -> np.ndarray:
+ """
+ Compute the bounding boxes around the provided panoptic segmentation masks.
+
+ Args:
+ masks: masks in format `[number_masks, height, width]` where N is the number of masks
+
+ Returns:
+ boxes: bounding boxes in format `[number_masks, 4]` in xyxy format
+ """
+ if masks.size == 0:
+ return np.zeros((0, 4))
+
+ h, w = masks.shape[-2:]
+ y = np.arange(0, h, dtype=np.float32)
+ x = np.arange(0, w, dtype=np.float32)
+ # see https://github.com/pytorch/pytorch/issues/50276
+ y, x = np.meshgrid(y, x, indexing="ij")
+
+ x_mask = masks * np.expand_dims(x, axis=0)
+ x_max = x_mask.reshape(x_mask.shape[0], -1).max(-1)
+ x = np.ma.array(x_mask, mask=~(np.array(masks, dtype=bool)))
+ x_min = x.filled(fill_value=1e8)
+ x_min = x_min.reshape(x_min.shape[0], -1).min(-1)
+
+ y_mask = masks * np.expand_dims(y, axis=0)
+ y_max = y_mask.reshape(x_mask.shape[0], -1).max(-1)
+ y = np.ma.array(y_mask, mask=~(np.array(masks, dtype=bool)))
+ y_min = y.filled(fill_value=1e8)
+ y_min = y_min.reshape(y_min.shape[0], -1).min(-1)
+
+ return np.stack([x_min, y_min, x_max, y_max], 1)
+
+
+# Copied from transformers.models.detr.image_processing_detr.prepare_coco_panoptic_annotation with DETR->DeformableDetr
+def prepare_coco_panoptic_annotation(
+ image: np.ndarray,
+ target: dict,
+ masks_path: Union[str, pathlib.Path],
+ return_masks: bool = True,
+ input_data_format: Union[ChannelDimension, str] = None,
+) -> dict:
+ """
+ Prepare a coco panoptic annotation for DeformableDetr.
+ """
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
+ annotation_path = pathlib.Path(masks_path) / target["file_name"]
+
+ new_target = {}
+ new_target["image_id"] = np.asarray([target["image_id"] if "image_id" in target else target["id"]], dtype=np.int64)
+ new_target["size"] = np.asarray([image_height, image_width], dtype=np.int64)
+ new_target["orig_size"] = np.asarray([image_height, image_width], dtype=np.int64)
+
+ if "segments_info" in target:
+ masks = np.asarray(PIL.Image.open(annotation_path), dtype=np.uint32)
+ masks = rgb_to_id(masks)
+
+ ids = np.array([segment_info["id"] for segment_info in target["segments_info"]])
+ masks = masks == ids[:, None, None]
+ masks = masks.astype(np.uint8)
+ if return_masks:
+ new_target["masks"] = masks
+ new_target["boxes"] = masks_to_boxes(masks)
+ new_target["class_labels"] = np.array(
+ [segment_info["category_id"] for segment_info in target["segments_info"]], dtype=np.int64
+ )
+ new_target["iscrowd"] = np.asarray(
+ [segment_info["iscrowd"] for segment_info in target["segments_info"]], dtype=np.int64
+ )
+ new_target["area"] = np.asarray(
+ [segment_info["area"] for segment_info in target["segments_info"]], dtype=np.float32
+ )
+
+ return new_target
+
+
+# Copied from transformers.models.detr.image_processing_detr.get_segmentation_image
+def get_segmentation_image(
+ masks: np.ndarray, input_size: tuple, target_size: tuple, stuff_equiv_classes, deduplicate=False
+):
+ h, w = input_size
+ final_h, final_w = target_size
+
+ m_id = scipy.special.softmax(masks.transpose(0, 1), -1)
+
+ if m_id.shape[-1] == 0:
+ # We didn't detect any mask :(
+ m_id = np.zeros((h, w), dtype=np.int64)
+ else:
+ m_id = m_id.argmax(-1).reshape(h, w)
+
+ if deduplicate:
+ # Merge the masks corresponding to the same stuff class
+ for equiv in stuff_equiv_classes.values():
+ for eq_id in equiv:
+ m_id[m_id == eq_id] = equiv[0]
+
+ seg_img = id_to_rgb(m_id)
+ seg_img = resize(seg_img, (final_w, final_h), resample=PILImageResampling.NEAREST)
+ return seg_img
+
+
+# Copied from transformers.models.detr.image_processing_detr.get_mask_area
+def get_mask_area(seg_img: np.ndarray, target_size: tuple[int, int], n_classes: int) -> np.ndarray:
+ final_h, final_w = target_size
+ np_seg_img = seg_img.astype(np.uint8)
+ np_seg_img = np_seg_img.reshape(final_h, final_w, 3)
+ m_id = rgb_to_id(np_seg_img)
+ area = [(m_id == i).sum() for i in range(n_classes)]
+ return area
+
+
+# Copied from transformers.models.detr.image_processing_detr.score_labels_from_class_probabilities
+def score_labels_from_class_probabilities(logits: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+ probs = scipy.special.softmax(logits, axis=-1)
+ labels = probs.argmax(-1, keepdims=True)
+ scores = np.take_along_axis(probs, labels, axis=-1)
+ scores, labels = scores.squeeze(-1), labels.squeeze(-1)
+ return scores, labels
+
+
+# Copied from transformers.models.detr.image_processing_detr.post_process_panoptic_sample
+def post_process_panoptic_sample(
+ out_logits: np.ndarray,
+ masks: np.ndarray,
+ boxes: np.ndarray,
+ processed_size: tuple[int, int],
+ target_size: tuple[int, int],
+ is_thing_map: dict,
+ threshold=0.85,
+) -> dict:
+ """
+ Converts the output of [`DetrForSegmentation`] into panoptic segmentation predictions for a single sample.
+
+ Args:
+ out_logits (`torch.Tensor`):
+ The logits for this sample.
+ masks (`torch.Tensor`):
+ The predicted segmentation masks for this sample.
+ boxes (`torch.Tensor`):
+ The predicted bounding boxes for this sample. The boxes are in the normalized format `(center_x, center_y,
+ width, height)` and values between `[0, 1]`, relative to the size the image (disregarding padding).
+ processed_size (`tuple[int, int]`):
+ The processed size of the image `(height, width)`, as returned by the preprocessing step i.e. the size
+ after data augmentation but before batching.
+ target_size (`tuple[int, int]`):
+ The target size of the image, `(height, width)` corresponding to the requested final size of the
+ prediction.
+ is_thing_map (`Dict`):
+ A dictionary mapping class indices to a boolean value indicating whether the class is a thing or not.
+ threshold (`float`, *optional*, defaults to 0.85):
+ The threshold used to binarize the segmentation masks.
+ """
+ # we filter empty queries and detection below threshold
+ scores, labels = score_labels_from_class_probabilities(out_logits)
+ keep = (labels != out_logits.shape[-1] - 1) & (scores > threshold)
+
+ cur_scores = scores[keep]
+ cur_classes = labels[keep]
+ cur_boxes = center_to_corners_format(boxes[keep])
+
+ if len(cur_boxes) != len(cur_classes):
+ raise ValueError("Not as many boxes as there are classes")
+
+ cur_masks = masks[keep]
+ cur_masks = resize(cur_masks[:, None], processed_size, resample=PILImageResampling.BILINEAR)
+ cur_masks = safe_squeeze(cur_masks, 1)
+ b, h, w = cur_masks.shape
+
+ # It may be that we have several predicted masks for the same stuff class.
+ # In the following, we track the list of masks ids for each stuff class (they are merged later on)
+ cur_masks = cur_masks.reshape(b, -1)
+ stuff_equiv_classes = defaultdict(list)
+ for k, label in enumerate(cur_classes):
+ if not is_thing_map[label]:
+ stuff_equiv_classes[label].append(k)
+
+ seg_img = get_segmentation_image(cur_masks, processed_size, target_size, stuff_equiv_classes, deduplicate=True)
+ area = get_mask_area(cur_masks, processed_size, n_classes=len(cur_scores))
+
+ # We filter out any mask that is too small
+ if cur_classes.size() > 0:
+ # We know filter empty masks as long as we find some
+ filtered_small = np.array([a <= 4 for a in area], dtype=bool)
+ while filtered_small.any():
+ cur_masks = cur_masks[~filtered_small]
+ cur_scores = cur_scores[~filtered_small]
+ cur_classes = cur_classes[~filtered_small]
+ seg_img = get_segmentation_image(cur_masks, (h, w), target_size, stuff_equiv_classes, deduplicate=True)
+ area = get_mask_area(seg_img, target_size, n_classes=len(cur_scores))
+ filtered_small = np.array([a <= 4 for a in area], dtype=bool)
+ else:
+ cur_classes = np.ones((1, 1), dtype=np.int64)
+
+ segments_info = [
+ {"id": i, "isthing": is_thing_map[cat], "category_id": int(cat), "area": a}
+ for i, (cat, a) in enumerate(zip(cur_classes, area))
+ ]
+ del cur_classes
+
+ with io.BytesIO() as out:
+ PIL.Image.fromarray(seg_img).save(out, format="PNG")
+ predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
+
+ return predictions
+
+
+# Copied from transformers.models.detr.image_processing_detr.resize_annotation
+def resize_annotation(
+ annotation: dict[str, Any],
+ orig_size: tuple[int, int],
+ target_size: tuple[int, int],
+ threshold: float = 0.5,
+ resample: PILImageResampling = PILImageResampling.NEAREST,
+):
+ """
+ Resizes an annotation to a target size.
+
+ Args:
+ annotation (`dict[str, Any]`):
+ The annotation dictionary.
+ orig_size (`tuple[int, int]`):
+ The original size of the input image.
+ target_size (`tuple[int, int]`):
+ The target size of the image, as returned by the preprocessing `resize` step.
+ threshold (`float`, *optional*, defaults to 0.5):
+ The threshold used to binarize the segmentation masks.
+ resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`):
+ The resampling filter to use when resizing the masks.
+ """
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size))
+ ratio_height, ratio_width = ratios
+
+ new_annotation = {}
+ new_annotation["size"] = target_size
+
+ for key, value in annotation.items():
+ if key == "boxes":
+ boxes = value
+ scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32)
+ new_annotation["boxes"] = scaled_boxes
+ elif key == "area":
+ area = value
+ scaled_area = area * (ratio_width * ratio_height)
+ new_annotation["area"] = scaled_area
+ elif key == "masks":
+ masks = value[:, None]
+ masks = np.array([resize(mask, target_size, resample=resample) for mask in masks])
+ masks = masks.astype(np.float32)
+ masks = masks[:, 0] > threshold
+ new_annotation["masks"] = masks
+ elif key == "size":
+ new_annotation["size"] = target_size
+ else:
+ new_annotation[key] = value
+
+ return new_annotation
+
+
+# Copied from transformers.models.detr.image_processing_detr.binary_mask_to_rle
+def binary_mask_to_rle(mask):
+ """
+ Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format.
+
+ Args:
+ mask (`torch.Tensor` or `numpy.array`):
+ A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
+ segment_id or class_id.
+ Returns:
+ `List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
+ format.
+ """
+ if is_torch_tensor(mask):
+ mask = mask.numpy()
+
+ pixels = mask.flatten()
+ pixels = np.concatenate([[0], pixels, [0]])
+ runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
+ runs[1::2] -= runs[::2]
+ return list(runs)
+
+
+# Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle
+def convert_segmentation_to_rle(segmentation):
+ """
+ Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format.
+
+ Args:
+ segmentation (`torch.Tensor` or `numpy.array`):
+ A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
+ Returns:
+ `list[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
+ """
+ segment_ids = torch.unique(segmentation)
+
+ run_length_encodings = []
+ for idx in segment_ids:
+ mask = torch.where(segmentation == idx, 1, 0)
+ rle = binary_mask_to_rle(mask)
+ run_length_encodings.append(rle)
+
+ return run_length_encodings
+
+
+# Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects
+def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
+ """
+ Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
+ `labels`.
+
+ Args:
+ masks (`torch.Tensor`):
+ A tensor of shape `(num_queries, height, width)`.
+ scores (`torch.Tensor`):
+ A tensor of shape `(num_queries)`.
+ labels (`torch.Tensor`):
+ A tensor of shape `(num_queries)`.
+ object_mask_threshold (`float`):
+ A number between 0 and 1 used to binarize the masks.
+ Raises:
+ `ValueError`: Raised when the first dimension doesn't match in all input tensors.
+ Returns:
+ `tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
+ < `object_mask_threshold`.
+ """
+ if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
+ raise ValueError("mask, scores and labels must have the same shape!")
+
+ to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
+
+ return masks[to_keep], scores[to_keep], labels[to_keep]
+
+
+# Copied from transformers.models.detr.image_processing_detr.check_segment_validity
+def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
+ # Get the mask associated with the k class
+ mask_k = mask_labels == k
+ mask_k_area = mask_k.sum()
+
+ # Compute the area of all the stuff in query k
+ original_area = (mask_probs[k] >= mask_threshold).sum()
+ mask_exists = mask_k_area > 0 and original_area > 0
+
+ # Eliminate disconnected tiny segments
+ if mask_exists:
+ area_ratio = mask_k_area / original_area
+ if not area_ratio.item() > overlap_mask_area_threshold:
+ mask_exists = False
+
+ return mask_exists, mask_k
+
+
+# Copied from transformers.models.detr.image_processing_detr.compute_segments
+def compute_segments(
+ mask_probs,
+ pred_scores,
+ pred_labels,
+ mask_threshold: float = 0.5,
+ overlap_mask_area_threshold: float = 0.8,
+ label_ids_to_fuse: Optional[set[int]] = None,
+ target_size: Optional[tuple[int, int]] = None,
+):
+ height = mask_probs.shape[1] if target_size is None else target_size[0]
+ width = mask_probs.shape[2] if target_size is None else target_size[1]
+
+ segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
+ segments: list[dict] = []
+
+ if target_size is not None:
+ mask_probs = nn.functional.interpolate(
+ mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
+ )[0]
+
+ current_segment_id = 0
+
+ # Weigh each mask by its prediction score
+ mask_probs *= pred_scores.view(-1, 1, 1)
+ mask_labels = mask_probs.argmax(0) # [height, width]
+
+ # Keep track of instances of each class
+ stuff_memory_list: dict[str, int] = {}
+ for k in range(pred_labels.shape[0]):
+ pred_class = pred_labels[k].item()
+ should_fuse = pred_class in label_ids_to_fuse
+
+ # Check if mask exists and large enough to be a segment
+ mask_exists, mask_k = check_segment_validity(
+ mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
+ )
+
+ if mask_exists:
+ if pred_class in stuff_memory_list:
+ current_segment_id = stuff_memory_list[pred_class]
+ else:
+ current_segment_id += 1
+
+ # Add current object segment to final segmentation map
+ segmentation[mask_k] = current_segment_id
+ segment_score = round(pred_scores[k].item(), 6)
+ segments.append(
+ {
+ "id": current_segment_id,
+ "label_id": pred_class,
+ "was_fused": should_fuse,
+ "score": segment_score,
+ }
+ )
+ if should_fuse:
+ stuff_memory_list[pred_class] = current_segment_id
+
+ return segmentation, segments
+
+
+@requires(backends=("torch", "vision"))
+class DeformableDetrImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Deformable DETR image processor.
+
+ Args:
+ format (`str`, *optional*, defaults to `"coco_detection"`):
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be
+ overridden by the `do_resize` parameter in the `preprocess` method.
+ size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 800, "longest_edge": 1333}`):
+ Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
+ in the `preprocess` method. Available options are:
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
+ Do NOT keep the aspect ratio.
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
+ less or equal to `longest_edge`.
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
+ `max_width`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+ `do_rescale` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_normalize:
+ Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
+ `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
+ Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
+ channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
+ Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
+ for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_convert_annotations (`bool`, *optional*, defaults to `True`):
+ Controls whether to convert the annotations to the format expected by the DETR model. Converts the
+ bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
+ Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
+ method. If `True`, padding will be applied to the bottom and right of the image with zeros.
+ If `pad_size` is provided, the image will be padded to the specified dimensions.
+ Otherwise, the image will be padded to the maximum height and width of the batch.
+ pad_size (`dict[str, int]`, *optional*):
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
+ height and width in the batch.
+ """
+
+ model_input_names = ["pixel_values", "pixel_mask"]
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.__init__
+ def __init__(
+ self,
+ format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_annotations: Optional[bool] = None,
+ do_pad: bool = True,
+ pad_size: Optional[dict[str, int]] = None,
+ **kwargs,
+ ) -> None:
+ if "pad_and_return_pixel_mask" in kwargs:
+ do_pad = kwargs.pop("pad_and_return_pixel_mask")
+
+ if "max_size" in kwargs:
+ logger.warning_once(
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
+ "Please specify in `size['longest_edge'] instead`.",
+ )
+ max_size = kwargs.pop("max_size")
+ else:
+ max_size = None if size is None else 1333
+
+ size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333}
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
+
+ # Backwards compatibility
+ if do_convert_annotations is None:
+ do_convert_annotations = do_normalize
+
+ super().__init__(**kwargs)
+ self.format = format
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.do_convert_annotations = do_convert_annotations
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
+ self.do_pad = do_pad
+ self.pad_size = pad_size
+ self._valid_processor_keys = [
+ "images",
+ "annotations",
+ "return_segmentation_masks",
+ "masks_path",
+ "do_resize",
+ "size",
+ "resample",
+ "do_rescale",
+ "rescale_factor",
+ "do_normalize",
+ "do_convert_annotations",
+ "image_mean",
+ "image_std",
+ "do_pad",
+ "pad_size",
+ "format",
+ "return_tensors",
+ "data_format",
+ "input_data_format",
+ ]
+
+ @classmethod
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->DeformableDetr
+ def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs):
+ """
+ Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
+ created using from_dict and kwargs e.g. `DeformableDetrImageProcessor.from_pretrained(checkpoint, size=600,
+ max_size=800)`
+ """
+ image_processor_dict = image_processor_dict.copy()
+ if "max_size" in kwargs:
+ image_processor_dict["max_size"] = kwargs.pop("max_size")
+ if "pad_and_return_pixel_mask" in kwargs:
+ image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
+ return super().from_dict(image_processor_dict, **kwargs)
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->DeformableDetr
+ def prepare_annotation(
+ self,
+ image: np.ndarray,
+ target: dict,
+ format: Optional[AnnotationFormat] = None,
+ return_segmentation_masks: Optional[bool] = None,
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> dict:
+ """
+ Prepare an annotation for feeding into DeformableDetr model.
+ """
+ format = format if format is not None else self.format
+
+ if format == AnnotationFormat.COCO_DETECTION:
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
+ target = prepare_coco_detection_annotation(
+ image, target, return_segmentation_masks, input_data_format=input_data_format
+ )
+ elif format == AnnotationFormat.COCO_PANOPTIC:
+ return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
+ target = prepare_coco_panoptic_annotation(
+ image,
+ target,
+ masks_path=masks_path,
+ return_masks=return_segmentation_masks,
+ input_data_format=input_data_format,
+ )
+ else:
+ raise ValueError(f"Format {format} is not supported.")
+ return target
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
+ int, smaller edge of the image will be matched to this number.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
+ Do NOT keep the aspect ratio.
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
+ less or equal to `longest_edge`.
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
+ `max_width`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ if "max_size" in kwargs:
+ logger.warning_once(
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
+ "Please specify in `size['longest_edge'] instead`.",
+ )
+ max_size = kwargs.pop("max_size")
+ else:
+ max_size = None
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
+ if "shortest_edge" in size and "longest_edge" in size:
+ new_size = get_resize_output_image_size(
+ image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
+ )
+ elif "max_height" in size and "max_width" in size:
+ new_size = get_image_size_for_max_height_width(
+ image, size["max_height"], size["max_width"], input_data_format=input_data_format
+ )
+ elif "height" in size and "width" in size:
+ new_size = (size["height"], size["width"])
+ else:
+ raise ValueError(
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
+ f" {size.keys()}."
+ )
+ image = resize(
+ image,
+ size=new_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+ return image
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation
+ def resize_annotation(
+ self,
+ annotation,
+ orig_size,
+ size,
+ resample: PILImageResampling = PILImageResampling.NEAREST,
+ ) -> dict:
+ """
+ Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched
+ to this number.
+ """
+ return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample)
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
+ def rescale(
+ self,
+ image: np.ndarray,
+ rescale_factor: float,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Rescale the image by the given factor. image = image * rescale_factor.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ rescale_factor (`float`):
+ The value to use for rescaling.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. If unset, is inferred from the input image. Can be
+ one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
+ def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict:
+ """
+ Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to
+ `[center_x, center_y, width, height]` format and from absolute to relative pixel values.
+ """
+ return normalize_annotation(annotation, image_size=image_size)
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._update_annotation_for_padded_image
+ def _update_annotation_for_padded_image(
+ self,
+ annotation: dict,
+ input_image_size: tuple[int, int],
+ output_image_size: tuple[int, int],
+ padding,
+ update_bboxes,
+ ) -> dict:
+ """
+ Update the annotation for a padded image.
+ """
+ new_annotation = {}
+ new_annotation["size"] = output_image_size
+
+ for key, value in annotation.items():
+ if key == "masks":
+ masks = value
+ masks = pad(
+ masks,
+ padding,
+ mode=PaddingMode.CONSTANT,
+ constant_values=0,
+ input_data_format=ChannelDimension.FIRST,
+ )
+ masks = safe_squeeze(masks, 1)
+ new_annotation["masks"] = masks
+ elif key == "boxes" and update_bboxes:
+ boxes = value
+ boxes *= np.asarray(
+ [
+ input_image_size[1] / output_image_size[1],
+ input_image_size[0] / output_image_size[0],
+ input_image_size[1] / output_image_size[1],
+ input_image_size[0] / output_image_size[0],
+ ]
+ )
+ new_annotation["boxes"] = boxes
+ elif key == "size":
+ new_annotation["size"] = output_image_size
+ else:
+ new_annotation[key] = value
+ return new_annotation
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
+ def _pad_image(
+ self,
+ image: np.ndarray,
+ output_size: tuple[int, int],
+ annotation: Optional[dict[str, Any]] = None,
+ constant_values: Union[float, Iterable[float]] = 0,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ update_bboxes: bool = True,
+ ) -> np.ndarray:
+ """
+ Pad an image with zeros to the given size.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ output_height, output_width = output_size
+
+ pad_bottom = output_height - input_height
+ pad_right = output_width - input_width
+ padding = ((0, pad_bottom), (0, pad_right))
+ padded_image = pad(
+ image,
+ padding,
+ mode=PaddingMode.CONSTANT,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ if annotation is not None:
+ annotation = self._update_annotation_for_padded_image(
+ annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes
+ )
+ return padded_image, annotation
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad
+ def pad(
+ self,
+ images: list[np.ndarray],
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
+ constant_values: Union[float, Iterable[float]] = 0,
+ return_pixel_mask: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ update_bboxes: bool = True,
+ pad_size: Optional[dict[str, int]] = None,
+ ) -> BatchFeature:
+ """
+ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
+ in the batch and optionally returns their corresponding pixel mask.
+
+ Args:
+ images (list[`np.ndarray`]):
+ Images to pad.
+ annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
+ Annotations to transform according to the padding that is applied to the images.
+ constant_values (`float` or `Iterable[float]`, *optional*):
+ The value to use for the padding if `mode` is `"constant"`.
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
+ Whether to return a pixel mask.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ update_bboxes (`bool`, *optional*, defaults to `True`):
+ Whether to update the bounding boxes in the annotations to match the padded images. If the
+ bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)`
+ format, the bounding boxes will not be updated.
+ pad_size (`dict[str, int]`, *optional*):
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
+ height and width in the batch.
+ """
+ pad_size = pad_size if pad_size is not None else self.pad_size
+ if pad_size is not None:
+ padded_size = (pad_size["height"], pad_size["width"])
+ else:
+ padded_size = get_max_height_width(images, input_data_format=input_data_format)
+
+ annotation_list = annotations if annotations is not None else [None] * len(images)
+ padded_images = []
+ padded_annotations = []
+ for image, annotation in zip(images, annotation_list):
+ padded_image, padded_annotation = self._pad_image(
+ image,
+ padded_size,
+ annotation,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ update_bboxes=update_bboxes,
+ )
+ padded_images.append(padded_image)
+ padded_annotations.append(padded_annotation)
+
+ data = {"pixel_values": padded_images}
+
+ if return_pixel_mask:
+ masks = [
+ make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format)
+ for image in images
+ ]
+ data["pixel_mask"] = masks
+
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ if annotations is not None:
+ encoded_inputs["labels"] = [
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations
+ ]
+
+ return encoded_inputs
+
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.preprocess
+ def preprocess(
+ self,
+ images: ImageInput,
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
+ return_segmentation_masks: Optional[bool] = None,
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample=None, # PILImageResampling
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[Union[int, float]] = None,
+ do_normalize: Optional[bool] = None,
+ do_convert_annotations: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ format: Optional[Union[str, AnnotationFormat]] = None,
+ return_tensors: Optional[Union[TensorType, str]] = None,
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ pad_size: Optional[dict[str, int]] = None,
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Preprocess an image or a batch of images so that it can be used by the model.
+
+ Args:
+ images (`ImageInput`):
+ Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
+ from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
+ List of annotations associated with the image or batch of images. If annotation is for object
+ detection, the annotations should be a dictionary with the following keys:
+ - "image_id" (`int`): The image id.
+ - "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a
+ dictionary. An image can have no annotations, in which case the list should be empty.
+ If annotation is for segmentation, the annotations should be a dictionary with the following keys:
+ - "image_id" (`int`): The image id.
+ - "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary.
+ An image can have no segments, in which case the list should be empty.
+ - "file_name" (`str`): The file name of the image.
+ return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
+ Whether to return segmentation masks.
+ masks_path (`str` or `pathlib.Path`, *optional*):
+ Path to the directory containing the segmentation masks.
+ do_resize (`bool`, *optional*, defaults to self.do_resize):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to self.size):
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
+ Do NOT keep the aspect ratio.
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
+ less or equal to `longest_edge`.
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
+ `max_width`.
+ resample (`PILImageResampling`, *optional*, defaults to self.resample):
+ Resampling filter to use when resizing the image.
+ do_rescale (`bool`, *optional*, defaults to self.do_rescale):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
+ Rescale factor to use when rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to self.do_normalize):
+ Whether to normalize the image.
+ do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
+ Whether to convert the annotations to the format expected by the model. Converts the bounding
+ boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
+ and in relative coordinates.
+ image_mean (`float` or `list[float]`, *optional*, defaults to self.image_mean):
+ Mean to use when normalizing the image.
+ image_std (`float` or `list[float]`, *optional*, defaults to self.image_std):
+ Standard deviation to use when normalizing the image.
+ do_pad (`bool`, *optional*, defaults to self.do_pad):
+ Whether to pad the image. If `True`, padding will be applied to the bottom and right of
+ the image with zeros. If `pad_size` is provided, the image will be padded to the specified
+ dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
+ format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
+ Format of the annotations.
+ return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
+ Type of tensors to return. If `None`, will return the list of images.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ pad_size (`dict[str, int]`, *optional*):
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
+ height and width in the batch.
+ """
+ if "pad_and_return_pixel_mask" in kwargs:
+ logger.warning_once(
+ "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, "
+ "use `do_pad` instead."
+ )
+ do_pad = kwargs.pop("pad_and_return_pixel_mask")
+
+ if "max_size" in kwargs:
+ logger.warning_once(
+ "The `max_size` argument is deprecated and will be removed in a future version, use"
+ " `size['longest_edge']` instead."
+ )
+ size = kwargs.pop("max_size")
+
+ do_resize = self.do_resize if do_resize is None else do_resize
+ size = self.size if size is None else size
+ size = get_size_dict(size=size, default_to_square=False)
+ resample = self.resample if resample is None else resample
+ do_rescale = self.do_rescale if do_rescale is None else do_rescale
+ rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
+ do_normalize = self.do_normalize if do_normalize is None else do_normalize
+ image_mean = self.image_mean if image_mean is None else image_mean
+ image_std = self.image_std if image_std is None else image_std
+ do_convert_annotations = (
+ self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
+ )
+ do_pad = self.do_pad if do_pad is None else do_pad
+ pad_size = self.pad_size if pad_size is None else pad_size
+ format = self.format if format is None else format
+
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
+
+ # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated.
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if annotations is not None and isinstance(annotations, dict):
+ annotations = [annotations]
+
+ if annotations is not None and len(images) != len(annotations):
+ raise ValueError(
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
+ )
+
+ format = AnnotationFormat(format)
+ if annotations is not None:
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
+
+ if (
+ masks_path is not None
+ and format == AnnotationFormat.COCO_PANOPTIC
+ and not isinstance(masks_path, (pathlib.Path, str))
+ ):
+ raise ValueError(
+ "The path to the directory containing the mask PNG files should be provided as a"
+ f" `pathlib.Path` or string object, but is {type(masks_path)} instead."
+ )
+
+ # All transformations expect numpy arrays
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
+ if annotations is not None:
+ prepared_images = []
+ prepared_annotations = []
+ for image, target in zip(images, annotations):
+ target = self.prepare_annotation(
+ image,
+ target,
+ format,
+ return_segmentation_masks=return_segmentation_masks,
+ masks_path=masks_path,
+ input_data_format=input_data_format,
+ )
+ prepared_images.append(image)
+ prepared_annotations.append(target)
+ images = prepared_images
+ annotations = prepared_annotations
+ del prepared_images, prepared_annotations
+
+ # transformations
+ if do_resize:
+ if annotations is not None:
+ resized_images, resized_annotations = [], []
+ for image, target in zip(images, annotations):
+ orig_size = get_image_size(image, input_data_format)
+ resized_image = self.resize(
+ image, size=size, resample=resample, input_data_format=input_data_format
+ )
+ resized_annotation = self.resize_annotation(
+ target, orig_size, get_image_size(resized_image, input_data_format)
+ )
+ resized_images.append(resized_image)
+ resized_annotations.append(resized_annotation)
+ images = resized_images
+ annotations = resized_annotations
+ del resized_images, resized_annotations
+ else:
+ images = [
+ self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_rescale:
+ images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
+
+ if do_normalize:
+ images = [
+ self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
+ ]
+
+ if do_convert_annotations and annotations is not None:
+ annotations = [
+ self.normalize_annotation(annotation, get_image_size(image, input_data_format))
+ for annotation, image in zip(annotations, images)
+ ]
+
+ if do_pad:
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
+ encoded_inputs = self.pad(
+ images,
+ annotations=annotations,
+ return_pixel_mask=True,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ update_bboxes=do_convert_annotations,
+ return_tensors=return_tensors,
+ pad_size=pad_size,
+ )
+ else:
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ for image in images
+ ]
+ encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
+ if annotations is not None:
+ encoded_inputs["labels"] = [
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
+ ]
+
+ return encoded_inputs
+
+ # POSTPROCESSING METHODS - TODO: add support for other frameworks
+ def post_process(self, outputs, target_sizes):
+ """
+ Converts the raw output of [`DeformableDetrForObjectDetection`] into final bounding boxes in (top_left_x,
+ top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch.
+
+ Args:
+ outputs ([`DeformableDetrObjectDetectionOutput`]):
+ Raw outputs of the model.
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
+ Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the
+ original image size (before any data augmentation). For visualization, this should be the image size
+ after data augment, but before padding.
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model.
+ """
+ logger.warning_once(
+ "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
+ " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
+ )
+
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
+
+ if len(out_logits) != len(target_sizes):
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
+ if target_sizes.shape[1] != 2:
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
+
+ prob = out_logits.sigmoid()
+ topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
+ scores = topk_values
+ topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
+ labels = topk_indexes % out_logits.shape[2]
+ boxes = center_to_corners_format(out_bbox)
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
+
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
+
+ return results
+
+ def post_process_object_detection(
+ self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None, top_k: int = 100
+ ):
+ """
+ Converts the raw output of [`DeformableDetrForObjectDetection`] into final bounding boxes in (top_left_x,
+ top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch.
+
+ Args:
+ outputs ([`DetrObjectDetectionOutput`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*):
+ Score threshold to keep object detection predictions.
+ target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
+ (height, width) of each image in the batch. If left to None, predictions will not be resized.
+ top_k (`int`, *optional*, defaults to 100):
+ Keep only top k bounding boxes before filtering by thresholding.
+
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model.
+ """
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
+
+ if target_sizes is not None:
+ if len(out_logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ prob = out_logits.sigmoid()
+ prob = prob.view(out_logits.shape[0], -1)
+ k_value = min(top_k, prob.size(1))
+ topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
+ scores = topk_values
+ topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
+ labels = topk_indexes % out_logits.shape[2]
+ boxes = center_to_corners_format(out_bbox)
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
+
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ if target_sizes is not None:
+ if isinstance(target_sizes, list):
+ img_h = torch.Tensor([i[0] for i in target_sizes])
+ img_w = torch.Tensor([i[1] for i in target_sizes])
+ else:
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = []
+ for s, l, b in zip(scores, labels, boxes):
+ score = s[s > threshold]
+ label = l[s > threshold]
+ box = b[s > threshold]
+ results.append({"scores": score, "labels": label, "boxes": box})
+
+ return results
+
+
+__all__ = ["DeformableDetrImageProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py b/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..8458d02d58a52635f38361e6da05c22294c0a0f6
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py
@@ -0,0 +1,795 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/deformable_detr/modular_deformable_detr.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_deformable_detr.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+import pathlib
+from typing import Any, Optional, Union
+
+import torch
+from torchvision.io import read_image
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils import BatchFeature, get_size_dict
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ SizeDict,
+ get_image_size_for_max_height_width,
+ get_max_height_width,
+ safe_squeeze,
+)
+from ...image_transforms import center_to_corners_format, corners_to_center_format
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ AnnotationFormat,
+ AnnotationType,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ validate_annotations,
+)
+from ...processing_utils import Unpack
+from ...utils import TensorType, auto_docstring, logging
+from ...utils.import_utils import requires
+from .image_processing_deformable_detr import get_size_with_aspect_ratio
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeformableDetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ r"""
+ format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
+ do_convert_annotations (`bool`, *optional*, defaults to `True`):
+ Controls whether to convert the annotations to the format expected by the DEFORMABLE_DETR model. Converts the
+ bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
+ Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
+ return_segmentation_masks (`bool`, *optional*, defaults to `False`):
+ Whether to return segmentation masks.
+ """
+
+ format: Optional[Union[str, AnnotationFormat]]
+ do_convert_annotations: Optional[bool]
+ return_segmentation_masks: Optional[bool]
+
+
+SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
+
+
+# inspired by https://github.com/facebookresearch/deformable_detr/blob/master/datasets/coco.py#L33
+def convert_coco_poly_to_mask(segmentations, height: int, width: int, device: torch.device) -> torch.Tensor:
+ """
+ Convert a COCO polygon annotation to a mask.
+
+ Args:
+ segmentations (`list[list[float]]`):
+ List of polygons, each polygon represented by a list of x-y coordinates.
+ height (`int`):
+ Height of the mask.
+ width (`int`):
+ Width of the mask.
+ """
+ try:
+ from pycocotools import mask as coco_mask
+ except ImportError:
+ raise ImportError("Pycocotools is not installed in your environment.")
+
+ masks = []
+ for polygons in segmentations:
+ rles = coco_mask.frPyObjects(polygons, height, width)
+ mask = coco_mask.decode(rles)
+ if len(mask.shape) < 3:
+ mask = mask[..., None]
+ mask = torch.as_tensor(mask, dtype=torch.uint8, device=device)
+ mask = torch.any(mask, axis=2)
+ masks.append(mask)
+ if masks:
+ masks = torch.stack(masks, axis=0)
+ else:
+ masks = torch.zeros((0, height, width), dtype=torch.uint8, device=device)
+
+ return masks
+
+
+# inspired by https://github.com/facebookresearch/deformable_detr/blob/master/datasets/coco.py#L50
+def prepare_coco_detection_annotation(
+ image,
+ target,
+ return_segmentation_masks: bool = False,
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
+):
+ """
+ Convert the target in COCO format into the format expected by DEFORMABLE_DETR.
+ """
+ image_height, image_width = image.size()[-2:]
+
+ image_id = target["image_id"]
+ image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device)
+
+ # Get all COCO annotations for the given image.
+ annotations = target["annotations"]
+ classes = []
+ area = []
+ boxes = []
+ keypoints = []
+ for obj in annotations:
+ if "iscrowd" not in obj or obj["iscrowd"] == 0:
+ classes.append(obj["category_id"])
+ area.append(obj["area"])
+ boxes.append(obj["bbox"])
+ if "keypoints" in obj:
+ keypoints.append(obj["keypoints"])
+
+ classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device)
+ area = torch.as_tensor(area, dtype=torch.float32, device=image.device)
+ iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device)
+ # guard against no boxes via resizing
+ boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4)
+ boxes[:, 2:] += boxes[:, :2]
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
+
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+
+ new_target = {
+ "image_id": image_id,
+ "class_labels": classes[keep],
+ "boxes": boxes[keep],
+ "area": area[keep],
+ "iscrowd": iscrowd[keep],
+ "orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device),
+ }
+
+ if keypoints:
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device)
+ # Apply the keep mask here to filter the relevant annotations
+ keypoints = keypoints[keep]
+ num_keypoints = keypoints.shape[0]
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
+ new_target["keypoints"] = keypoints
+
+ if return_segmentation_masks:
+ segmentation_masks = [obj["segmentation"] for obj in annotations]
+ masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width, device=image.device)
+ new_target["masks"] = masks[keep]
+
+ return new_target
+
+
+def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the bounding boxes around the provided panoptic segmentation masks.
+
+ Args:
+ masks: masks in format `[number_masks, height, width]` where N is the number of masks
+
+ Returns:
+ boxes: bounding boxes in format `[number_masks, 4]` in xyxy format
+ """
+ if masks.numel() == 0:
+ return torch.zeros((0, 4), device=masks.device)
+
+ h, w = masks.shape[-2:]
+ y = torch.arange(0, h, dtype=torch.float32, device=masks.device)
+ x = torch.arange(0, w, dtype=torch.float32, device=masks.device)
+ # see https://github.com/pytorch/pytorch/issues/50276
+ y, x = torch.meshgrid(y, x, indexing="ij")
+
+ x_mask = masks * torch.unsqueeze(x, 0)
+ x_max = x_mask.view(x_mask.shape[0], -1).max(-1)[0]
+ x_min = (
+ torch.where(masks, x.unsqueeze(0), torch.tensor(1e8, device=masks.device)).view(masks.shape[0], -1).min(-1)[0]
+ )
+
+ y_mask = masks * torch.unsqueeze(y, 0)
+ y_max = y_mask.view(y_mask.shape[0], -1).max(-1)[0]
+ y_min = (
+ torch.where(masks, y.unsqueeze(0), torch.tensor(1e8, device=masks.device)).view(masks.shape[0], -1).min(-1)[0]
+ )
+
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
+
+
+# 2 functions below adapted from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
+# Copyright (c) 2018, Alexander Kirillov
+# All rights reserved.
+def rgb_to_id(color):
+ """
+ Converts RGB color to unique ID.
+ """
+ if isinstance(color, torch.Tensor) and len(color.shape) == 3:
+ if color.dtype == torch.uint8:
+ color = color.to(torch.int32)
+ return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
+ return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
+
+
+def prepare_coco_panoptic_annotation(
+ image: torch.Tensor,
+ target: dict,
+ masks_path: Union[str, pathlib.Path],
+ return_masks: bool = True,
+ input_data_format: Union[ChannelDimension, str] = None,
+) -> dict:
+ """
+ Prepare a coco panoptic annotation for DEFORMABLE_DETR.
+ """
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
+ annotation_path = pathlib.Path(masks_path) / target["file_name"]
+
+ new_target = {}
+ new_target["image_id"] = torch.as_tensor(
+ [target["image_id"] if "image_id" in target else target["id"]], dtype=torch.int64, device=image.device
+ )
+ new_target["size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device)
+ new_target["orig_size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device)
+
+ if "segments_info" in target:
+ masks = read_image(annotation_path).permute(1, 2, 0).to(dtype=torch.int32, device=image.device)
+ masks = rgb_to_id(masks)
+
+ ids = torch.as_tensor([segment_info["id"] for segment_info in target["segments_info"]], device=image.device)
+ masks = masks == ids[:, None, None]
+ masks = masks.to(torch.bool)
+ if return_masks:
+ new_target["masks"] = masks
+ new_target["boxes"] = masks_to_boxes(masks)
+ new_target["class_labels"] = torch.as_tensor(
+ [segment_info["category_id"] for segment_info in target["segments_info"]],
+ dtype=torch.int64,
+ device=image.device,
+ )
+ new_target["iscrowd"] = torch.as_tensor(
+ [segment_info["iscrowd"] for segment_info in target["segments_info"]],
+ dtype=torch.int64,
+ device=image.device,
+ )
+ new_target["area"] = torch.as_tensor(
+ [segment_info["area"] for segment_info in target["segments_info"]],
+ dtype=torch.float32,
+ device=image.device,
+ )
+
+ return new_target
+
+
+@auto_docstring
+@requires(backends=("torchvision", "torch"))
+class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_DEFAULT_MEAN
+ image_std = IMAGENET_DEFAULT_STD
+ format = AnnotationFormat.COCO_DETECTION
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_pad = True
+ size = {"shortest_edge": 800, "longest_edge": 1333}
+ default_to_square = False
+ model_input_names = ["pixel_values", "pixel_mask"]
+ valid_kwargs = DeformableDetrFastImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[DeformableDetrFastImageProcessorKwargs]) -> None:
+ if "pad_and_return_pixel_mask" in kwargs:
+ kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
+
+ size = kwargs.pop("size", None)
+ if "max_size" in kwargs:
+ logger.warning_once(
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
+ "Please specify in `size['longest_edge'] instead`.",
+ )
+ max_size = kwargs.pop("max_size")
+ else:
+ max_size = None if size is None else 1333
+
+ size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333}
+ self.size = get_size_dict(size, max_size=max_size, default_to_square=False)
+
+ # Backwards compatibility
+ do_convert_annotations = kwargs.get("do_convert_annotations")
+ do_normalize = kwargs.get("do_normalize")
+ if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
+ self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
+
+ super().__init__(**kwargs)
+
+ @classmethod
+ def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs):
+ """
+ Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
+ created using from_dict and kwargs e.g. `DeformableDetrImageProcessorFast.from_pretrained(checkpoint, size=600,
+ max_size=800)`
+ """
+ image_processor_dict = image_processor_dict.copy()
+ if "max_size" in kwargs:
+ image_processor_dict["max_size"] = kwargs.pop("max_size")
+ if "pad_and_return_pixel_mask" in kwargs:
+ image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
+ return super().from_dict(image_processor_dict, **kwargs)
+
+ def prepare_annotation(
+ self,
+ image: torch.Tensor,
+ target: dict,
+ format: Optional[AnnotationFormat] = None,
+ return_segmentation_masks: Optional[bool] = None,
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> dict:
+ """
+ Prepare an annotation for feeding into DEFORMABLE_DETR model.
+ """
+ format = format if format is not None else self.format
+
+ if format == AnnotationFormat.COCO_DETECTION:
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
+ target = prepare_coco_detection_annotation(
+ image, target, return_segmentation_masks, input_data_format=input_data_format
+ )
+ elif format == AnnotationFormat.COCO_PANOPTIC:
+ return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
+ target = prepare_coco_panoptic_annotation(
+ image,
+ target,
+ masks_path=masks_path,
+ return_masks=return_segmentation_masks,
+ input_data_format=input_data_format,
+ )
+ else:
+ raise ValueError(f"Format {format} is not supported.")
+ return target
+
+ def resize(
+ self,
+ image: torch.Tensor,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
+ int, smaller edge of the image will be matched to this number.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to resize.
+ size (`SizeDict`):
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
+ Do NOT keep the aspect ratio.
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
+ less or equal to `longest_edge`.
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
+ `max_width`.
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ Resampling filter to use if resizing the image.
+ """
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
+ if size.shortest_edge and size.longest_edge:
+ # Resize the image so that the shortest edge or the longest edge is of the given size
+ # while maintaining the aspect ratio of the original image.
+ new_size = get_size_with_aspect_ratio(
+ image.size()[-2:],
+ size["shortest_edge"],
+ size["longest_edge"],
+ )
+ elif size.max_height and size.max_width:
+ new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"])
+ elif size.height and size.width:
+ new_size = (size["height"], size["width"])
+ else:
+ raise ValueError(
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
+ f" {size.keys()}."
+ )
+
+ image = F.resize(
+ image,
+ size=new_size,
+ interpolation=interpolation,
+ **kwargs,
+ )
+ return image
+
+ def resize_annotation(
+ self,
+ annotation: dict[str, Any],
+ orig_size: tuple[int, int],
+ target_size: tuple[int, int],
+ threshold: float = 0.5,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ ):
+ """
+ Resizes an annotation to a target size.
+
+ Args:
+ annotation (`dict[str, Any]`):
+ The annotation dictionary.
+ orig_size (`tuple[int, int]`):
+ The original size of the input image.
+ target_size (`tuple[int, int]`):
+ The target size of the image, as returned by the preprocessing `resize` step.
+ threshold (`float`, *optional*, defaults to 0.5):
+ The threshold used to binarize the segmentation masks.
+ resample (`InterpolationMode`, defaults to `F.InterpolationMode.NEAREST_EXACT`):
+ The resampling filter to use when resizing the masks.
+ """
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.NEAREST_EXACT
+ ratio_height, ratio_width = [target / orig for target, orig in zip(target_size, orig_size)]
+
+ new_annotation = {}
+ new_annotation["size"] = target_size
+
+ for key, value in annotation.items():
+ if key == "boxes":
+ boxes = value
+ scaled_boxes = boxes * torch.as_tensor(
+ [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32, device=boxes.device
+ )
+ new_annotation["boxes"] = scaled_boxes
+ elif key == "area":
+ area = value
+ scaled_area = area * (ratio_width * ratio_height)
+ new_annotation["area"] = scaled_area
+ elif key == "masks":
+ masks = value[:, None]
+ masks = [F.resize(mask, target_size, interpolation=interpolation) for mask in masks]
+ masks = torch.stack(masks).to(torch.float32)
+ masks = masks[:, 0] > threshold
+ new_annotation["masks"] = masks
+ elif key == "size":
+ new_annotation["size"] = target_size
+ else:
+ new_annotation[key] = value
+
+ return new_annotation
+
+ def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict:
+ image_height, image_width = image_size
+ norm_annotation = {}
+ for key, value in annotation.items():
+ if key == "boxes":
+ boxes = value
+ boxes = corners_to_center_format(boxes)
+ boxes /= torch.as_tensor(
+ [image_width, image_height, image_width, image_height], dtype=torch.float32, device=boxes.device
+ )
+ norm_annotation[key] = boxes
+ else:
+ norm_annotation[key] = value
+ return norm_annotation
+
+ def _update_annotation_for_padded_image(
+ self,
+ annotation: dict,
+ input_image_size: tuple[int, int],
+ output_image_size: tuple[int, int],
+ padding,
+ update_bboxes,
+ ) -> dict:
+ """
+ Update the annotation for a padded image.
+ """
+ new_annotation = {}
+ new_annotation["size"] = output_image_size
+ ratio_height, ratio_width = (input / output for output, input in zip(output_image_size, input_image_size))
+
+ for key, value in annotation.items():
+ if key == "masks":
+ masks = value
+ masks = F.pad(
+ masks,
+ padding,
+ fill=0,
+ )
+ masks = safe_squeeze(masks, 1)
+ new_annotation["masks"] = masks
+ elif key == "boxes" and update_bboxes:
+ boxes = value
+ boxes *= torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height], device=boxes.device)
+ new_annotation["boxes"] = boxes
+ elif key == "size":
+ new_annotation["size"] = output_image_size
+ else:
+ new_annotation[key] = value
+ return new_annotation
+
+ def pad(
+ self,
+ image: torch.Tensor,
+ padded_size: tuple[int, int],
+ annotation: Optional[dict[str, Any]] = None,
+ update_bboxes: bool = True,
+ fill: int = 0,
+ ):
+ original_size = image.size()[-2:]
+ padding_bottom = padded_size[0] - original_size[0]
+ padding_right = padded_size[1] - original_size[1]
+ if padding_bottom < 0 or padding_right < 0:
+ raise ValueError(
+ f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
+ f"original size. Got padded size: {padded_size}, original size: {original_size}."
+ )
+ if original_size != padded_size:
+ padding = [0, 0, padding_right, padding_bottom]
+ image = F.pad(image, padding, fill=fill)
+ if annotation is not None:
+ annotation = self._update_annotation_for_padded_image(
+ annotation, original_size, padded_size, padding, update_bboxes
+ )
+
+ # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+ pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device)
+ pixel_mask[: original_size[0], : original_size[1]] = 1
+
+ return image, pixel_mask, annotation
+
+ @auto_docstring
+ def preprocess(
+ self,
+ images: ImageInput,
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
+ **kwargs: Unpack[DeformableDetrFastImageProcessorKwargs],
+ ) -> BatchFeature:
+ r"""
+ annotations (`AnnotationType` or `list[AnnotationType]`, *optional*):
+ List of annotations associated with the image or batch of images. If annotation is for object
+ detection, the annotations should be a dictionary with the following keys:
+ - "image_id" (`int`): The image id.
+ - "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a
+ dictionary. An image can have no annotations, in which case the list should be empty.
+ If annotation is for segmentation, the annotations should be a dictionary with the following keys:
+ - "image_id" (`int`): The image id.
+ - "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary.
+ An image can have no segments, in which case the list should be empty.
+ - "file_name" (`str`): The file name of the image.
+ masks_path (`str` or `pathlib.Path`, *optional*):
+ Path to the directory containing the segmentation masks.
+ """
+ if "pad_and_return_pixel_mask" in kwargs:
+ kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
+ logger.warning_once(
+ "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, "
+ "use `do_pad` instead."
+ )
+
+ if "max_size" in kwargs:
+ logger.warning_once(
+ "The `max_size` argument is deprecated and will be removed in a future version, use"
+ " `size['longest_edge']` instead."
+ )
+ kwargs["size"] = kwargs.pop("max_size")
+
+ return super().preprocess(images, annotations, masks_path, **kwargs)
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ annotations: Optional[Union[AnnotationType, list[AnnotationType]]],
+ masks_path: Optional[Union[str, pathlib.Path]],
+ return_segmentation_masks: bool,
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ do_convert_annotations: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ do_pad: bool,
+ pad_size: Optional[SizeDict],
+ format: Optional[Union[str, AnnotationFormat]],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Preprocess an image or a batch of images so that it can be used by the model.
+ """
+ if annotations is not None and isinstance(annotations, dict):
+ annotations = [annotations]
+
+ if annotations is not None and len(images) != len(annotations):
+ raise ValueError(
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
+ )
+
+ format = AnnotationFormat(format)
+ if annotations is not None:
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
+
+ if (
+ masks_path is not None
+ and format == AnnotationFormat.COCO_PANOPTIC
+ and not isinstance(masks_path, (pathlib.Path, str))
+ ):
+ raise ValueError(
+ "The path to the directory containing the mask PNG files should be provided as a"
+ f" `pathlib.Path` or string object, but is {type(masks_path)} instead."
+ )
+
+ data = {}
+
+ processed_images = []
+ processed_annotations = []
+ pixel_masks = [] # Initialize pixel_masks here
+ for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
+ # prepare (COCO annotations as a list of Dict -> DEFORMABLE_DETR target as a single Dict per image)
+ if annotations is not None:
+ annotation = self.prepare_annotation(
+ image,
+ annotation,
+ format,
+ return_segmentation_masks=return_segmentation_masks,
+ masks_path=masks_path,
+ input_data_format=ChannelDimension.FIRST,
+ )
+
+ if do_resize:
+ resized_image = self.resize(image, size=size, interpolation=interpolation)
+ if annotations is not None:
+ annotation = self.resize_annotation(
+ annotation,
+ orig_size=image.size()[-2:],
+ target_size=resized_image.size()[-2:],
+ )
+ image = resized_image
+ # Fused rescale and normalize
+ image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
+ if do_convert_annotations and annotations is not None:
+ annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
+
+ processed_images.append(image)
+ processed_annotations.append(annotation)
+ images = processed_images
+ annotations = processed_annotations if annotations is not None else None
+
+ if do_pad:
+ # depends on all resized image shapes so we need another loop
+ if pad_size is not None:
+ padded_size = (pad_size.height, pad_size.width)
+ else:
+ padded_size = get_max_height_width(images)
+
+ padded_images = []
+ padded_annotations = []
+ for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)):
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
+ if padded_size == image.size()[-2:]:
+ padded_images.append(image)
+ pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device))
+ padded_annotations.append(annotation)
+ continue
+ image, pixel_mask, annotation = self.pad(
+ image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations
+ )
+ padded_images.append(image)
+ padded_annotations.append(annotation)
+ pixel_masks.append(pixel_mask)
+ images = padded_images
+ annotations = padded_annotations if annotations is not None else None
+ data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)})
+
+ data.update({"pixel_values": torch.stack(images, dim=0)})
+ encoded_inputs = BatchFeature(data, tensor_type=return_tensors)
+ if annotations is not None:
+ encoded_inputs["labels"] = [
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
+ ]
+ return encoded_inputs
+
+ def post_process(self, outputs, target_sizes):
+ """
+ Converts the raw output of [`DeformableDetrForObjectDetection`] into final bounding boxes in (top_left_x,
+ top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch.
+
+ Args:
+ outputs ([`DeformableDetrObjectDetectionOutput`]):
+ Raw outputs of the model.
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
+ Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the
+ original image size (before any data augmentation). For visualization, this should be the image size
+ after data augment, but before padding.
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model.
+ """
+ logger.warning_once(
+ "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
+ " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
+ )
+
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
+
+ if len(out_logits) != len(target_sizes):
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
+ if target_sizes.shape[1] != 2:
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
+
+ prob = out_logits.sigmoid()
+ topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
+ scores = topk_values
+ topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
+ labels = topk_indexes % out_logits.shape[2]
+ boxes = center_to_corners_format(out_bbox)
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
+
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
+
+ return results
+
+ def post_process_object_detection(
+ self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None, top_k: int = 100
+ ):
+ """
+ Converts the raw output of [`DeformableDetrForObjectDetection`] into final bounding boxes in (top_left_x,
+ top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch.
+
+ Args:
+ outputs ([`DetrObjectDetectionOutput`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*):
+ Score threshold to keep object detection predictions.
+ target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
+ (height, width) of each image in the batch. If left to None, predictions will not be resized.
+ top_k (`int`, *optional*, defaults to 100):
+ Keep only top k bounding boxes before filtering by thresholding.
+
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model.
+ """
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
+
+ if target_sizes is not None:
+ if len(out_logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ prob = out_logits.sigmoid()
+ prob = prob.view(out_logits.shape[0], -1)
+ k_value = min(top_k, prob.size(1))
+ topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
+ scores = topk_values
+ topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
+ labels = topk_indexes % out_logits.shape[2]
+ boxes = center_to_corners_format(out_bbox)
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
+
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ if target_sizes is not None:
+ if isinstance(target_sizes, list):
+ img_h = torch.Tensor([i[0] for i in target_sizes])
+ img_w = torch.Tensor([i[1] for i in target_sizes])
+ else:
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = []
+ for s, l, b in zip(scores, labels, boxes):
+ score = s[s > threshold]
+ label = l[s > threshold]
+ box = b[s > threshold]
+ results.append({"scores": score, "labels": label, "boxes": box})
+
+ return results
+
+
+__all__ = ["DeformableDetrImageProcessorFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/modeling_deformable_detr.py b/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/modeling_deformable_detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..34f5bce7a5c4f6bc12d8fdf670a0bc381f7f6850
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/modeling_deformable_detr.py
@@ -0,0 +1,1897 @@
+# coding=utf-8
+# Copyright 2022 SenseTime and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Deformable DETR model."""
+
+import copy
+import math
+import warnings
+from dataclasses import dataclass
+from typing import Any, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from ...activations import ACT2FN
+from ...integrations import use_kernel_forward_from_hub
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import meshgrid
+from ...utils import (
+ ModelOutput,
+ auto_docstring,
+ is_timm_available,
+ logging,
+ requires_backends,
+)
+from ...utils.backbone_utils import load_backbone
+from .configuration_deformable_detr import DeformableDetrConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_timm_available():
+ from timm import create_model
+
+
+logger = logging.get_logger(__name__)
+
+
+@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
+class MultiScaleDeformableAttention(nn.Module):
+ def forward(
+ self,
+ value: Tensor,
+ value_spatial_shapes: Tensor,
+ value_spatial_shapes_list: list[tuple],
+ level_start_index: Tensor,
+ sampling_locations: Tensor,
+ attention_weights: Tensor,
+ im2col_step: int,
+ ):
+ batch_size, _, num_heads, hidden_dim = value.shape
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
+ value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for level_id, (height, width) in enumerate(value_spatial_shapes_list):
+ # batch_size, height*width, num_heads, hidden_dim
+ # -> batch_size, height*width, num_heads*hidden_dim
+ # -> batch_size, num_heads*hidden_dim, height*width
+ # -> batch_size*num_heads, hidden_dim, height, width
+ value_l_ = (
+ value_list[level_id]
+ .flatten(2)
+ .transpose(1, 2)
+ .reshape(batch_size * num_heads, hidden_dim, height, width)
+ )
+ # batch_size, num_queries, num_heads, num_points, 2
+ # -> batch_size, num_heads, num_queries, num_points, 2
+ # -> batch_size*num_heads, num_queries, num_points, 2
+ sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
+ sampling_value_l_ = nn.functional.grid_sample(
+ value_l_,
+ sampling_grid_l_,
+ mode="bilinear",
+ padding_mode="zeros",
+ align_corners=False,
+ )
+ sampling_value_list.append(sampling_value_l_)
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
+ attention_weights = attention_weights.transpose(1, 2).reshape(
+ batch_size * num_heads, 1, num_queries, num_levels * num_points
+ )
+ output = (
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
+ .sum(-1)
+ .view(batch_size, num_heads * hidden_dim, num_queries)
+ )
+ return output.transpose(1, 2).contiguous()
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of the DeformableDetrDecoder. This class adds two attributes to
+ BaseModelOutputWithCrossAttentions, namely:
+ - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
+ - a stacked tensor of intermediate reference points.
+ """
+)
+class DeformableDetrDecoderOutput(ModelOutput):
+ r"""
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
+ used to compute the weighted average in the cross-attention heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of the Deformable DETR encoder-decoder model.
+ """
+)
+class DeformableDetrModelOutput(ModelOutput):
+ r"""
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Initial reference points sent through the Transformer decoder.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
+ foreground and background).
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Logits of predicted bounding boxes coordinates in the first stage.
+ """
+
+ init_reference_points: Optional[torch.FloatTensor] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ enc_outputs_class: Optional[torch.FloatTensor] = None
+ enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`DeformableDetrForObjectDetection`].
+ """
+)
+class DeformableDetrObjectDetectionOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
+ scale-invariant IoU loss.
+ loss_dict (`Dict`, *optional*):
+ A dictionary containing the individual losses. Useful for logging.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
+ Classification logits (including no-object) for all queries.
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
+ possible padding). You can use [`~DeformableDetrProcessor.post_process_object_detection`] to retrieve the
+ unnormalized bounding boxes.
+ auxiliary_outputs (`list[Dict]`, *optional*):
+ Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
+ `pred_boxes`) for each decoder layer.
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Initial reference points sent through the Transformer decoder.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
+ Stacked intermediate hidden states (output of each layer of the decoder).
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
+ Stacked intermediate reference points (reference points of each layer of the decoder).
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
+ foreground and background).
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
+ Logits of predicted bounding boxes coordinates in the first stage.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ loss_dict: Optional[dict] = None
+ logits: Optional[torch.FloatTensor] = None
+ pred_boxes: Optional[torch.FloatTensor] = None
+ auxiliary_outputs: Optional[list[dict]] = None
+ init_reference_points: Optional[torch.FloatTensor] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
+ intermediate_reference_points: Optional[torch.FloatTensor] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ enc_outputs_class: Any = None
+ enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def inverse_sigmoid(x, eps=1e-5):
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->DeformableDetr
+class DeformableDetrFrozenBatchNorm2d(nn.Module):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
+ torchvision.models.resnet[18,34,50,101] produce nans.
+ """
+
+ def __init__(self, n):
+ super().__init__()
+ self.register_buffer("weight", torch.ones(n))
+ self.register_buffer("bias", torch.zeros(n))
+ self.register_buffer("running_mean", torch.zeros(n))
+ self.register_buffer("running_var", torch.ones(n))
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ num_batches_tracked_key = prefix + "num_batches_tracked"
+ if num_batches_tracked_key in state_dict:
+ del state_dict[num_batches_tracked_key]
+
+ super()._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ )
+
+ def forward(self, x):
+ # move reshapes to the beginning
+ # to make it user-friendly
+ weight = self.weight.reshape(1, -1, 1, 1)
+ bias = self.bias.reshape(1, -1, 1, 1)
+ running_var = self.running_var.reshape(1, -1, 1, 1)
+ running_mean = self.running_mean.reshape(1, -1, 1, 1)
+ epsilon = 1e-5
+ scale = weight * (running_var + epsilon).rsqrt()
+ bias = bias - running_mean * scale
+ return x * scale + bias
+
+
+# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->DeformableDetr
+def replace_batch_norm(model):
+ r"""
+ Recursively replace all `torch.nn.BatchNorm2d` with `DeformableDetrFrozenBatchNorm2d`.
+
+ Args:
+ model (torch.nn.Module):
+ input model
+ """
+ for name, module in model.named_children():
+ if isinstance(module, nn.BatchNorm2d):
+ new_module = DeformableDetrFrozenBatchNorm2d(module.num_features)
+
+ if module.weight.device != torch.device("meta"):
+ new_module.weight.data.copy_(module.weight)
+ new_module.bias.data.copy_(module.bias)
+ new_module.running_mean.data.copy_(module.running_mean)
+ new_module.running_var.data.copy_(module.running_var)
+
+ model._modules[name] = new_module
+
+ if len(list(module.children())) > 0:
+ replace_batch_norm(module)
+
+
+class DeformableDetrConvEncoder(nn.Module):
+ """
+ Convolutional backbone, using either the AutoBackbone API or one from the timm library.
+
+ nn.BatchNorm2d layers are replaced by DeformableDetrFrozenBatchNorm2d as defined above.
+
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+
+ # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
+ if config.use_timm_backbone:
+ # We default to values which were previously hard-coded. This enables configurability from the config
+ # using backbone arguments, while keeping the default behavior the same.
+ requires_backends(self, ["timm"])
+ kwargs = getattr(config, "backbone_kwargs", {})
+ kwargs = {} if kwargs is None else kwargs.copy()
+ out_indices = kwargs.pop("out_indices", (2, 3, 4) if config.num_feature_levels > 1 else (4,))
+ num_channels = kwargs.pop("in_chans", config.num_channels)
+ if config.dilation:
+ kwargs["output_stride"] = kwargs.get("output_stride", 16)
+ backbone = create_model(
+ config.backbone,
+ pretrained=config.use_pretrained_backbone,
+ features_only=True,
+ out_indices=out_indices,
+ in_chans=num_channels,
+ **kwargs,
+ )
+ else:
+ backbone = load_backbone(config)
+
+ # replace batch norm by frozen batch norm
+ with torch.no_grad():
+ replace_batch_norm(backbone)
+ self.model = backbone
+ self.intermediate_channel_sizes = (
+ self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
+ )
+
+ backbone_model_type = None
+ if config.backbone is not None:
+ backbone_model_type = config.backbone
+ elif config.backbone_config is not None:
+ backbone_model_type = config.backbone_config.model_type
+ else:
+ raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")
+
+ if "resnet" in backbone_model_type:
+ for name, parameter in self.model.named_parameters():
+ if config.use_timm_backbone:
+ if "layer2" not in name and "layer3" not in name and "layer4" not in name:
+ parameter.requires_grad_(False)
+ else:
+ if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
+ parameter.requires_grad_(False)
+
+ # Copied from transformers.models.detr.modeling_detr.DetrConvEncoder.forward with Detr->DeformableDetr
+ def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
+ # send pixel_values through the model to get list of feature maps
+ features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
+
+ out = []
+ for feature_map in features:
+ # downsample pixel_mask to match shape of corresponding feature_map
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
+ out.append((feature_map, mask))
+ return out
+
+
+# Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->DeformableDetr
+class DeformableDetrConvModel(nn.Module):
+ """
+ This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
+ """
+
+ def __init__(self, conv_encoder, position_embedding):
+ super().__init__()
+ self.conv_encoder = conv_encoder
+ self.position_embedding = position_embedding
+
+ def forward(self, pixel_values, pixel_mask):
+ # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
+ out = self.conv_encoder(pixel_values, pixel_mask)
+ pos = []
+ for feature_map, mask in out:
+ # position encoding
+ pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
+
+ return out, pos
+
+
+class DeformableDetrSinePositionEmbedding(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
+ need paper, generalized to work on images.
+ """
+
+ def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, pixel_values, pixel_mask):
+ if pixel_mask is None:
+ raise ValueError("No pixel mask provided")
+ y_embed = pixel_mask.cumsum(1, dtype=pixel_values.dtype)
+ x_embed = pixel_mask.cumsum(2, dtype=pixel_values.dtype)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.embedding_dim, dtype=pixel_values.dtype, device=pixel_values.device)
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+# Copied from transformers.models.detr.modeling_detr.DetrLearnedPositionEmbedding
+class DeformableDetrLearnedPositionEmbedding(nn.Module):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, embedding_dim=256):
+ super().__init__()
+ self.row_embeddings = nn.Embedding(50, embedding_dim)
+ self.column_embeddings = nn.Embedding(50, embedding_dim)
+
+ def forward(self, pixel_values, pixel_mask=None):
+ height, width = pixel_values.shape[-2:]
+ width_values = torch.arange(width, device=pixel_values.device)
+ height_values = torch.arange(height, device=pixel_values.device)
+ x_emb = self.column_embeddings(width_values)
+ y_emb = self.row_embeddings(height_values)
+ pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
+ pos = pos.permute(2, 0, 1)
+ pos = pos.unsqueeze(0)
+ pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
+ return pos
+
+
+# Copied from transformers.models.detr.modeling_detr.build_position_encoding with Detr->DeformableDetr
+def build_position_encoding(config):
+ n_steps = config.d_model // 2
+ if config.position_embedding_type == "sine":
+ # TODO find a better way of exposing other arguments
+ position_embedding = DeformableDetrSinePositionEmbedding(n_steps, normalize=True)
+ elif config.position_embedding_type == "learned":
+ position_embedding = DeformableDetrLearnedPositionEmbedding(n_steps)
+ else:
+ raise ValueError(f"Not supported {config.position_embedding_type}")
+
+ return position_embedding
+
+
+class DeformableDetrMultiscaleDeformableAttention(nn.Module):
+ """
+ Multiscale deformable attention as proposed in Deformable DETR.
+ """
+
+ def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int):
+ super().__init__()
+
+ self.attn = MultiScaleDeformableAttention()
+
+ if config.d_model % num_heads != 0:
+ raise ValueError(
+ f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
+ )
+ dim_per_head = config.d_model // num_heads
+ # check if dim_per_head is power of 2
+ if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
+ warnings.warn(
+ "You'd better set embed_dim (d_model) in DeformableDetrMultiscaleDeformableAttention to make the"
+ " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
+ " implementation."
+ )
+
+ self.im2col_step = 64
+
+ self.d_model = config.d_model
+ self.n_levels = config.num_feature_levels
+ self.n_heads = num_heads
+ self.n_points = n_points
+
+ self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
+ self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
+ self.value_proj = nn.Linear(config.d_model, config.d_model)
+ self.output_proj = nn.Linear(config.d_model, config.d_model)
+
+ self.disable_custom_kernels = config.disable_custom_kernels
+
+ def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
+ return tensor if position_embeddings is None else tensor + position_embeddings
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ position_embeddings: Optional[torch.Tensor] = None,
+ reference_points=None,
+ spatial_shapes=None,
+ spatial_shapes_list=None,
+ level_start_index=None,
+ output_attentions: bool = False,
+ ):
+ # add position embeddings to the hidden states before projecting to queries and keys
+ if position_embeddings is not None:
+ hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
+
+ batch_size, num_queries, _ = hidden_states.shape
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
+ total_elements = sum(height * width for height, width in spatial_shapes_list)
+ if total_elements != sequence_length:
+ raise ValueError(
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
+ )
+
+ value = self.value_proj(encoder_hidden_states)
+ if attention_mask is not None:
+ # we invert the attention_mask
+ value = value.masked_fill(~attention_mask[..., None], float(0))
+ value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
+ sampling_offsets = self.sampling_offsets(hidden_states).view(
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
+ )
+ attention_weights = self.attention_weights(hidden_states).view(
+ batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
+ )
+ attention_weights = F.softmax(attention_weights, -1).view(
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
+ )
+ # batch_size, num_queries, n_heads, n_levels, n_points, 2
+ num_coordinates = reference_points.shape[-1]
+ if num_coordinates == 2:
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :]
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+ )
+ elif num_coordinates == 4:
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :2]
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
+ )
+ else:
+ raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
+
+ output = self.attn(
+ value,
+ spatial_shapes,
+ spatial_shapes_list,
+ level_start_index,
+ sampling_locations,
+ attention_weights,
+ self.im2col_step,
+ )
+
+ output = self.output_proj(output)
+
+ return output, attention_weights
+
+
+class DeformableDetrMultiheadAttention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper.
+
+ Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ if self.head_dim * num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
+ return tensor if position_embeddings is None else tensor + position_embeddings
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, target_len, embed_dim = hidden_states.size()
+ # add position embeddings to the hidden states before projecting to queries and keys
+ if position_embeddings is not None:
+ hidden_states_original = hidden_states
+ hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
+
+ # get queries, keys and values
+ query_states = self.q_proj(hidden_states) * self.scaling
+ key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
+ value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
+
+ proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ source_len = key_states.size(1)
+
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
+ raise ValueError(
+ f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (batch_size, 1, target_len, source_len):
+ raise ValueError(
+ f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
+ f" {attention_mask.size()}"
+ )
+ if attention_mask.dtype == torch.bool:
+ attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
+ attention_mask, -torch.inf
+ )
+ attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
+ attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
+ attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (
+ batch_size * self.num_heads,
+ target_len,
+ self.head_dim,
+ ):
+ raise ValueError(
+ f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DeformableDetrConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+ self.self_attn = DeformableDetrMultiscaleDeformableAttention(
+ config,
+ num_heads=config.encoder_attention_heads,
+ n_points=config.encoder_n_points,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_embeddings: Optional[torch.Tensor] = None,
+ reference_points=None,
+ spatial_shapes=None,
+ spatial_shapes_list=None,
+ level_start_index=None,
+ output_attentions: bool = False,
+ ):
+ """
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Input to the layer.
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Attention mask.
+ position_embeddings (`torch.FloatTensor`, *optional*):
+ Position embeddings, to be added to `hidden_states`.
+ reference_points (`torch.FloatTensor`, *optional*):
+ Reference points.
+ spatial_shapes (`torch.LongTensor`, *optional*):
+ Spatial shapes of the backbone feature maps.
+ level_start_index (`torch.LongTensor`, *optional*):
+ Level start index.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ # Apply Multi-scale Deformable Attention Module on the multi-scale feature maps.
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=hidden_states,
+ encoder_attention_mask=attention_mask,
+ position_embeddings=position_embeddings,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ level_start_index=level_start_index,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ residual = hidden_states
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ hidden_states = residual + hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ if self.training:
+ if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class DeformableDetrDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DeformableDetrConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ # self-attention
+ self.self_attn = DeformableDetrMultiheadAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ # cross-attention
+ self.encoder_attn = DeformableDetrMultiscaleDeformableAttention(
+ config,
+ num_heads=config.decoder_attention_heads,
+ n_points=config.decoder_n_points,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ # feedforward neural networks
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[torch.Tensor] = None,
+ reference_points=None,
+ spatial_shapes=None,
+ spatial_shapes_list=None,
+ level_start_index=None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ):
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(seq_len, batch, embed_dim)`.
+ position_embeddings (`torch.FloatTensor`, *optional*):
+ Position embeddings that are added to the queries and keys in the self-attention layer.
+ reference_points (`torch.FloatTensor`, *optional*):
+ Reference points.
+ spatial_shapes (`torch.LongTensor`, *optional*):
+ Spatial shapes.
+ level_start_index (`torch.LongTensor`, *optional*):
+ Level start index.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
+ values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ second_residual = hidden_states
+
+ # Cross-Attention
+ cross_attn_weights = None
+ hidden_states, cross_attn_weights = self.encoder_attn(
+ hidden_states=hidden_states,
+ attention_mask=encoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ position_embeddings=position_embeddings,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ level_start_index=level_start_index,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = second_residual + hidden_states
+
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ return outputs
+
+
+@auto_docstring
+class DeformableDetrPreTrainedModel(PreTrainedModel):
+ config: DeformableDetrConfig
+ base_model_prefix = "model"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = [
+ r"DeformableDetrConvEncoder",
+ r"DeformableDetrEncoderLayer",
+ r"DeformableDetrDecoderLayer",
+ ]
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+
+ if isinstance(module, DeformableDetrLearnedPositionEmbedding):
+ nn.init.uniform_(module.row_embeddings.weight)
+ nn.init.uniform_(module.column_embeddings.weight)
+ elif isinstance(module, DeformableDetrMultiscaleDeformableAttention):
+ nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
+ default_dtype = torch.get_default_dtype()
+ thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
+ 2.0 * math.pi / module.n_heads
+ )
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+ .view(module.n_heads, 1, 1, 2)
+ .repeat(1, module.n_levels, module.n_points, 1)
+ )
+ for i in range(module.n_points):
+ grid_init[:, :, i, :] *= i + 1
+ with torch.no_grad():
+ module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+ nn.init.constant_(module.attention_weights.weight.data, 0.0)
+ nn.init.constant_(module.attention_weights.bias.data, 0.0)
+ nn.init.xavier_uniform_(module.value_proj.weight.data)
+ nn.init.constant_(module.value_proj.bias.data, 0.0)
+ nn.init.xavier_uniform_(module.output_proj.weight.data)
+ nn.init.constant_(module.output_proj.bias.data, 0.0)
+ elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ if hasattr(module, "reference_points") and not self.config.two_stage:
+ nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0)
+ nn.init.constant_(module.reference_points.bias.data, 0.0)
+ if hasattr(module, "level_embed"):
+ nn.init.normal_(module.level_embed)
+
+
+class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
+ """
+ Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a
+ [`DeformableDetrEncoderLayer`].
+
+ The encoder updates the flattened multi-scale feature maps through multiple deformable attention layers.
+
+ Args:
+ config: DeformableDetrConfig
+ """
+
+ def __init__(self, config: DeformableDetrConfig):
+ super().__init__(config)
+ self.gradient_checkpointing = False
+
+ self.dropout = config.dropout
+ self.layers = nn.ModuleList([DeformableDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @staticmethod
+ def get_reference_points(spatial_shapes, valid_ratios, device):
+ """
+ Get reference points for each feature map. Used in decoder.
+
+ Args:
+ spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
+ Spatial shapes of each feature map.
+ valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
+ Valid ratios of each feature map.
+ device (`torch.device`):
+ Device on which to create the tensors.
+ Returns:
+ `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
+ """
+ reference_points_list = []
+ for level, (height, width) in enumerate(spatial_shapes):
+ ref_y, ref_x = meshgrid(
+ torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
+ torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
+ indexing="ij",
+ )
+ # TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height)
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+ def forward(
+ self,
+ inputs_embeds=None,
+ attention_mask=None,
+ position_embeddings=None,
+ spatial_shapes=None,
+ spatial_shapes_list=None,
+ level_start_index=None,
+ valid_ratios=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
+ - 1 for pixel features that are real (i.e. **not masked**),
+ - 0 for pixel features that are padding (i.e. **masked**).
+ [What are attention masks?](../glossary#attention-mask)
+ position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Position embeddings that are added to the queries and keys in each self-attention layer.
+ spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
+ Spatial shapes of each feature map.
+ level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
+ Starting index of each feature map.
+ valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
+ Ratio of valid area in each feature level.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states = inputs_embeds
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ spatial_shapes_tuple = tuple(spatial_shapes_list)
+ reference_points = self.get_reference_points(spatial_shapes_tuple, valid_ratios, device=inputs_embeds.device)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ for i, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ position_embeddings=position_embeddings,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ level_start_index=level_start_index,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_states,
+ attentions=all_attentions,
+ )
+
+
+class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DeformableDetrDecoderLayer`].
+
+ The decoder updates the query embeddings through multiple self-attention and cross-attention layers.
+
+ Some tweaks for Deformable DETR:
+
+ - `position_embeddings`, `reference_points`, `spatial_shapes` and `valid_ratios` are added to the forward pass.
+ - it also returns a stack of intermediate outputs and reference points from all decoding layers.
+
+ Args:
+ config: DeformableDetrConfig
+ """
+
+ def __init__(self, config: DeformableDetrConfig):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+ self.layers = nn.ModuleList([DeformableDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
+ self.gradient_checkpointing = False
+
+ # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
+ self.bbox_embed = None
+ self.class_embed = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ position_embeddings=None,
+ reference_points=None,
+ spatial_shapes=None,
+ spatial_shapes_list=None,
+ level_start_index=None,
+ valid_ratios=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
+ The query embeddings that are passed into the decoder.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
+ in `[0, 1]`:
+ - 1 for pixels that are real (i.e. **not masked**),
+ - 0 for pixels that are padding (i.e. **masked**).
+ position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
+ Position embeddings that are added to the queries and keys in each self-attention layer.
+ reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
+ Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
+ spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
+ Spatial shapes of the feature maps.
+ level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
+ Indexes for the start of each feature level. In range `[0, sequence_length]`.
+ valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
+ Ratio of valid area in each feature level.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ intermediate = ()
+ intermediate_reference_points = ()
+
+ for idx, decoder_layer in enumerate(self.layers):
+ num_coordinates = reference_points.shape[-1]
+ if num_coordinates == 4:
+ reference_points_input = (
+ reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
+ )
+ elif reference_points.shape[-1] == 2:
+ reference_points_input = reference_points[:, :, None] * valid_ratios[:, None]
+ else:
+ raise ValueError("Reference points' last dimension must be of size 2")
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ position_embeddings,
+ reference_points_input,
+ spatial_shapes,
+ spatial_shapes_list,
+ level_start_index,
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ # hack implementation for iterative bounding box refinement
+ if self.bbox_embed is not None:
+ tmp = self.bbox_embed[idx](hidden_states)
+ num_coordinates = reference_points.shape[-1]
+ if num_coordinates == 4:
+ new_reference_points = tmp + inverse_sigmoid(reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ elif num_coordinates == 2:
+ new_reference_points = tmp
+ new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ else:
+ raise ValueError(
+ f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}"
+ )
+ reference_points = new_reference_points.detach()
+
+ intermediate += (hidden_states,)
+ intermediate_reference_points += (reference_points,)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ # Keep batch_size as first dimension
+ intermediate = torch.stack(intermediate, dim=1)
+ intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ intermediate,
+ intermediate_reference_points,
+ all_hidden_states,
+ all_self_attns,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return DeformableDetrDecoderOutput(
+ last_hidden_state=hidden_states,
+ intermediate_hidden_states=intermediate,
+ intermediate_reference_points=intermediate_reference_points,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The bare Deformable DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw
+ hidden-states without any specific head on top.
+ """
+)
+class DeformableDetrModel(DeformableDetrPreTrainedModel):
+ def __init__(self, config: DeformableDetrConfig):
+ super().__init__(config)
+
+ # Create backbone + positional encoding
+ backbone = DeformableDetrConvEncoder(config)
+ position_embeddings = build_position_encoding(config)
+ self.backbone = DeformableDetrConvModel(backbone, position_embeddings)
+
+ # Create input projection layers
+ if config.num_feature_levels > 1:
+ num_backbone_outs = len(backbone.intermediate_channel_sizes)
+ input_proj_list = []
+ for _ in range(num_backbone_outs):
+ in_channels = backbone.intermediate_channel_sizes[_]
+ input_proj_list.append(
+ nn.Sequential(
+ nn.Conv2d(in_channels, config.d_model, kernel_size=1),
+ nn.GroupNorm(32, config.d_model),
+ )
+ )
+ for _ in range(config.num_feature_levels - num_backbone_outs):
+ input_proj_list.append(
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels,
+ config.d_model,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ nn.GroupNorm(32, config.d_model),
+ )
+ )
+ in_channels = config.d_model
+ self.input_proj = nn.ModuleList(input_proj_list)
+ else:
+ self.input_proj = nn.ModuleList(
+ [
+ nn.Sequential(
+ nn.Conv2d(
+ backbone.intermediate_channel_sizes[-1],
+ config.d_model,
+ kernel_size=1,
+ ),
+ nn.GroupNorm(32, config.d_model),
+ )
+ ]
+ )
+
+ if not config.two_stage:
+ self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model * 2)
+
+ self.encoder = DeformableDetrEncoder(config)
+ self.decoder = DeformableDetrDecoder(config)
+
+ self.level_embed = nn.Parameter(torch.Tensor(config.num_feature_levels, config.d_model))
+
+ if config.two_stage:
+ self.enc_output = nn.Linear(config.d_model, config.d_model)
+ self.enc_output_norm = nn.LayerNorm(config.d_model)
+ self.pos_trans = nn.Linear(config.d_model * 2, config.d_model * 2)
+ self.pos_trans_norm = nn.LayerNorm(config.d_model * 2)
+ else:
+ self.reference_points = nn.Linear(config.d_model, 2)
+
+ self.post_init()
+
+ def get_encoder(self):
+ return self.encoder
+
+ def freeze_backbone(self):
+ for name, param in self.backbone.conv_encoder.model.named_parameters():
+ param.requires_grad_(False)
+
+ def unfreeze_backbone(self):
+ for name, param in self.backbone.conv_encoder.model.named_parameters():
+ param.requires_grad_(True)
+
+ def get_valid_ratio(self, mask, dtype=torch.float32):
+ """Get the valid ratio of all feature maps."""
+
+ _, height, width = mask.shape
+ valid_height = torch.sum(mask[:, :, 0], 1)
+ valid_width = torch.sum(mask[:, 0, :], 1)
+ valid_ratio_height = valid_height.to(dtype) / height
+ valid_ratio_width = valid_width.to(dtype) / width
+ valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1)
+ return valid_ratio
+
+ def get_proposal_pos_embed(self, proposals):
+ """Get the position embedding of the proposals."""
+
+ num_pos_feats = self.config.d_model // 2
+ temperature = 10000
+ scale = 2 * math.pi
+
+ dim_t = torch.arange(num_pos_feats, dtype=proposals.dtype, device=proposals.device)
+ dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
+ # batch_size, num_queries, 4
+ proposals = proposals.sigmoid() * scale
+ # batch_size, num_queries, 4, 128
+ pos = proposals[:, :, :, None] / dim_t
+ # batch_size, num_queries, 4, 64, 2 -> batch_size, num_queries, 512
+ pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
+ return pos
+
+ def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes):
+ """Generate the encoder output proposals from encoded enc_output.
+
+ Args:
+ enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder.
+ padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`.
+ spatial_shapes (list[tuple[int, int]]): Spatial shapes of the feature maps.
+
+ Returns:
+ `tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction.
+ - object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to
+ directly predict a bounding box. (without the need of a decoder)
+ - output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals, after an inverse
+ sigmoid.
+ """
+ batch_size = enc_output.shape[0]
+ proposals = []
+ _cur = 0
+ for level, (height, width) in enumerate(spatial_shapes):
+ mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1)
+ valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
+ valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
+
+ grid_y, grid_x = meshgrid(
+ torch.linspace(
+ 0,
+ height - 1,
+ height,
+ dtype=enc_output.dtype,
+ device=enc_output.device,
+ ),
+ torch.linspace(
+ 0,
+ width - 1,
+ width,
+ dtype=enc_output.dtype,
+ device=enc_output.device,
+ ),
+ indexing="ij",
+ )
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
+
+ scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2)
+ grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale
+ width_height = torch.ones_like(grid) * 0.05 * (2.0**level)
+ proposal = torch.cat((grid, width_height), -1).view(batch_size, -1, 4)
+ proposals.append(proposal)
+ _cur += height * width
+ output_proposals = torch.cat(proposals, 1)
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
+ output_proposals = torch.log(output_proposals / (1 - output_proposals)) # inverse sigmoid
+ output_proposals = output_proposals.masked_fill(padding_mask.unsqueeze(-1), float("inf"))
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
+
+ # assign each pixel as an object query
+ object_query = enc_output
+ object_query = object_query.masked_fill(padding_mask.unsqueeze(-1), float(0))
+ object_query = object_query.masked_fill(~output_proposals_valid, float(0))
+ object_query = self.enc_output_norm(self.enc_output(object_query))
+ return object_query, output_proposals
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_mask: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_outputs: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.FloatTensor], DeformableDetrModelOutput]:
+ r"""
+ decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
+ Not used by default. Can be used to mask object queries.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
+ can choose to directly pass a flattened representation of an image.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
+ embedded representation.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, DeformableDetrModel
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("SenseTime/deformable-detr")
+ >>> model = DeformableDetrModel.from_pretrained("SenseTime/deformable-detr")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ >>> list(last_hidden_states.shape)
+ [1, 300, 256]
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size, num_channels, height, width = pixel_values.shape
+ device = pixel_values.device
+
+ if pixel_mask is None:
+ pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device)
+
+ # Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper)
+ # First, sent pixel_values + pixel_mask through Backbone to obtain the features
+ # which is a list of tuples
+ features, position_embeddings_list = self.backbone(pixel_values, pixel_mask)
+
+ # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
+ sources = []
+ masks = []
+ for level, (source, mask) in enumerate(features):
+ sources.append(self.input_proj[level](source))
+ masks.append(mask)
+ if mask is None:
+ raise ValueError("No attention mask was provided")
+
+ # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
+ if self.config.num_feature_levels > len(sources):
+ _len_sources = len(sources)
+ for level in range(_len_sources, self.config.num_feature_levels):
+ if level == _len_sources:
+ source = self.input_proj[level](features[-1][0])
+ else:
+ source = self.input_proj[level](sources[-1])
+ mask = nn.functional.interpolate(pixel_mask[None].to(pixel_values.dtype), size=source.shape[-2:]).to(
+ torch.bool
+ )[0]
+ pos_l = self.backbone.position_embedding(source, mask).to(source.dtype)
+ sources.append(source)
+ masks.append(mask)
+ position_embeddings_list.append(pos_l)
+
+ # Create queries
+ query_embeds = None
+ if not self.config.two_stage:
+ query_embeds = self.query_position_embeddings.weight
+
+ # Prepare encoder inputs (by flattening)
+ source_flatten = []
+ mask_flatten = []
+ lvl_pos_embed_flatten = []
+ spatial_shapes_list = []
+ for level, (source, mask, pos_embed) in enumerate(zip(sources, masks, position_embeddings_list)):
+ batch_size, num_channels, height, width = source.shape
+ spatial_shape = (height, width)
+ spatial_shapes_list.append(spatial_shape)
+ source = source.flatten(2).transpose(1, 2)
+ mask = mask.flatten(1)
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
+ lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1)
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ source_flatten.append(source)
+ mask_flatten.append(mask)
+ source_flatten = torch.cat(source_flatten, 1)
+ mask_flatten = torch.cat(mask_flatten, 1)
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
+ spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
+
+ # Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
+ # Also provide spatial_shapes, level_start_index and valid_ratios
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ inputs_embeds=source_flatten,
+ attention_mask=mask_flatten,
+ position_embeddings=lvl_pos_embed_flatten,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # Fifth, prepare decoder inputs
+ batch_size, _, num_channels = encoder_outputs[0].shape
+ enc_outputs_class = None
+ enc_outputs_coord_logits = None
+ if self.config.two_stage:
+ object_query_embedding, output_proposals = self.gen_encoder_output_proposals(
+ encoder_outputs[0], ~mask_flatten, spatial_shapes_list
+ )
+
+ # hack implementation for two-stage Deformable DETR
+ # apply a detection head to each pixel (A.4 in paper)
+ # linear projection for bounding box binary classification (i.e. foreground and background)
+ enc_outputs_class = self.decoder.class_embed[-1](object_query_embedding)
+ # 3-layer FFN to predict bounding boxes coordinates (bbox regression branch)
+ delta_bbox = self.decoder.bbox_embed[-1](object_query_embedding)
+ enc_outputs_coord_logits = delta_bbox + output_proposals
+
+ # only keep top scoring `config.two_stage_num_proposals` proposals
+ topk = self.config.two_stage_num_proposals
+ topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
+ topk_coords_logits = torch.gather(
+ enc_outputs_coord_logits,
+ 1,
+ topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
+ )
+
+ topk_coords_logits = topk_coords_logits.detach()
+ reference_points = topk_coords_logits.sigmoid()
+ init_reference_points = reference_points
+ pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_logits)))
+ query_embed, target = torch.split(pos_trans_out, num_channels, dim=2)
+ else:
+ query_embed, target = torch.split(query_embeds, num_channels, dim=1)
+ query_embed = query_embed.unsqueeze(0).expand(batch_size, -1, -1)
+ target = target.unsqueeze(0).expand(batch_size, -1, -1)
+ reference_points = self.reference_points(query_embed).sigmoid()
+ init_reference_points = reference_points
+
+ decoder_outputs = self.decoder(
+ inputs_embeds=target,
+ position_embeddings=query_embed,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=mask_flatten,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ spatial_shapes_list=spatial_shapes_list,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ enc_outputs = tuple(value for value in [enc_outputs_class, enc_outputs_coord_logits] if value is not None)
+ tuple_outputs = (init_reference_points,) + decoder_outputs + encoder_outputs + enc_outputs
+
+ return tuple_outputs
+
+ return DeformableDetrModelOutput(
+ init_reference_points=init_reference_points,
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
+ intermediate_reference_points=decoder_outputs.intermediate_reference_points,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ enc_outputs_class=enc_outputs_class,
+ enc_outputs_coord_logits=enc_outputs_coord_logits,
+ )
+
+
+# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead
+class DeformableDetrMLPPredictionHead(nn.Module):
+ """
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
+ height and width of a bounding box w.r.t. an image.
+
+ Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
+
+ """
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+@auto_docstring(
+ custom_intro="""
+ Deformable DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on
+ top, for tasks such as COCO detection.
+ """
+)
+class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
+ # When using clones, all layers > 0 will be clones, but layer 0 *is* required
+ _tied_weights_keys = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
+ # We can't initialize the model on meta device as some weights are modified during the initialization
+ _no_split_modules = None
+
+ def __init__(self, config: DeformableDetrConfig):
+ super().__init__(config)
+
+ # Deformable DETR encoder-decoder model
+ self.model = DeformableDetrModel(config)
+ # Detection heads on top
+ self.class_embed = nn.Linear(config.d_model, config.num_labels)
+ self.bbox_embed = DeformableDetrMLPPredictionHead(
+ input_dim=config.d_model,
+ hidden_dim=config.d_model,
+ output_dim=4,
+ num_layers=3,
+ )
+
+ # if two-stage, the last class_embed and bbox_embed is for region proposal generation
+ num_pred = (config.decoder_layers + 1) if config.two_stage else config.decoder_layers
+ if config.with_box_refine:
+ self.class_embed = _get_clones(self.class_embed, num_pred)
+ self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
+ # hack implementation for iterative bounding box refinement
+ self.model.decoder.bbox_embed = self.bbox_embed
+ else:
+ self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])
+ self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
+ self.model.decoder.bbox_embed = None
+ if config.two_stage:
+ # hack implementation for two-stage
+ self.model.decoder.class_embed = self.class_embed
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ pixel_mask: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_outputs: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[list[dict]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.FloatTensor], DeformableDetrObjectDetectionOutput]:
+ r"""
+ decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
+ Not used by default. Can be used to mask object queries.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
+ can choose to directly pass a flattened representation of an image.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
+ embedded representation.
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, DeformableDetrForObjectDetection
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("SenseTime/deformable-detr")
+ >>> model = DeformableDetrForObjectDetection.from_pretrained("SenseTime/deformable-detr")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
+ >>> target_sizes = torch.tensor([image.size[::-1]])
+ >>> results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[
+ ... 0
+ ... ]
+ >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
+ ... box = [round(i, 2) for i in box.tolist()]
+ ... print(
+ ... f"Detected {model.config.id2label[label.item()]} with confidence "
+ ... f"{round(score.item(), 3)} at location {box}"
+ ... )
+ Detected cat with confidence 0.8 at location [16.5, 52.84, 318.25, 470.78]
+ Detected cat with confidence 0.789 at location [342.19, 24.3, 640.02, 372.25]
+ Detected remote with confidence 0.633 at location [40.79, 72.78, 176.76, 117.25]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # First, sent images through DETR base model to obtain encoder + decoder outputs
+ outputs = self.model(
+ pixel_values,
+ pixel_mask=pixel_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_outputs=encoder_outputs,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2]
+ init_reference = outputs.init_reference_points if return_dict else outputs[0]
+ inter_references = outputs.intermediate_reference_points if return_dict else outputs[3]
+
+ # class logits + predicted bounding boxes
+ outputs_classes = []
+ outputs_coords = []
+
+ for level in range(hidden_states.shape[1]):
+ if level == 0:
+ reference = init_reference
+ else:
+ reference = inter_references[:, level - 1]
+ reference = inverse_sigmoid(reference)
+ outputs_class = self.class_embed[level](hidden_states[:, level])
+ delta_bbox = self.bbox_embed[level](hidden_states[:, level])
+ if reference.shape[-1] == 4:
+ outputs_coord_logits = delta_bbox + reference
+ elif reference.shape[-1] == 2:
+ delta_bbox[..., :2] += reference
+ outputs_coord_logits = delta_bbox
+ else:
+ raise ValueError(f"reference.shape[-1] should be 4 or 2, but got {reference.shape[-1]}")
+ outputs_coord = outputs_coord_logits.sigmoid()
+ outputs_classes.append(outputs_class)
+ outputs_coords.append(outputs_coord)
+ outputs_class = torch.stack(outputs_classes)
+ outputs_coord = torch.stack(outputs_coords)
+
+ logits = outputs_class[-1]
+ pred_boxes = outputs_coord[-1]
+
+ loss, loss_dict, auxiliary_outputs = None, None, None
+ if labels is not None:
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
+ logits,
+ labels,
+ self.device,
+ pred_boxes,
+ self.config,
+ outputs_class,
+ outputs_coord,
+ )
+ if not return_dict:
+ if auxiliary_outputs is not None:
+ output = (logits, pred_boxes) + auxiliary_outputs + outputs
+ else:
+ output = (logits, pred_boxes) + outputs
+ tuple_outputs = ((loss, loss_dict) + output) if loss is not None else output
+
+ return tuple_outputs
+
+ dict_outputs = DeformableDetrObjectDetectionOutput(
+ loss=loss,
+ loss_dict=loss_dict,
+ logits=logits,
+ pred_boxes=pred_boxes,
+ auxiliary_outputs=auxiliary_outputs,
+ last_hidden_state=outputs.last_hidden_state,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
+ intermediate_reference_points=outputs.intermediate_reference_points,
+ init_reference_points=outputs.init_reference_points,
+ enc_outputs_class=outputs.enc_outputs_class,
+ enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
+ )
+
+ return dict_outputs
+
+
+__all__ = [
+ "DeformableDetrForObjectDetection",
+ "DeformableDetrModel",
+ "DeformableDetrPreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/modular_deformable_detr.py b/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/modular_deformable_detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e38df7845a2b656475b9b20b12c9198cc1a9ca6
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deformable_detr/modular_deformable_detr.py
@@ -0,0 +1,141 @@
+from typing import Union
+
+import torch
+
+from transformers.models.detr.image_processing_detr_fast import DetrImageProcessorFast
+
+from ...image_transforms import center_to_corners_format
+from ...utils import (
+ TensorType,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeformableDetrImageProcessorFast(DetrImageProcessorFast):
+ def post_process(self, outputs, target_sizes):
+ """
+ Converts the raw output of [`DeformableDetrForObjectDetection`] into final bounding boxes in (top_left_x,
+ top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch.
+
+ Args:
+ outputs ([`DeformableDetrObjectDetectionOutput`]):
+ Raw outputs of the model.
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
+ Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the
+ original image size (before any data augmentation). For visualization, this should be the image size
+ after data augment, but before padding.
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model.
+ """
+ logger.warning_once(
+ "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
+ " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
+ )
+
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
+
+ if len(out_logits) != len(target_sizes):
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
+ if target_sizes.shape[1] != 2:
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
+
+ prob = out_logits.sigmoid()
+ topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
+ scores = topk_values
+ topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
+ labels = topk_indexes % out_logits.shape[2]
+ boxes = center_to_corners_format(out_bbox)
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
+
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
+
+ return results
+
+ def post_process_object_detection(
+ self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None, top_k: int = 100
+ ):
+ """
+ Converts the raw output of [`DeformableDetrForObjectDetection`] into final bounding boxes in (top_left_x,
+ top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch.
+
+ Args:
+ outputs ([`DetrObjectDetectionOutput`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*):
+ Score threshold to keep object detection predictions.
+ target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
+ (height, width) of each image in the batch. If left to None, predictions will not be resized.
+ top_k (`int`, *optional*, defaults to 100):
+ Keep only top k bounding boxes before filtering by thresholding.
+
+ Returns:
+ `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model.
+ """
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
+
+ if target_sizes is not None:
+ if len(out_logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ prob = out_logits.sigmoid()
+ prob = prob.view(out_logits.shape[0], -1)
+ k_value = min(top_k, prob.size(1))
+ topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
+ scores = topk_values
+ topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
+ labels = topk_indexes % out_logits.shape[2]
+ boxes = center_to_corners_format(out_bbox)
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
+
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ if target_sizes is not None:
+ if isinstance(target_sizes, list):
+ img_h = torch.Tensor([i[0] for i in target_sizes])
+ img_w = torch.Tensor([i[1] for i in target_sizes])
+ else:
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = []
+ for s, l, b in zip(scores, labels, boxes):
+ score = s[s > threshold]
+ label = l[s > threshold]
+ box = b[s > threshold]
+ results.append({"scores": score, "labels": label, "boxes": box})
+
+ return results
+
+ def post_process_segmentation(self):
+ raise NotImplementedError("Segmentation post-processing is not implemented for Deformable DETR yet.")
+
+ def post_process_instance(self):
+ raise NotImplementedError("Instance post-processing is not implemented for Deformable DETR yet.")
+
+ def post_process_panoptic(self):
+ raise NotImplementedError("Panoptic post-processing is not implemented for Deformable DETR yet.")
+
+ def post_process_instance_segmentation(self):
+ raise NotImplementedError("Segmentation post-processing is not implemented for Deformable DETR yet.")
+
+ def post_process_semantic_segmentation(self):
+ raise NotImplementedError("Semantic segmentation post-processing is not implemented for Deformable DETR yet.")
+
+ def post_process_panoptic_segmentation(self):
+ raise NotImplementedError("Panoptic segmentation post-processing is not implemented for Deformable DETR yet.")
+
+
+__all__ = ["DeformableDetrImageProcessorFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deit/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/deit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..98236a86d7a1e8b4ff16b53fb3ff37befbf1d7ac
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deit/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_deit import *
+ from .feature_extraction_deit import *
+ from .image_processing_deit import *
+ from .image_processing_deit_fast import *
+ from .modeling_deit import *
+ from .modeling_tf_deit import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deit/configuration_deit.py b/venv/lib/python3.13/site-packages/transformers/models/deit/configuration_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a321ebe293e191e7bbce29b528dfa2f6b00d141
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deit/configuration_deit.py
@@ -0,0 +1,152 @@
+# coding=utf-8
+# Copyright 2021 Facebook AI Research (FAIR) and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DeiT model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeiTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DeiTModel`]. It is used to instantiate an DeiT
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the DeiT
+ [facebook/deit-base-distilled-patch16-224](https://huggingface.co/facebook/deit-base-distilled-patch16-224)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ encoder_stride (`int`, *optional*, defaults to 16):
+ Factor to increase the spatial resolution by in the decoder head for masked image modeling.
+ pooler_output_size (`int`, *optional*):
+ Dimensionality of the pooler layer. If None, defaults to `hidden_size`.
+ pooler_act (`str`, *optional*, defaults to `"tanh"`):
+ The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and
+ Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are
+ supported for Tensorflow.
+
+ Example:
+
+ ```python
+ >>> from transformers import DeiTConfig, DeiTModel
+
+ >>> # Initializing a DeiT deit-base-distilled-patch16-224 style configuration
+ >>> configuration = DeiTConfig()
+
+ >>> # Initializing a model (with random weights) from the deit-base-distilled-patch16-224 style configuration
+ >>> model = DeiTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "deit"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ qkv_bias=True,
+ encoder_stride=16,
+ pooler_output_size=None,
+ pooler_act="tanh",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.encoder_stride = encoder_stride
+ self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size
+ self.pooler_act = pooler_act
+
+
+class DeiTOnnxConfig(OnnxConfig):
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-4
+
+
+__all__ = ["DeiTConfig", "DeiTOnnxConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deit/feature_extraction_deit.py b/venv/lib/python3.13/site-packages/transformers/models/deit/feature_extraction_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..d040fd08395f8e921ec688228d7d5faa8963ab81
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deit/feature_extraction_deit.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for DeiT."""
+
+import warnings
+
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_deit import DeiTImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class DeiTFeatureExtractor(DeiTImageProcessor):
+ def __init__(self, *args, **kwargs) -> None:
+ warnings.warn(
+ "The class DeiTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
+ " use DeiTImageProcessor instead.",
+ FutureWarning,
+ )
+ super().__init__(*args, **kwargs)
+
+
+__all__ = ["DeiTFeatureExtractor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deit/image_processing_deit.py b/venv/lib/python3.13/site-packages/transformers/models/deit/image_processing_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e2f6c3b5ae5f0f1cf2eb1727d2e3235443b81b9
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deit/image_processing_deit.py
@@ -0,0 +1,301 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for DeiT."""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+from ...utils.import_utils import requires
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class DeiTImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a DeiT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in `preprocess`.
+ size (`dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
+ Size of the image after `resize`. Can be overridden by `size` in `preprocess`.
+ resample (`PILImageResampling` filter, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
+ is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in `preprocess`.
+ crop_size (`dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Desired output size when applying center-cropping. Can be overridden by `crop_size` in `preprocess`.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PIL.Image.BICUBIC,
+ do_center_crop: bool = True,
+ crop_size: Optional[dict[str, int]] = None,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_rescale: bool = True,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 256, "width": 256}
+ size = get_size_dict(size)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
+ output_size = (size["height"], size["width"])
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample=None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[dict[str, int]] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after `resize`.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ PILImageResampling filter to use if resizing the image Only has an effect if `do_resize` is set to
+ `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
+ padded with zeros and then cropped
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - `None`: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
+
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ all_images = []
+ for image in images:
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_center_crop:
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ all_images.append(image)
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ for image in all_images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["DeiTImageProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deit/image_processing_deit_fast.py b/venv/lib/python3.13/site-packages/transformers/models/deit/image_processing_deit_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..3aafeaf50c09455cffeecb3776eb3598c8ceccf2
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deit/image_processing_deit_fast.py
@@ -0,0 +1,41 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for DeiT."""
+
+from ...image_processing_utils_fast import BaseImageProcessorFast
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ PILImageResampling,
+)
+from ...utils import auto_docstring
+
+
+@auto_docstring
+class DeiTImageProcessorFast(BaseImageProcessorFast):
+ # To be checked against the slow image processor
+ # None values left after checking can be removed
+ resample = PILImageResampling.BICUBIC
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ size = {"height": 256, "width": 256}
+ crop_size = {"height": 224, "width": 224}
+ do_resize = True
+ do_center_crop = True
+ do_rescale = True
+ do_normalize = True
+
+
+__all__ = ["DeiTImageProcessorFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deit/modeling_deit.py b/venv/lib/python3.13/site-packages/transformers/models/deit/modeling_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddb03c053f1ee08f011e650daad794821205ff33
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deit/modeling_deit.py
@@ -0,0 +1,791 @@
+# coding=utf-8
+# Copyright 2021 Facebook AI Research (FAIR), Ross Wightman, The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DeiT model."""
+
+import collections.abc
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPooling,
+ ImageClassifierOutput,
+ MaskedImageModelingOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging, torch_int
+from ...utils.generic import can_return_tuple, check_model_inputs
+from .configuration_deit import DeiTConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeiTEmbeddings(nn.Module):
+ """
+ Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
+ self.patch_embeddings = DeiTPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing and 2 class embeddings.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1] - 2
+ num_positions = self.position_embeddings.shape[1] - 2
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ class_and_dist_pos_embed = self.position_embeddings[:, :2]
+ patch_pos_embed = self.position_embeddings[:, 2:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_and_dist_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.Tensor:
+ _, _, height, width = pixel_values.shape
+ embeddings = self.patch_embeddings(pixel_values)
+
+ batch_size, seq_length, _ = embeddings.size()
+
+ if bool_masked_pos is not None:
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+
+ distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
+
+ embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
+ position_embedding = self.position_embeddings
+
+ if interpolate_pos_encoding:
+ position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
+
+ embeddings = embeddings + position_embedding
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class DeiTPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ x = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return x
+
+
+# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+
+ # Normalize the attention scores to probabilities.
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ # Mask heads if we want to
+ if attention_mask is not None:
+ attn_weights = attn_weights * attention_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT
+class DeiTSelfAttention(nn.Module):
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.config = config
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.dropout_prob = config.attention_probs_dropout_prob
+ self.scaling = self.attention_head_size**-0.5
+ self.is_causal = False
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ def forward(
+ self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size = hidden_states.shape[0]
+ new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
+
+ key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
+ value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
+ query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ context_layer, attention_probs = attention_interface(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ head_mask,
+ is_causal=self.is_causal,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.dropout_prob,
+ )
+
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.reshape(new_context_layer_shape)
+
+ return context_layer, attention_probs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT
+class DeiTSelfOutput(nn.Module):
+ """
+ The residual connection is defined in DeiTLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->DeiT
+class DeiTAttention(nn.Module):
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ self.attention = DeiTSelfAttention(config)
+ self.output = DeiTSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: set[int]):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ self_attn_output, _ = self.attention(hidden_states, head_mask)
+ output = self.output(self_attn_output, hidden_states)
+ return output
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
+class DeiTIntermediate(nn.Module):
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DeiT
+class DeiTOutput(nn.Module):
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states + input_tensor
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT
+class DeiTLayer(GradientCheckpointingLayer):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = DeiTAttention(config)
+ self.intermediate = DeiTIntermediate(config)
+ self.output = DeiTOutput(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ hidden_states_norm = self.layernorm_before(hidden_states)
+ attention_output = self.attention(hidden_states_norm, head_mask)
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in DeiT, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ return layer_output
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->DeiT
+class DeiTEncoder(nn.Module):
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> BaseModelOutput:
+ for i, layer_module in enumerate(self.layer):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ hidden_states = layer_module(hidden_states, layer_head_mask)
+
+ return BaseModelOutput(last_hidden_state=hidden_states)
+
+
+@auto_docstring
+class DeiTPreTrainedModel(PreTrainedModel):
+ config: DeiTConfig
+ base_model_prefix = "deit"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["DeiTLayer"]
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": DeiTLayer,
+ "attentions": DeiTSelfAttention,
+ }
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, DeiTEmbeddings):
+ module.cls_token.data.zero_()
+ module.position_embeddings.data.zero_()
+ module.distillation_token.data.zero_()
+ if module.mask_token is not None:
+ module.mask_token.data.zero_()
+
+
+@auto_docstring
+class DeiTModel(DeiTPreTrainedModel):
+ def __init__(self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None:
+ r"""
+ add_pooling_layer (bool, *optional*, defaults to `True`):
+ Whether to add a pooling layer
+ use_mask_token (`bool`, *optional*, defaults to `False`):
+ Whether to use a mask token for masked image modeling.
+ """
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = DeiTEmbeddings(config, use_mask_token=use_mask_token)
+ self.encoder = DeiTEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = DeiTPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> DeiTPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ interpolate_pos_encoding: bool = False,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+ """
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
+ expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
+ if pixel_values.dtype != expected_dtype:
+ pixel_values = pixel_values.to(expected_dtype)
+
+ embedding_output = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+ )
+
+ encoder_outputs: BaseModelOutput = self.encoder(embedding_output, head_mask=head_mask)
+ sequence_output = encoder_outputs.last_hidden_state
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ )
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->DeiT
+class DeiTPooler(nn.Module):
+ def __init__(self, config: DeiTConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
+ self.activation = ACT2FN[config.pooler_act]
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+@auto_docstring(
+ custom_intro="""
+ DeiT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://huggingface.co/papers/2111.09886).
+
+
+
+ Note that we provide a script to pre-train this model on custom data in our [examples
+ directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
+
+
+ """
+)
+class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
+ def __init__(self, config: DeiTConfig) -> None:
+ super().__init__(config)
+
+ self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True)
+
+ self.decoder = nn.Sequential(
+ nn.Conv2d(
+ in_channels=config.hidden_size,
+ out_channels=config.encoder_stride**2 * config.num_channels,
+ kernel_size=1,
+ ),
+ nn.PixelShuffle(config.encoder_stride),
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ interpolate_pos_encoding: bool = False,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MaskedImageModelingOutput:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+
+ Examples:
+ ```python
+ >>> from transformers import AutoImageProcessor, DeiTForMaskedImageModeling
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
+ >>> model = DeiTForMaskedImageModeling.from_pretrained("facebook/deit-base-distilled-patch16-224")
+
+ >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
+ >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
+ >>> # create random boolean mask of shape (batch_size, num_patches)
+ >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
+
+ >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
+ >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
+ >>> list(reconstructed_pixel_values.shape)
+ [1, 3, 224, 224]
+ ```"""
+
+ outputs: BaseModelOutputWithPooling = self.deit(
+ pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ **kwargs,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ # Reshape to (batch_size, num_channels, height, width)
+ sequence_output = sequence_output[:, 1:-1]
+ batch_size, sequence_length, num_channels = sequence_output.shape
+ height = width = int(sequence_length**0.5)
+ sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
+
+ # Reconstruct pixel values
+ reconstructed_pixel_values = self.decoder(sequence_output)
+
+ masked_im_loss = None
+ if bool_masked_pos is not None:
+ size = self.config.image_size // self.config.patch_size
+ bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
+ mask = (
+ bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
+ .repeat_interleave(self.config.patch_size, 2)
+ .unsqueeze(1)
+ .contiguous()
+ )
+ reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
+ masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
+
+ return MaskedImageModelingOutput(
+ loss=masked_im_loss,
+ reconstruction=reconstructed_pixel_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
+ the [CLS] token) e.g. for ImageNet.
+ """
+)
+class DeiTForImageClassification(DeiTPreTrainedModel):
+ def __init__(self, config: DeiTConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.deit = DeiTModel(config, add_pooling_layer=False)
+
+ # Classifier head
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ interpolate_pos_encoding: bool = False,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> ImageClassifierOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, DeiTForImageClassification
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> # note: we are loading a DeiTForImageClassificationWithTeacher from the hub here,
+ >>> # so the head will be randomly initialized, hence the predictions will be random
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
+ >>> model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the 1000 ImageNet classes
+ >>> predicted_class_idx = logits.argmax(-1).item()
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
+ Predicted class: Polaroid camera, Polaroid Land camera
+ ```"""
+
+ outputs: BaseModelOutputWithPooling = self.deit(
+ pixel_values,
+ head_mask=head_mask,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ **kwargs,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ logits = self.classifier(sequence_output[:, 0, :])
+ # we don't use the distillation token
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config, **kwargs)
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`DeiTForImageClassificationWithTeacher`].
+ """
+)
+class DeiTForImageClassificationWithTeacherOutput(ModelOutput):
+ r"""
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores as the average of the cls_logits and distillation logits.
+ cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
+ class token).
+ distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
+ distillation token).
+ """
+
+ logits: Optional[torch.FloatTensor] = None
+ cls_logits: Optional[torch.FloatTensor] = None
+ distillation_logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of
+ the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.
+
+ .. warning::
+
+ This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
+ supported.
+ """
+)
+class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
+ def __init__(self, config: DeiTConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.deit = DeiTModel(config, add_pooling_layer=False)
+
+ # Classifier heads
+ self.cls_classifier = (
+ nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+ self.distillation_classifier = (
+ nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ interpolate_pos_encoding: bool = False,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> DeiTForImageClassificationWithTeacherOutput:
+ outputs: BaseModelOutputWithPooling = self.deit(
+ pixel_values,
+ head_mask=head_mask,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ **kwargs,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ cls_logits = self.cls_classifier(sequence_output[:, 0, :])
+ distillation_logits = self.distillation_classifier(sequence_output[:, 1, :])
+
+ # during inference, return the average of both classifier predictions
+ logits = (cls_logits + distillation_logits) / 2
+
+ return DeiTForImageClassificationWithTeacherOutput(
+ logits=logits,
+ cls_logits=cls_logits,
+ distillation_logits=distillation_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "DeiTForImageClassification",
+ "DeiTForImageClassificationWithTeacher",
+ "DeiTForMaskedImageModeling",
+ "DeiTModel",
+ "DeiTPreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deit/modeling_tf_deit.py b/venv/lib/python3.13/site-packages/transformers/models/deit/modeling_tf_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c56eee87911edc445641e0bbc14f094e1c5efa7
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deit/modeling_tf_deit.py
@@ -0,0 +1,1232 @@
+# coding=utf-8
+# Copyright 2022 Facebook AI Research (FAIR) and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TensorFlow DeiT model."""
+
+from __future__ import annotations
+
+import collections.abc
+import math
+from dataclasses import dataclass
+
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutput,
+ TFBaseModelOutputWithPooling,
+ TFImageClassifierOutput,
+ TFMaskedImageModelingOutput,
+)
+from ...modeling_tf_utils import (
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list, stable_softmax
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_deit import DeiTConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "DeiTConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/deit-base-distilled-patch16-224"
+_EXPECTED_OUTPUT_SHAPE = [1, 198, 768]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/deit-base-distilled-patch16-224"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+@dataclass
+class TFDeiTForImageClassificationWithTeacherOutput(ModelOutput):
+ """
+ Output type of [`DeiTForImageClassificationWithTeacher`].
+
+ Args:
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores as the average of the cls_logits and distillation logits.
+ cls_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
+ class token).
+ distillation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
+ distillation token).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
+ the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ logits: tf.Tensor | None = None
+ cls_logits: tf.Tensor | None = None
+ distillation_logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+class TFDeiTEmbeddings(keras.layers.Layer):
+ """
+ Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config: DeiTConfig, use_mask_token: bool = False, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.config = config
+ self.use_mask_token = use_mask_token
+ self.patch_embeddings = TFDeiTPatchEmbeddings(config=config, name="patch_embeddings")
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout")
+
+ def build(self, input_shape=None):
+ self.cls_token = self.add_weight(
+ shape=(1, 1, self.config.hidden_size),
+ initializer=keras.initializers.zeros(),
+ trainable=True,
+ name="cls_token",
+ )
+ self.distillation_token = self.add_weight(
+ shape=(1, 1, self.config.hidden_size),
+ initializer=keras.initializers.zeros(),
+ trainable=True,
+ name="distillation_token",
+ )
+ self.mask_token = None
+ if self.use_mask_token:
+ self.mask_token = self.add_weight(
+ shape=(1, 1, self.config.hidden_size),
+ initializer=keras.initializers.zeros(),
+ trainable=True,
+ name="mask_token",
+ )
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = self.add_weight(
+ shape=(1, num_patches + 2, self.config.hidden_size),
+ initializer=keras.initializers.zeros(),
+ trainable=True,
+ name="position_embeddings",
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "patch_embeddings", None) is not None:
+ with tf.name_scope(self.patch_embeddings.name):
+ self.patch_embeddings.build(None)
+ if getattr(self, "dropout", None) is not None:
+ with tf.name_scope(self.dropout.name):
+ self.dropout.build(None)
+
+ def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor:
+ num_patches = embeddings.shape[1] - 2
+ num_positions = self.position_embeddings.shape[1] - 2
+
+ if num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ class_pos_embed = self.position_embeddings[:, 0, :]
+ dist_pos_embed = self.position_embeddings[:, 1, :]
+ patch_pos_embed = self.position_embeddings[:, 2:, :]
+ dim = embeddings.shape[-1]
+ h0 = height // self.config.patch_size
+ w0 = width // self.config.patch_size
+ # # we add a small number to avoid floating point error in the interpolation
+ # # see discussion at https://github.com/facebookresearch/dino/issues/8
+ h0, w0 = h0 + 0.1, w0 + 0.1
+ patch_pos_embed = tf.reshape(
+ patch_pos_embed, (1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+ )
+ patch_pos_embed = tf.image.resize(patch_pos_embed, size=(int(h0), int(w0)), method="bicubic")
+ patch_pos_embed = tf.transpose(patch_pos_embed, perm=[0, 2, 3, 1])
+ patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, dim))
+
+ return tf.concat(
+ [tf.expand_dims(class_pos_embed, axis=0), tf.expand_dims(dist_pos_embed, axis=0), patch_pos_embed], axis=1
+ )
+
+ def call(
+ self,
+ pixel_values: tf.Tensor,
+ bool_masked_pos: tf.Tensor | None = None,
+ training: bool = False,
+ interpolate_pos_encoding: bool = False,
+ ) -> tf.Tensor:
+ _, height, width, _ = pixel_values.shape
+
+ embeddings = self.patch_embeddings(pixel_values)
+ batch_size, seq_length, _ = shape_list(embeddings)
+
+ if bool_masked_pos is not None:
+ mask_tokens = tf.tile(self.mask_token, [batch_size, seq_length, 1])
+ # replace the masked visual tokens by mask_tokens
+ mask = tf.expand_dims(bool_masked_pos, axis=-1)
+ mask = tf.cast(mask, dtype=mask_tokens.dtype)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
+ distillation_tokens = tf.repeat(self.distillation_token, repeats=batch_size, axis=0)
+ embeddings = tf.concat((cls_tokens, distillation_tokens, embeddings), axis=1)
+ position_embedding = self.position_embeddings
+ if interpolate_pos_encoding:
+ position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
+
+ embeddings = embeddings + position_embedding
+ embeddings = self.dropout(embeddings, training=training)
+ return embeddings
+
+
+class TFDeiTPatchEmbeddings(keras.layers.Layer):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config: DeiTConfig, **kwargs) -> None:
+ super().__init__(**kwargs)
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = keras.layers.Conv2D(
+ hidden_size, kernel_size=patch_size, strides=patch_size, name="projection"
+ )
+
+ def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
+ batch_size, height, width, num_channels = shape_list(pixel_values)
+ if tf.executing_eagerly() and num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+
+ x = self.projection(pixel_values)
+ batch_size, height, width, num_channels = shape_list(x)
+ x = tf.reshape(x, (batch_size, height * width, num_channels))
+ return x
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "projection", None) is not None:
+ with tf.name_scope(self.projection.name):
+ self.projection.build([None, None, None, self.num_channels])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfAttention with ViT->DeiT
+class TFDeiTSelfAttention(keras.layers.Layer):
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number "
+ f"of attention heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
+
+ self.query = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
+ )
+ self.key = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
+ )
+ self.value = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
+ )
+ self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
+ self.config = config
+
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
+
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ batch_size = shape_list(hidden_states)[0]
+ mixed_query_layer = self.query(inputs=hidden_states)
+ mixed_key_layer = self.key(inputs=hidden_states)
+ mixed_value_layer = self.value(inputs=hidden_states)
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
+ key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
+ value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ # (batch size, num_heads, seq_len_q, seq_len_k)
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+ dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
+ attention_scores = tf.divide(attention_scores, dk)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(inputs=attention_probs, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = tf.multiply(attention_probs, head_mask)
+
+ attention_output = tf.matmul(attention_probs, value_layer)
+ attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
+
+ # (batch_size, seq_len_q, all_head_size)
+ attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "query", None) is not None:
+ with tf.name_scope(self.query.name):
+ self.query.build([None, None, self.config.hidden_size])
+ if getattr(self, "key", None) is not None:
+ with tf.name_scope(self.key.name):
+ self.key.build([None, None, self.config.hidden_size])
+ if getattr(self, "value", None) is not None:
+ with tf.name_scope(self.value.name):
+ self.value.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfOutput with ViT->DeiT
+class TFDeiTSelfOutput(keras.layers.Layer):
+ """
+ The residual connection is defined in TFDeiTLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTAttention with ViT->DeiT
+class TFDeiTAttention(keras.layers.Layer):
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.self_attention = TFDeiTSelfAttention(config, name="attention")
+ self.dense_output = TFDeiTSelfOutput(config, name="output")
+
+ def prune_heads(self, heads):
+ raise NotImplementedError
+
+ def call(
+ self,
+ input_tensor: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ self_outputs = self.self_attention(
+ hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
+ )
+ attention_output = self.dense_output(
+ hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
+ )
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "self_attention", None) is not None:
+ with tf.name_scope(self.self_attention.name):
+ self.self_attention.build(None)
+ if getattr(self, "dense_output", None) is not None:
+ with tf.name_scope(self.dense_output.name):
+ self.dense_output.build(None)
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->DeiT
+class TFDeiTIntermediate(keras.layers.Layer):
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
+ else:
+ self.intermediate_act_fn = config.hidden_act
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTOutput with ViT->DeiT
+class TFDeiTOutput(keras.layers.Layer):
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+ hidden_states = hidden_states + input_tensor
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.intermediate_size])
+
+
+class TFDeiTLayer(keras.layers.Layer):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.attention = TFDeiTAttention(config, name="attention")
+ self.intermediate = TFDeiTIntermediate(config, name="intermediate")
+ self.deit_output = TFDeiTOutput(config, name="output")
+
+ self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
+ self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
+ self.config = config
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ attention_outputs = self.attention(
+ # in DeiT, layernorm is applied before self-attention
+ input_tensor=self.layernorm_before(inputs=hidden_states, training=training),
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ attention_output = attention_outputs[0]
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in DeiT, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(inputs=hidden_states, training=training)
+
+ intermediate_output = self.intermediate(hidden_states=layer_output, training=training)
+
+ # second residual connection is done here
+ layer_output = self.deit_output(
+ hidden_states=intermediate_output, input_tensor=hidden_states, training=training
+ )
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "intermediate", None) is not None:
+ with tf.name_scope(self.intermediate.name):
+ self.intermediate.build(None)
+ if getattr(self, "deit_output", None) is not None:
+ with tf.name_scope(self.deit_output.name):
+ self.deit_output.build(None)
+ if getattr(self, "layernorm_before", None) is not None:
+ with tf.name_scope(self.layernorm_before.name):
+ self.layernorm_before.build([None, None, self.config.hidden_size])
+ if getattr(self, "layernorm_after", None) is not None:
+ with tf.name_scope(self.layernorm_after.name):
+ self.layernorm_after.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTEncoder with ViT->DeiT
+class TFDeiTEncoder(keras.layers.Layer):
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.layer = [TFDeiTLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ output_hidden_states: bool,
+ return_dict: bool,
+ training: bool = False,
+ ) -> TFBaseModelOutput | tuple[tf.Tensor]:
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_outputs = layer_module(
+ hidden_states=hidden_states,
+ head_mask=head_mask[i],
+ output_attentions=output_attentions,
+ training=training,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
+
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layer", None) is not None:
+ for layer in self.layer:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+@keras_serializable
+class TFDeiTMainLayer(keras.layers.Layer):
+ config_class = DeiTConfig
+
+ def __init__(
+ self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.config = config
+
+ self.embeddings = TFDeiTEmbeddings(config, use_mask_token=use_mask_token, name="embeddings")
+ self.encoder = TFDeiTEncoder(config, name="encoder")
+
+ self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+ self.pooler = TFDeiTPooler(config, name="pooler") if add_pooling_layer else None
+
+ def get_input_embeddings(self) -> TFDeiTPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ raise NotImplementedError
+
+ def get_head_mask(self, head_mask):
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.config.num_hidden_layers
+
+ return head_mask
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ bool_masked_pos: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ interpolate_pos_encoding: bool = False,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor, ...]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # TF 2.0 image layers can't use NCHW format when running on CPU.
+ # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
+ pixel_values = tf.transpose(pixel_values, (0, 2, 3, 1))
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask)
+
+ embedding_output = self.embeddings(
+ pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ training=training,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output, training=training)
+ pooled_output = self.pooler(sequence_output, training=training) if self.pooler is not None else None
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+ return head_outputs + encoder_outputs[1:]
+
+ return TFBaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, None, self.config.hidden_size])
+ if getattr(self, "pooler", None) is not None:
+ with tf.name_scope(self.pooler.name):
+ self.pooler.build(None)
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTPreTrainedModel with ViT->DeiT all-casing
+class TFDeiTPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DeiTConfig
+ base_model_prefix = "deit"
+ main_input_name = "pixel_values"
+
+
+DEIT_START_DOCSTRING = r"""
+ This model is a TensorFlow
+ [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer). Use it as a regular
+ TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and behavior.
+
+ Parameters:
+ config ([`DeiTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DEIT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`DeiTImageProcessor.__call__`] for details.
+
+ head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
+ Whether to interpolate the pre-trained position encodings.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare DeiT Model transformer outputting raw hidden-states without any specific head on top.",
+ DEIT_START_DOCSTRING,
+)
+class TFDeiTModel(TFDeiTPreTrainedModel):
+ def __init__(
+ self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
+ ) -> None:
+ super().__init__(config, **kwargs)
+
+ self.deit = TFDeiTMainLayer(
+ config, add_pooling_layer=add_pooling_layer, use_mask_token=use_mask_token, name="deit"
+ )
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ bool_masked_pos: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ interpolate_pos_encoding: bool = False,
+ training: bool = False,
+ ) -> tuple | TFBaseModelOutputWithPooling:
+ outputs = self.deit(
+ pixel_values=pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ training=training,
+ )
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "deit", None) is not None:
+ with tf.name_scope(self.deit.name):
+ self.deit.build(None)
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTPooler with ViT->DeiT
+class TFDeiTPooler(keras.layers.Layer):
+ def __init__(self, config: DeiTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.pooler_output_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation=config.pooler_act,
+ name="dense",
+ )
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(inputs=first_token_tensor)
+
+ return pooled_output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFDeitPixelShuffle(keras.layers.Layer):
+ """TF layer implementation of torch.nn.PixelShuffle"""
+
+ def __init__(self, upscale_factor: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ if not isinstance(upscale_factor, int) or upscale_factor < 2:
+ raise ValueError(f"upscale_factor must be an integer value >= 2 got {upscale_factor}")
+ self.upscale_factor = upscale_factor
+
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ hidden_states = x
+ batch_size, _, _, num_input_channels = shape_list(hidden_states)
+ block_size_squared = self.upscale_factor**2
+ output_depth = int(num_input_channels / block_size_squared)
+ # When the number of output channels >= 2, PyTorch's PixelShuffle and
+ # TF's depth_to_space differ in their output as the order of channels selected for combining
+ # is a permutation of the other c.f.
+ # https://stackoverflow.com/questions/68272502/tf-depth-to-space-not-same-as-torchs-pixelshuffle-when-output-channels-1
+ permutation = tf.constant(
+ [[i + j * block_size_squared for i in range(block_size_squared) for j in range(output_depth)]]
+ )
+ hidden_states = tf.gather(params=hidden_states, indices=tf.tile(permutation, [batch_size, 1]), batch_dims=-1)
+ hidden_states = tf.nn.depth_to_space(hidden_states, block_size=self.upscale_factor, data_format="NHWC")
+ return hidden_states
+
+
+class TFDeitDecoder(keras.layers.Layer):
+ def __init__(self, config: DeiTConfig, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.conv2d = keras.layers.Conv2D(
+ filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, name="0"
+ )
+ self.pixel_shuffle = TFDeitPixelShuffle(config.encoder_stride, name="1")
+ self.config = config
+
+ def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = inputs
+ hidden_states = self.conv2d(hidden_states)
+ hidden_states = self.pixel_shuffle(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "conv2d", None) is not None:
+ with tf.name_scope(self.conv2d.name):
+ self.conv2d.build([None, None, None, self.config.hidden_size])
+ if getattr(self, "pixel_shuffle", None) is not None:
+ with tf.name_scope(self.pixel_shuffle.name):
+ self.pixel_shuffle.build(None)
+
+
+@add_start_docstrings(
+ "DeiT Model with a decoder on top for masked image modeling, as proposed in"
+ " [SimMIM](https://huggingface.co/papers/2111.09886).",
+ DEIT_START_DOCSTRING,
+)
+class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
+ def __init__(self, config: DeiTConfig) -> None:
+ super().__init__(config)
+
+ self.deit = TFDeiTMainLayer(config, add_pooling_layer=False, use_mask_token=True, name="deit")
+ self.decoder = TFDeitDecoder(config, name="decoder")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ bool_masked_pos: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ interpolate_pos_encoding: bool = False,
+ training: bool = False,
+ ) -> tuple | TFMaskedImageModelingOutput:
+ r"""
+ bool_masked_pos (`tf.Tensor` of type bool and shape `(batch_size, num_patches)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import AutoImageProcessor, TFDeiTForMaskedImageModeling
+ >>> import tensorflow as tf
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
+ >>> model = TFDeiTForMaskedImageModeling.from_pretrained("facebook/deit-base-distilled-patch16-224")
+
+ >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
+ >>> pixel_values = image_processor(images=image, return_tensors="tf").pixel_values
+ >>> # create random boolean mask of shape (batch_size, num_patches)
+ >>> bool_masked_pos = tf.cast(tf.random.uniform((1, num_patches), minval=0, maxval=2, dtype=tf.int32), tf.bool)
+
+ >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
+ >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
+ >>> list(reconstructed_pixel_values.shape)
+ [1, 3, 224, 224]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deit(
+ pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+
+ # Reshape to (batch_size, num_channels, height, width)
+ sequence_output = sequence_output[:, 1:-1]
+ batch_size, sequence_length, num_channels = shape_list(sequence_output)
+ height = width = int(sequence_length**0.5)
+ sequence_output = tf.reshape(sequence_output, (batch_size, height, width, num_channels))
+
+ # Reconstruct pixel values
+ reconstructed_pixel_values = self.decoder(sequence_output, training=training)
+ # TF 2.0 image layers can't use NCHW format when running on CPU, so intermediate layers use NHWC,
+ # including the decoder. We transpose to compute the loss against the pixel values
+ # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width)
+ reconstructed_pixel_values = tf.transpose(reconstructed_pixel_values, (0, 3, 1, 2))
+
+ masked_im_loss = None
+ if bool_masked_pos is not None:
+ size = self.config.image_size // self.config.patch_size
+ bool_masked_pos = tf.reshape(bool_masked_pos, (-1, size, size))
+ mask = tf.repeat(bool_masked_pos, self.config.patch_size, 1)
+ mask = tf.repeat(mask, self.config.patch_size, 2)
+ mask = tf.expand_dims(mask, 1)
+ mask = tf.cast(mask, tf.float32)
+
+ reconstruction_loss = keras.losses.mean_absolute_error(
+ # Swap axes as metric calculation reduces over the final dimension
+ tf.transpose(pixel_values, (1, 2, 3, 0)),
+ tf.transpose(reconstructed_pixel_values, (1, 2, 3, 0)),
+ )
+ reconstruction_loss = tf.expand_dims(reconstruction_loss, 0)
+ total_loss = tf.reduce_sum(reconstruction_loss * mask)
+ num_masked_pixels = (tf.reduce_sum(mask) + 1e-5) * self.config.num_channels
+ masked_im_loss = total_loss / num_masked_pixels
+ masked_im_loss = tf.reshape(masked_im_loss, (1,))
+
+ if not return_dict:
+ output = (reconstructed_pixel_values,) + outputs[1:]
+ return ((masked_im_loss,) + output) if masked_im_loss is not None else output
+
+ return TFMaskedImageModelingOutput(
+ loss=masked_im_loss,
+ reconstruction=reconstructed_pixel_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "deit", None) is not None:
+ with tf.name_scope(self.deit.name):
+ self.deit.build(None)
+ if getattr(self, "decoder", None) is not None:
+ with tf.name_scope(self.decoder.name):
+ self.decoder.build(None)
+
+
+@add_start_docstrings(
+ """
+ DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
+ the [CLS] token) e.g. for ImageNet.
+ """,
+ DEIT_START_DOCSTRING,
+)
+class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: DeiTConfig):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.deit = TFDeiTMainLayer(config, add_pooling_layer=False, name="deit")
+
+ # Classifier head
+ self.classifier = (
+ keras.layers.Dense(config.num_labels, name="classifier")
+ if config.num_labels > 0
+ else keras.layers.Activation("linear", name="classifier")
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ labels: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ interpolate_pos_encoding: bool = False,
+ training: bool = False,
+ ) -> tf.Tensor | TFImageClassifierOutput:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, TFDeiTForImageClassification
+ >>> import tensorflow as tf
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> keras.utils.set_random_seed(3) # doctest: +IGNORE_RESULT
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> # note: we are loading a TFDeiTForImageClassificationWithTeacher from the hub here,
+ >>> # so the head will be randomly initialized, hence the predictions will be random
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
+ >>> model = TFDeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")
+
+ >>> inputs = image_processor(images=image, return_tensors="tf")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the 1000 ImageNet classes
+ >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
+ >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
+ Predicted class: little blue heron, Egretta caerulea
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deit(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.classifier(sequence_output[:, 0, :])
+ # we don't use the distillation token
+
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "deit", None) is not None:
+ with tf.name_scope(self.deit.name):
+ self.deit.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of
+ the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.
+
+ .. warning::
+
+ This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
+ supported.
+ """,
+ DEIT_START_DOCSTRING,
+)
+class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel):
+ def __init__(self, config: DeiTConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.deit = TFDeiTMainLayer(config, add_pooling_layer=False, name="deit")
+
+ # Classifier heads
+ self.cls_classifier = (
+ keras.layers.Dense(config.num_labels, name="cls_classifier")
+ if config.num_labels > 0
+ else keras.layers.Activation("linear", name="cls_classifier")
+ )
+ self.distillation_classifier = (
+ keras.layers.Dense(config.num_labels, name="distillation_classifier")
+ if config.num_labels > 0
+ else keras.layers.Activation("linear", name="distillation_classifier")
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=TFDeiTForImageClassificationWithTeacherOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def call(
+ self,
+ pixel_values: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ interpolate_pos_encoding: bool = False,
+ training: bool = False,
+ ) -> tuple | TFDeiTForImageClassificationWithTeacherOutput:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deit(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+
+ cls_logits = self.cls_classifier(sequence_output[:, 0, :])
+ distillation_logits = self.distillation_classifier(sequence_output[:, 1, :])
+
+ # during inference, return the average of both classifier predictions
+ logits = (cls_logits + distillation_logits) / 2
+
+ if not return_dict:
+ output = (logits, cls_logits, distillation_logits) + outputs[1:]
+ return output
+
+ return TFDeiTForImageClassificationWithTeacherOutput(
+ logits=logits,
+ cls_logits=cls_logits,
+ distillation_logits=distillation_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "deit", None) is not None:
+ with tf.name_scope(self.deit.name):
+ self.deit.build(None)
+ if getattr(self, "cls_classifier", None) is not None:
+ with tf.name_scope(self.cls_classifier.name):
+ self.cls_classifier.build([None, None, self.config.hidden_size])
+ if getattr(self, "distillation_classifier", None) is not None:
+ with tf.name_scope(self.distillation_classifier.name):
+ self.distillation_classifier.build([None, None, self.config.hidden_size])
+
+
+__all__ = [
+ "TFDeiTForImageClassification",
+ "TFDeiTForImageClassificationWithTeacher",
+ "TFDeiTForMaskedImageModeling",
+ "TFDeiTModel",
+ "TFDeiTPreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/deprecated/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/deprecated/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e293c354e1e92a431a601da77d7555f2ecfe29ef
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/deprecated/__init__.py
@@ -0,0 +1,49 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .bort import *
+ from .deta import *
+ from .efficientformer import *
+ from .ernie_m import *
+ from .gptsan_japanese import *
+ from .graphormer import *
+ from .jukebox import *
+ from .mctct import *
+ from .mega import *
+ from .mmbt import *
+ from .nat import *
+ from .nezha import *
+ from .open_llama import *
+ from .qdqbert import *
+ from .realm import *
+ from .retribert import *
+ from .speech_to_text_2 import *
+ from .tapex import *
+ from .trajectory_transformer import *
+ from .transfo_xl import *
+ from .tvlt import *
+ from .van import *
+ from .vit_hybrid import *
+ from .xlm_prophetnet import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dinov2_with_registers/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/dinov2_with_registers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d10027b6a3b6375235a6785df044e8f0ce5fb33
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dinov2_with_registers/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_dinov2_with_registers import *
+ from .modeling_dinov2_with_registers import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py b/venv/lib/python3.13/site-packages/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec4f446fc684f40d634927c1e7a52b64c5732b12
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py
@@ -0,0 +1,159 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_dinov2_with_registers.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an
+ Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DINOv2 with Registers
+ [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ mlp_ratio (`int`, *optional*, defaults to 4):
+ Ratio of the hidden size of the MLPs relative to the `hidden_size`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ layerscale_value (`float`, *optional*, defaults to 1.0):
+ Initial value to use for layer scale.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
+ use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
+ Whether to use the SwiGLU feedforward neural network.
+ num_register_tokens (`int`, *optional*, defaults to 4):
+ Number of register tokens to use.
+ out_features (`list[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`list[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ apply_layernorm (`bool`, *optional*, defaults to `True`):
+ Whether to apply layer normalization to the feature maps in case the model is used as backbone.
+ reshape_hidden_states (`bool`, *optional*, defaults to `True`):
+ Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
+ case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
+ seq_len, hidden_size)`.
+
+ Example:
+
+ ```python
+ >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel
+
+ >>> # Initializing a Dinov2WithRegisters base style configuration
+ >>> configuration = Dinov2WithRegistersConfig()
+
+ >>> # Initializing a model (with random weights) from the base style configuration
+ >>> model = Dinov2WithRegistersModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "dinov2_with_registers"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ mlp_ratio=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ qkv_bias=True,
+ layerscale_value=1.0,
+ drop_path_rate=0.0,
+ use_swiglu_ffn=False,
+ num_register_tokens=4,
+ out_features=None,
+ out_indices=None,
+ apply_layernorm=True,
+ reshape_hidden_states=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.mlp_ratio = mlp_ratio
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.layerscale_value = layerscale_value
+ self.drop_path_rate = drop_path_rate
+ self.use_swiglu_ffn = use_swiglu_ffn
+ self.num_register_tokens = num_register_tokens
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+ self.apply_layernorm = apply_layernorm
+ self.reshape_hidden_states = reshape_hidden_states
+
+
+__all__ = ["Dinov2WithRegistersConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/venv/lib/python3.13/site-packages/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2c9b9174bc719f3da1f5f496f2656b883baef70
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py
@@ -0,0 +1,712 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_dinov2_with_registers.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections.abc
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import TransformersKwargs, auto_docstring, torch_int
+from ...utils.backbone_utils import BackboneMixin
+from ...utils.generic import can_return_tuple, check_model_inputs
+from .configuration_dinov2_with_registers import Dinov2WithRegistersConfig
+
+
+class Dinov2WithRegistersPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ f" Expected {self.num_channels} but got {num_channels}."
+ )
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+class Dinov2WithRegistersEmbeddings(nn.Module):
+ """
+ Construct the CLS token, mask token, register tokens, position and patch embeddings.
+ """
+
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
+ self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
+ self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
+ with the original implementation.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+ - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
+ """
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+
+ # Skip interpolation for matching dimensions (unless tracing)
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ # Handle class token and patch embeddings separately
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+
+ # Calculate new dimensions
+ height = height // self.config.patch_size
+ width = width // self.config.patch_size
+
+ # Reshape for interpolation
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ # Store original dtype for restoration after interpolation
+ target_dtype = patch_pos_embed.dtype
+
+ # Interpolate at float32 precision
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.to(dtype=torch.float32),
+ size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor
+ mode="bicubic",
+ align_corners=False,
+ antialias=True,
+ ).to(dtype=target_dtype)
+
+ # Validate output dimensions if not tracing
+ if not torch.jit.is_tracing():
+ if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
+ raise ValueError("Width or height does not match with the interpolated position embeddings")
+
+ # Reshape back to original format
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ # Combine class and patch embeddings
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ target_dtype = self.patch_embeddings.projection.weight.dtype
+ embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
+
+ if bool_masked_pos is not None:
+ embeddings = torch.where(
+ bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
+ )
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+
+ # add register tokens
+ embeddings = torch.cat(
+ (embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1
+ )
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+
+ # Normalize the attention scores to probabilities.
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ # Mask heads if we want to
+ if attention_mask is not None:
+ attn_weights = attn_weights * attention_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Dinov2WithRegistersSelfAttention(nn.Module):
+ def __init__(self, config: Dinov2WithRegistersConfig):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.config = config
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.dropout_prob = config.attention_probs_dropout_prob
+ self.scaling = self.attention_head_size**-0.5
+ self.is_causal = False
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ def forward(
+ self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size = hidden_states.shape[0]
+ new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
+
+ key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
+ value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
+ query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ context_layer, attention_probs = attention_interface(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ head_mask,
+ is_causal=self.is_causal,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.dropout_prob,
+ )
+
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.reshape(new_context_layer_shape)
+
+ return context_layer, attention_probs
+
+
+class Dinov2WithRegistersSelfOutput(nn.Module):
+ """
+ The residual connection is defined in Dinov2WithRegistersLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: Dinov2WithRegistersConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class Dinov2WithRegistersAttention(nn.Module):
+ def __init__(self, config: Dinov2WithRegistersConfig):
+ super().__init__()
+ self.attention = Dinov2WithRegistersSelfAttention(config)
+ self.output = Dinov2WithRegistersSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: set[int]):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ self_attn_output, _ = self.attention(hidden_states, head_mask)
+ output = self.output(self_attn_output, hidden_states)
+ return output
+
+
+class Dinov2WithRegistersLayerScale(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ return hidden_state * self.lambda1
+
+
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+class Dinov2WithRegistersDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return f"p={self.drop_prob}"
+
+
+class Dinov2WithRegistersMLP(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
+ if isinstance(config.hidden_act, str):
+ self.activation = ACT2FN[config.hidden_act]
+ else:
+ self.activation = config.hidden_act
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.fc1(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.fc2(hidden_state)
+ return hidden_state
+
+
+class Dinov2WithRegistersSwiGLUFFN(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.weights_in(hidden_state)
+ x1, x2 = hidden_state.chunk(2, dim=-1)
+ hidden = nn.functional.silu(x1) * x2
+ return self.weights_out(hidden)
+
+
+class Dinov2WithRegistersLayer(GradientCheckpointingLayer):
+ """This corresponds to the Block class in the original implementation."""
+
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attention = Dinov2WithRegistersAttention(config)
+ self.layer_scale1 = Dinov2WithRegistersLayerScale(config)
+ self.drop_path = (
+ Dinov2WithRegistersDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+ )
+
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ if config.use_swiglu_ffn:
+ self.mlp = Dinov2WithRegistersSwiGLUFFN(config)
+ else:
+ self.mlp = Dinov2WithRegistersMLP(config)
+ self.layer_scale2 = Dinov2WithRegistersLayerScale(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states_norm = self.norm1(hidden_states)
+ self_attention_output = self.attention(hidden_states_norm, head_mask)
+ self_attention_output = self.layer_scale1(self_attention_output)
+
+ # first residual connection
+ hidden_states = self.drop_path(self_attention_output) + hidden_states
+
+ # in Dinov2WithRegisters, layernorm is also applied after self-attention
+ layer_output = self.norm2(hidden_states)
+ layer_output = self.mlp(layer_output)
+ layer_output = self.layer_scale2(layer_output)
+
+ # second residual connection
+ layer_output = self.drop_path(layer_output) + hidden_states
+
+ return layer_output
+
+
+class Dinov2WithRegistersEncoder(nn.Module):
+ def __init__(self, config: Dinov2WithRegistersConfig):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([Dinov2WithRegistersLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False
+ ) -> BaseModelOutput:
+ all_hidden_states = [hidden_states] if output_hidden_states else None
+ for i, layer_module in enumerate(self.layer):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ hidden_states = layer_module(hidden_states, layer_head_mask)
+ if all_hidden_states:
+ all_hidden_states.append(hidden_states)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
+ )
+
+
+@auto_docstring
+class Dinov2WithRegistersPreTrainedModel(PreTrainedModel):
+ config: Dinov2WithRegistersConfig
+ base_model_prefix = "dinov2_with_registers"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Dinov2WithRegistersLayer"]
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "attentions": Dinov2WithRegistersSelfAttention,
+ }
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Dinov2WithRegistersEmbeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+
+ module.cls_token.data = nn.init.trunc_normal_(
+ module.cls_token.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.cls_token.dtype)
+
+ module.mask_token.data.zero_()
+ module.register_tokens.data.zero_()
+ elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821
+ module.lambda1.data.fill_(self.config.layerscale_value)
+
+
+@auto_docstring
+class Dinov2WithRegistersModel(Dinov2WithRegistersPreTrainedModel):
+ def __init__(self, config: Dinov2WithRegistersConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = Dinov2WithRegistersEmbeddings(config)
+ self.encoder = Dinov2WithRegistersEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ **kwargs,
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
+ pre-training.
+ """
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ embedding_output, head_mask=head_mask, output_hidden_states=output_hidden_states
+ )
+ sequence_output = encoder_outputs.last_hidden_state
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = sequence_output[:, 0, :]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Dinov2WithRegisters Model transformer with an image classification head on top (a linear layer on top of the final hidden state
+ of the [CLS] token) e.g. for ImageNet.
+ """
+)
+class Dinov2WithRegistersForImageClassification(Dinov2WithRegistersPreTrainedModel):
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.dinov2_with_registers = Dinov2WithRegistersModel(config)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> ImageClassifierOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ outputs: BaseModelOutputWithPooling = self.dinov2_with_registers(pixel_values, head_mask=head_mask, **kwargs)
+ sequence_output = outputs.last_hidden_state # batch_size, sequence_length, hidden_size
+
+ cls_token = sequence_output[:, 0]
+ # cls and register tokens should not be included in patch tokens variable
+ patch_tokens = sequence_output[:, 1 + self.config.num_register_tokens :]
+
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
+ logits = self.classifier(linear_input)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config, **kwargs)
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Dinov2WithRegisters backbone, to be used with frameworks like DETR and MaskFormer.
+ """
+)
+class Dinov2WithRegistersBackbone(Dinov2WithRegistersPreTrainedModel, BackboneMixin):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+ self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
+ self.embeddings = Dinov2WithRegistersEmbeddings(config)
+ self.encoder = Dinov2WithRegistersEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ self.num_register_tokens = config.num_register_tokens
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ **kwargs,
+ ) -> BackboneOutput:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 768, 16, 16]
+ ```"""
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ embedding_output = self.embeddings(pixel_values)
+ output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True)
+ hidden_states = output.hidden_states
+
+ feature_maps = []
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ if self.config.apply_layernorm:
+ hidden_state = self.layernorm(hidden_state)
+ if self.config.reshape_hidden_states:
+ hidden_state = hidden_state[:, 1 + self.num_register_tokens :]
+ # this was actually a bug in the original implementation that we copied here,
+ # cause normally the order is height, width
+ batch_size, _, height, width = pixel_values.shape
+ patch_size = self.config.patch_size
+ hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ feature_maps.append(hidden_state)
+
+ return BackboneOutput(
+ feature_maps=tuple(feature_maps),
+ hidden_states=hidden_states if output_hidden_states else None,
+ )
+
+
+__all__ = [
+ "Dinov2WithRegistersPreTrainedModel",
+ "Dinov2WithRegistersModel",
+ "Dinov2WithRegistersForImageClassification",
+ "Dinov2WithRegistersBackbone",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py b/venv/lib/python3.13/site-packages/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py
new file mode 100644
index 0000000000000000000000000000000000000000..686528002b09c9689d66a057ed55eb1a43b0d256
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py
@@ -0,0 +1,435 @@
+# coding=utf-8
+# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ....transformers.models.dinov2.modeling_dinov2 import (
+ Dinov2Backbone,
+ Dinov2Encoder,
+ Dinov2ForImageClassification,
+ Dinov2Model,
+ Dinov2PatchEmbeddings,
+ Dinov2PreTrainedModel,
+)
+from ...configuration_utils import PretrainedConfig
+from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, logging, torch_int
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an
+ Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DINOv2 with Registers
+ [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ mlp_ratio (`int`, *optional*, defaults to 4):
+ Ratio of the hidden size of the MLPs relative to the `hidden_size`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ layerscale_value (`float`, *optional*, defaults to 1.0):
+ Initial value to use for layer scale.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
+ use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
+ Whether to use the SwiGLU feedforward neural network.
+ num_register_tokens (`int`, *optional*, defaults to 4):
+ Number of register tokens to use.
+ out_features (`list[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`list[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ apply_layernorm (`bool`, *optional*, defaults to `True`):
+ Whether to apply layer normalization to the feature maps in case the model is used as backbone.
+ reshape_hidden_states (`bool`, *optional*, defaults to `True`):
+ Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
+ case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
+ seq_len, hidden_size)`.
+
+ Example:
+
+ ```python
+ >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel
+
+ >>> # Initializing a Dinov2WithRegisters base style configuration
+ >>> configuration = Dinov2WithRegistersConfig()
+
+ >>> # Initializing a model (with random weights) from the base style configuration
+ >>> model = Dinov2WithRegistersModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "dinov2_with_registers"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ mlp_ratio=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ qkv_bias=True,
+ layerscale_value=1.0,
+ drop_path_rate=0.0,
+ use_swiglu_ffn=False,
+ num_register_tokens=4,
+ out_features=None,
+ out_indices=None,
+ apply_layernorm=True,
+ reshape_hidden_states=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.mlp_ratio = mlp_ratio
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.layerscale_value = layerscale_value
+ self.drop_path_rate = drop_path_rate
+ self.use_swiglu_ffn = use_swiglu_ffn
+ self.num_register_tokens = num_register_tokens
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+ self.apply_layernorm = apply_layernorm
+ self.reshape_hidden_states = reshape_hidden_states
+
+
+class Dinov2WithRegistersPatchEmbeddings(Dinov2PatchEmbeddings):
+ pass
+
+
+class Dinov2WithRegistersEmbeddings(nn.Module):
+ """
+ Construct the CLS token, mask token, register tokens, position and patch embeddings.
+ """
+
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
+ self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
+ self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
+ with the original implementation.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+ - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
+ """
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+
+ # Skip interpolation for matching dimensions (unless tracing)
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ # Handle class token and patch embeddings separately
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+
+ # Calculate new dimensions
+ height = height // self.config.patch_size
+ width = width // self.config.patch_size
+
+ # Reshape for interpolation
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ # Store original dtype for restoration after interpolation
+ target_dtype = patch_pos_embed.dtype
+
+ # Interpolate at float32 precision
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.to(dtype=torch.float32),
+ size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor
+ mode="bicubic",
+ align_corners=False,
+ antialias=True,
+ ).to(dtype=target_dtype)
+
+ # Validate output dimensions if not tracing
+ if not torch.jit.is_tracing():
+ if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
+ raise ValueError("Width or height does not match with the interpolated position embeddings")
+
+ # Reshape back to original format
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ # Combine class and patch embeddings
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ target_dtype = self.patch_embeddings.projection.weight.dtype
+ embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
+
+ if bool_masked_pos is not None:
+ embeddings = torch.where(
+ bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
+ )
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+
+ # add register tokens
+ embeddings = torch.cat(
+ (embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1
+ )
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class Dinov2WithRegistersEncoder(Dinov2Encoder):
+ pass
+
+
+class Dinov2WithRegistersPreTrainedModel(Dinov2PreTrainedModel):
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Dinov2WithRegistersEmbeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+
+ module.cls_token.data = nn.init.trunc_normal_(
+ module.cls_token.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.cls_token.dtype)
+
+ module.mask_token.data.zero_()
+ module.register_tokens.data.zero_()
+ elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821
+ module.lambda1.data.fill_(self.config.layerscale_value)
+
+
+class Dinov2WithRegistersModel(Dinov2Model):
+ pass
+
+
+class Dinov2WithRegistersForImageClassification(Dinov2ForImageClassification):
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> ImageClassifierOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ outputs: BaseModelOutputWithPooling = self.dinov2_with_registers(pixel_values, head_mask=head_mask, **kwargs)
+ sequence_output = outputs.last_hidden_state # batch_size, sequence_length, hidden_size
+
+ cls_token = sequence_output[:, 0]
+ # cls and register tokens should not be included in patch tokens variable
+ patch_tokens = sequence_output[:, 1 + self.config.num_register_tokens :]
+
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
+ logits = self.classifier(linear_input)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config, **kwargs)
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class Dinov2WithRegistersBackbone(Dinov2Backbone):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.num_register_tokens = config.num_register_tokens
+ self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
+ self.embeddings = Dinov2WithRegistersEmbeddings(config)
+ self.encoder = Dinov2WithRegistersEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ **kwargs,
+ ) -> BackboneOutput:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 768, 16, 16]
+ ```"""
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ embedding_output = self.embeddings(pixel_values)
+ output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True)
+ hidden_states = output.hidden_states
+
+ feature_maps = []
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ if self.config.apply_layernorm:
+ hidden_state = self.layernorm(hidden_state)
+ if self.config.reshape_hidden_states:
+ hidden_state = hidden_state[:, 1 + self.num_register_tokens :]
+ # this was actually a bug in the original implementation that we copied here,
+ # cause normally the order is height, width
+ batch_size, _, height, width = pixel_values.shape
+ patch_size = self.config.patch_size
+ hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ feature_maps.append(hidden_state)
+
+ return BackboneOutput(
+ feature_maps=tuple(feature_maps),
+ hidden_states=hidden_states if output_hidden_states else None,
+ )
+
+
+__all__ = [
+ "Dinov2WithRegistersConfig",
+ "Dinov2WithRegistersPreTrainedModel",
+ "Dinov2WithRegistersModel",
+ "Dinov2WithRegistersForImageClassification",
+ "Dinov2WithRegistersBackbone",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a74878b2053cf43fabe19a7fd72e020a0879f8e6
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_dinov3_vit import *
+ from .image_processing_dinov3_vit_fast import *
+ from .modeling_dinov3_vit import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/configuration_dinov3_vit.py b/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/configuration_dinov3_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..78cbd200ce612e6c778392c85d2f8c97a7d19c82
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/configuration_dinov3_vit.py
@@ -0,0 +1,166 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DINOv3 model configuration"""
+
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DINOv3ViTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DINOv3Model`]. It is used to instantiate an
+ DINOv3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DINOv3
+ [facebook/dinov3-vits16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vits16-pretrain-lvd1689m) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ hidden_size (`int`, *optional*, defaults to 384):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 1536):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 6):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ rope_theta (`float`, *optional*, defaults to 100.0):
+ The base period of the RoPE embeddings.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ query_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the query projection.
+ key_bias (`bool`, *optional*, defaults to `False`):
+ Whether to add a bias to the key projection.
+ value_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the value projection.
+ proj_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the output projection.
+ mlp_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the MLP layers.
+ layerscale_value (`float`, *optional*, defaults to 1.0):
+ Initial value to use for layer scale.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
+ use_gated_mlp (`bool`, *optional*, defaults to `False`):
+ Whether to use the SwiGLU feedforward neural network.
+ num_register_tokens (`int`, *optional*, defaults to 0):
+ The number of register tokens.
+ pos_embed_shift (`float`, *optional*):
+ Amount to randomly shift position embedding coordinates in [-shift, shift],
+ applied only in training mode if not `None`.
+ pos_embed_jitter (`float`, *optional*):
+ Amount to randomly jitter position embedding coordinates in log-uniform value in [1/jitter, jitter],
+ applied only in training mode if not `None`.
+ pos_embed_rescale (`float`, *optional*, defaults to 2.0):
+ Amount to randomly rescale position embedding coordinates in log-uniform value in [1/rescale, rescale],
+ applied only in training mode if not `None`.
+
+ Example:
+
+ ```python
+ >>> from transformers import DINOv3ViTConfig, DINOv3ViTModel
+
+ >>> # Initializing a DINOv3 ViT-small style configuration
+ >>> config = DINOv3ViTConfig()
+
+ >>> # Initializing a model (with random weights) from the config
+ >>> model = DINOv3ViTModel(config)
+
+ >>> # Accessing the model config
+ >>> config = model.config
+ ```"""
+
+ model_type = "dinov3_vit"
+
+ def __init__(
+ self,
+ patch_size: int = 16,
+ hidden_size: int = 384,
+ intermediate_size: int = 1536,
+ num_hidden_layers: int = 12,
+ num_attention_heads: int = 6,
+ hidden_act: str = "gelu",
+ attention_dropout: float = 0.0,
+ initializer_range: float = 0.02,
+ layer_norm_eps: float = 1e-5,
+ rope_theta: float = 100.0,
+ image_size: int = 224,
+ num_channels: int = 3,
+ query_bias: bool = True,
+ key_bias: bool = False,
+ value_bias: bool = True,
+ proj_bias: bool = True,
+ mlp_bias: bool = True,
+ layerscale_value: float = 1.0,
+ drop_path_rate: float = 0.0,
+ use_gated_mlp: bool = False,
+ num_register_tokens: int = 0,
+ # train augs
+ pos_embed_shift: Optional[float] = None,
+ pos_embed_jitter: Optional[float] = None,
+ pos_embed_rescale: Optional[float] = 2.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.layerscale_value = layerscale_value
+ self.drop_path_rate = drop_path_rate
+ self.use_gated_mlp = use_gated_mlp
+ self.rope_theta = rope_theta
+ self.query_bias = query_bias
+ self.key_bias = key_bias
+ self.value_bias = value_bias
+ self.proj_bias = proj_bias
+ self.mlp_bias = mlp_bias
+ self.num_register_tokens = num_register_tokens
+
+ # train augs
+ self.pos_embed_shift = pos_embed_shift
+ self.pos_embed_jitter = pos_embed_jitter
+ self.pos_embed_rescale = pos_embed_rescale
+
+
+__all__ = ["DINOv3ViTConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py b/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c080485ed008bc8bfa78e393e6b408fe86d172f
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py
@@ -0,0 +1,96 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for DINOv3."""
+
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from transformers.image_processing_base import BatchFeature
+from transformers.image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
+from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling, SizeDict
+from transformers.utils import (
+ TensorType,
+ auto_docstring,
+ logging,
+)
+from transformers.utils.import_utils import requires
+
+
+logger = logging.get_logger(__name__)
+
+
+@auto_docstring
+@requires(backends=("torchvision", "torch"))
+class DINOv3ViTImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_DEFAULT_MEAN
+ image_std = IMAGENET_DEFAULT_STD
+ size = {"height": 224, "width": 224}
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+
+ # Overridden for DINOv3 to preserve order of transforms
+ # rescale -> resize -> normalize
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_rescale:
+ stacked_images = self.rescale(stacked_images, rescale_factor)
+ if do_resize:
+ stacked_images = self.resize(
+ image=stacked_images, size=size, interpolation=interpolation, antialias=True
+ )
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_center_crop:
+ stacked_images = self.center_crop(stacked_images, crop_size)
+ if do_normalize:
+ stacked_images = self.normalize(stacked_images, image_mean, image_std)
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+
+__all__ = ["DINOv3ViTImageProcessorFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/modeling_dinov3_vit.py b/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/modeling_dinov3_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..bca25dcc1c2b9e3ffa271e13f60d2a27bf8ad409
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/modeling_dinov3_vit.py
@@ -0,0 +1,538 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/dinov3_vit/modular_dinov3_vit.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_dinov3_vit.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Callable, Optional
+
+import numpy as np
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPooling
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import compile_compatible_method_lru_cache
+from ...utils import TransformersKwargs, auto_docstring
+from ...utils.generic import check_model_inputs
+from .configuration_dinov3_vit import DINOv3ViTConfig
+
+
+class DINOv3ViTEmbeddings(nn.Module):
+ """
+ Construct the CLS token, mask token, position and patch embeddings.
+ """
+
+ def __init__(self, config: DINOv3ViTConfig):
+ super().__init__()
+ self.config = config
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.register_tokens = nn.Parameter(torch.empty(1, config.num_register_tokens, config.hidden_size))
+ self.patch_embeddings = nn.Conv2d(
+ config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size
+ )
+
+ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
+ batch_size = pixel_values.shape[0]
+ target_dtype = self.patch_embeddings.weight.dtype
+
+ # (batch_size, num_channels, height, width) -> (batch_size, num_patches, hidden_size)
+ patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
+ patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
+
+ if bool_masked_pos is not None:
+ mask_token = self.mask_token.to(patch_embeddings.dtype)
+ patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
+
+ # Add CLS and register tokens
+ cls_token = self.cls_token.expand(batch_size, -1, -1)
+ register_tokens = self.register_tokens.expand(batch_size, -1, -1)
+ embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
+
+ return embeddings
+
+
+@compile_compatible_method_lru_cache(maxsize=32)
+def get_patches_center_coordinates(
+ num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device
+) -> torch.Tensor:
+ """
+ Computes the 2D coordinates of the centers of image patches, normalized to the range [-1, +1].
+ The center of each patch is exactly halfway between its top-left and bottom-right corners.
+
+ Args:
+ num_patches_h (int): Number of patches along the vertical (height) axis.
+ num_patches_w (int): Number of patches along the horizontal (width) axis.
+ dtype (torch.dtype): The desired data type of the returned tensor.
+
+ Returns:
+ torch.Tensor: A tensor of shape (height * width, 2), where each row contains the (y, x)
+ coordinates of a patch center, normalized to [-1, +1].
+ """
+ coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
+ coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
+ coords_h = coords_h / num_patches_h
+ coords_w = coords_w / num_patches_w
+ # (height, width, 2) -> (height * width, 2)
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
+ coords = coords.flatten(0, 1)
+ # Shift range [0, 1] to [-1, +1]
+ coords = 2.0 * coords - 1.0
+ return coords
+
+
+def augment_patches_center_coordinates(
+ coords: torch.Tensor,
+ shift: Optional[float] = None,
+ jitter: Optional[float] = None,
+ rescale: Optional[float] = None,
+) -> torch.Tensor:
+ # Shift coords by adding a uniform value in [-shift, shift]
+ if shift is not None:
+ shift_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype)
+ shift_hw = shift_hw.uniform_(-shift, shift)
+ coords = coords + shift_hw
+
+ # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter]
+ if jitter is not None:
+ jitter_range = np.log(jitter)
+ jitter_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype)
+ jitter_hw = jitter_hw.uniform_(-jitter_range, jitter_range).exp()
+ coords = coords * jitter_hw
+
+ # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale]
+ if rescale is not None:
+ rescale_range = np.log(rescale)
+ rescale_hw = torch.empty(1, device=coords.device, dtype=coords.dtype)
+ rescale_hw = rescale_hw.uniform_(-rescale_range, rescale_range).exp()
+ coords = coords * rescale_hw
+
+ return coords
+
+
+class DINOv3ViTRopePositionEmbedding(nn.Module):
+ inv_freq: torch.Tensor
+
+ def __init__(self, config: DINOv3ViTConfig):
+ super().__init__()
+
+ self.config = config
+ self.base = config.rope_theta
+ self.head_dim = config.hidden_size // config.num_attention_heads
+ self.num_patches_h = config.image_size // config.patch_size
+ self.num_patches_w = config.image_size // config.patch_size
+
+ inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32) # (head_dim / 4,)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ _, _, height, width = pixel_values.shape
+ num_patches_h = height // self.config.patch_size
+ num_patches_w = width // self.config.patch_size
+
+ device = pixel_values.device
+ device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
+
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ # Although we could precompute static patch_coords from image_size and patch_size in the config,
+ # the model was trained with random_scale, so it can process images of varying sizes.
+ # Therefore, it's better to compute patch_coords dynamically (with lru_cache).
+ patch_coords = get_patches_center_coordinates(
+ num_patches_h, num_patches_w, dtype=torch.float32, device=device
+ )
+ if self.training:
+ patch_coords = augment_patches_center_coordinates(
+ patch_coords,
+ shift=self.config.pos_embed_shift,
+ jitter=self.config.pos_embed_jitter,
+ rescale=self.config.pos_embed_rescale,
+ )
+
+ # (height * width, 2, head_dim / 4) -> (height * width, head_dim / 2) -> (height * width, head_dim)
+ angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
+ angles = angles.flatten(1, 2)
+ angles = angles.tile(2)
+
+ cos = torch.cos(angles)
+ sin = torch.sin(angles)
+
+ dtype = pixel_values.dtype
+ return cos.to(dtype=dtype), sin.to(dtype=dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+
+ # Normalize the attention scores to probabilities.
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ # Mask heads if we want to
+ if attention_mask is not None:
+ attn_weights = attn_weights * attention_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def apply_rotary_pos_emb(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, **kwargs
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Applies Rotary Position Embedding to the query and key tensors, but only to the patch tokens,
+ ignoring the prefix tokens (cls token and register tokens).
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+
+ num_tokens = q.shape[-2]
+ num_patches = sin.shape[-2]
+ num_prefix_tokens = num_tokens - num_patches # cls token + register tokens
+
+ q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
+ k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
+
+ # apply rope only to patch tokens
+ q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
+ k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
+
+ q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
+ k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
+
+ return q, k
+
+
+class DINOv3ViTAttention(nn.Module):
+ """
+ Multi-headed attention compatible with ALL_ATTENTION_FUNCTIONS.
+ """
+
+ def __init__(self, config: DINOv3ViTConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ self.is_causal = False
+
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.dropout = config.attention_dropout
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.key_bias)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias)
+
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.query_bias)
+ self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, patches, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class DINOv3ViTLayerScale(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ return hidden_state * self.lambda1
+
+
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+class DINOv3ViTDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return f"p={self.drop_prob}"
+
+
+class DINOv3ViTMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.up_proj(x)))
+
+
+class DINOv3ViTGatedMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class DINOv3ViTLayer(GradientCheckpointingLayer):
+ """This corresponds to the Block class in the original implementation."""
+
+ def __init__(self, config: DINOv3ViTConfig):
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attention = DINOv3ViTAttention(config)
+ self.layer_scale1 = DINOv3ViTLayerScale(config)
+ self.drop_path = DINOv3ViTDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ if config.use_gated_mlp:
+ self.mlp = DINOv3ViTGatedMLP(config)
+ else:
+ self.mlp = DINOv3ViTMLP(config)
+ self.layer_scale2 = DINOv3ViTLayerScale(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ # Attention with residual connection
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ hidden_states, _ = self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_embeddings=position_embeddings,
+ )
+ hidden_states = self.layer_scale1(hidden_states)
+ hidden_states = self.drop_path(hidden_states) + residual
+
+ # MLP with residual connection
+ residual = hidden_states
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.layer_scale2(hidden_states)
+ hidden_states = self.drop_path(hidden_states) + residual
+
+ return hidden_states
+
+
+@auto_docstring
+class DINOv3ViTPreTrainedModel(PreTrainedModel):
+ config: DINOv3ViTConfig
+ base_model_prefix = "dinov3_vit"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["DINOv3ViTLayer"]
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": DINOv3ViTLayer,
+ "attentions": DINOv3ViTAttention,
+ }
+
+ def _init_weights(self, module) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, DINOv3ViTEmbeddings):
+ module.cls_token.data = nn.init.trunc_normal_(
+ module.cls_token.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.cls_token.dtype)
+ if module.config.num_register_tokens > 0:
+ module.register_tokens.data = nn.init.trunc_normal_(
+ module.register_tokens.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.register_tokens.dtype)
+ module.mask_token.data.zero_()
+ elif isinstance(module, DINOv3ViTLayerScale):
+ module.lambda1.data.fill_(self.config.layerscale_value)
+
+
+@auto_docstring
+class DINOv3ViTModel(DINOv3ViTPreTrainedModel):
+ def __init__(self, config: DINOv3ViTConfig):
+ super().__init__(config)
+ self.config = config
+ self.embeddings = DINOv3ViTEmbeddings(config)
+ self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
+ self.layer = nn.ModuleList([DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)])
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
+ pre-training.
+ """
+
+ pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
+ hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
+ position_embeddings = self.rope_embeddings(pixel_values)
+
+ for i, layer_module in enumerate(self.layer):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ hidden_states = layer_module(
+ hidden_states,
+ attention_mask=layer_head_mask,
+ position_embeddings=position_embeddings,
+ )
+
+ sequence_output = self.norm(hidden_states)
+ pooled_output = sequence_output[:, 0, :]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ )
+
+
+__all__ = ["DINOv3ViTModel", "DINOv3ViTPreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/modular_dinov3_vit.py b/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/modular_dinov3_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..88597336fb191be44ff2825d152f9a3551b47b28
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dinov3_vit/modular_dinov3_vit.py
@@ -0,0 +1,428 @@
+# coding=utf-8
+# Copyright 2025 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DINOv3 model."""
+
+import math
+from typing import Callable, Optional
+
+import numpy as np
+import torch
+from torch import nn
+
+from transformers.models.arcee.modeling_arcee import ArceeMLP
+from transformers.models.dinov2.modeling_dinov2 import (
+ Dinov2DropPath,
+ Dinov2LayerScale,
+ Dinov2PreTrainedModel,
+ eager_attention_forward,
+)
+from transformers.models.llama.modeling_llama import LlamaMLP
+from transformers.models.pixtral.modeling_pixtral import PixtralAttention, rotate_half
+
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPooling
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...pytorch_utils import compile_compatible_method_lru_cache
+from ...utils import TransformersKwargs, auto_docstring, logging
+from ...utils.generic import check_model_inputs
+from .configuration_dinov3_vit import DINOv3ViTConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class DINOv3ViTEmbeddings(nn.Module):
+ """
+ Construct the CLS token, mask token, position and patch embeddings.
+ """
+
+ def __init__(self, config: DINOv3ViTConfig):
+ super().__init__()
+ self.config = config
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.register_tokens = nn.Parameter(torch.empty(1, config.num_register_tokens, config.hidden_size))
+ self.patch_embeddings = nn.Conv2d(
+ config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size
+ )
+
+ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
+ batch_size = pixel_values.shape[0]
+ target_dtype = self.patch_embeddings.weight.dtype
+
+ # (batch_size, num_channels, height, width) -> (batch_size, num_patches, hidden_size)
+ patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
+ patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
+
+ if bool_masked_pos is not None:
+ mask_token = self.mask_token.to(patch_embeddings.dtype)
+ patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
+
+ # Add CLS and register tokens
+ cls_token = self.cls_token.expand(batch_size, -1, -1)
+ register_tokens = self.register_tokens.expand(batch_size, -1, -1)
+ embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
+
+ return embeddings
+
+
+@compile_compatible_method_lru_cache(maxsize=32)
+def get_patches_center_coordinates(
+ num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device
+) -> torch.Tensor:
+ """
+ Computes the 2D coordinates of the centers of image patches, normalized to the range [-1, +1].
+ The center of each patch is exactly halfway between its top-left and bottom-right corners.
+
+ Args:
+ num_patches_h (int): Number of patches along the vertical (height) axis.
+ num_patches_w (int): Number of patches along the horizontal (width) axis.
+ dtype (torch.dtype): The desired data type of the returned tensor.
+
+ Returns:
+ torch.Tensor: A tensor of shape (height * width, 2), where each row contains the (y, x)
+ coordinates of a patch center, normalized to [-1, +1].
+ """
+ coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
+ coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
+ coords_h = coords_h / num_patches_h
+ coords_w = coords_w / num_patches_w
+ # (height, width, 2) -> (height * width, 2)
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
+ coords = coords.flatten(0, 1)
+ # Shift range [0, 1] to [-1, +1]
+ coords = 2.0 * coords - 1.0
+ return coords
+
+
+def augment_patches_center_coordinates(
+ coords: torch.Tensor,
+ shift: Optional[float] = None,
+ jitter: Optional[float] = None,
+ rescale: Optional[float] = None,
+) -> torch.Tensor:
+ # Shift coords by adding a uniform value in [-shift, shift]
+ if shift is not None:
+ shift_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype)
+ shift_hw = shift_hw.uniform_(-shift, shift)
+ coords = coords + shift_hw
+
+ # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter]
+ if jitter is not None:
+ jitter_range = np.log(jitter)
+ jitter_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype)
+ jitter_hw = jitter_hw.uniform_(-jitter_range, jitter_range).exp()
+ coords = coords * jitter_hw
+
+ # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale]
+ if rescale is not None:
+ rescale_range = np.log(rescale)
+ rescale_hw = torch.empty(1, device=coords.device, dtype=coords.dtype)
+ rescale_hw = rescale_hw.uniform_(-rescale_range, rescale_range).exp()
+ coords = coords * rescale_hw
+
+ return coords
+
+
+class DINOv3ViTRopePositionEmbedding(nn.Module):
+ inv_freq: torch.Tensor
+
+ def __init__(self, config: DINOv3ViTConfig):
+ super().__init__()
+
+ self.config = config
+ self.base = config.rope_theta
+ self.head_dim = config.hidden_size // config.num_attention_heads
+ self.num_patches_h = config.image_size // config.patch_size
+ self.num_patches_w = config.image_size // config.patch_size
+
+ inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32) # (head_dim / 4,)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ _, _, height, width = pixel_values.shape
+ num_patches_h = height // self.config.patch_size
+ num_patches_w = width // self.config.patch_size
+
+ device = pixel_values.device
+ device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
+
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ # Although we could precompute static patch_coords from image_size and patch_size in the config,
+ # the model was trained with random_scale, so it can process images of varying sizes.
+ # Therefore, it's better to compute patch_coords dynamically (with lru_cache).
+ patch_coords = get_patches_center_coordinates(
+ num_patches_h, num_patches_w, dtype=torch.float32, device=device
+ )
+ if self.training:
+ patch_coords = augment_patches_center_coordinates(
+ patch_coords,
+ shift=self.config.pos_embed_shift,
+ jitter=self.config.pos_embed_jitter,
+ rescale=self.config.pos_embed_rescale,
+ )
+
+ # (height * width, 2, head_dim / 4) -> (height * width, head_dim / 2) -> (height * width, head_dim)
+ angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
+ angles = angles.flatten(1, 2)
+ angles = angles.tile(2)
+
+ cos = torch.cos(angles)
+ sin = torch.sin(angles)
+
+ dtype = pixel_values.dtype
+ return cos.to(dtype=dtype), sin.to(dtype=dtype)
+
+
+def apply_rotary_pos_emb(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, **kwargs
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Applies Rotary Position Embedding to the query and key tensors, but only to the patch tokens,
+ ignoring the prefix tokens (cls token and register tokens).
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+
+ num_tokens = q.shape[-2]
+ num_patches = sin.shape[-2]
+ num_prefix_tokens = num_tokens - num_patches # cls token + register tokens
+
+ q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
+ k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
+
+ # apply rope only to patch tokens
+ q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
+ k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
+
+ q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
+ k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
+
+ return q, k
+
+
+class DINOv3ViTAttention(PixtralAttention):
+ def __init__(self, config: DINOv3ViTConfig):
+ super().__init__(config)
+
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.query_bias)
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.key_bias)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias)
+ self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, patches, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class DINOv3ViTLayerScale(Dinov2LayerScale):
+ pass
+
+
+class DINOv3ViTDropPath(Dinov2DropPath):
+ pass
+
+
+class DINOv3ViTMLP(ArceeMLP):
+ pass
+
+
+class DINOv3ViTGatedMLP(LlamaMLP):
+ pass
+
+
+class DINOv3ViTLayer(GradientCheckpointingLayer):
+ """This corresponds to the Block class in the original implementation."""
+
+ def __init__(self, config: DINOv3ViTConfig):
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attention = DINOv3ViTAttention(config)
+ self.layer_scale1 = DINOv3ViTLayerScale(config)
+ self.drop_path = DINOv3ViTDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ if config.use_gated_mlp:
+ self.mlp = DINOv3ViTGatedMLP(config)
+ else:
+ self.mlp = DINOv3ViTMLP(config)
+ self.layer_scale2 = DINOv3ViTLayerScale(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ # Attention with residual connection
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ hidden_states, _ = self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_embeddings=position_embeddings,
+ )
+ hidden_states = self.layer_scale1(hidden_states)
+ hidden_states = self.drop_path(hidden_states) + residual
+
+ # MLP with residual connection
+ residual = hidden_states
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.layer_scale2(hidden_states)
+ hidden_states = self.drop_path(hidden_states) + residual
+
+ return hidden_states
+
+
+@auto_docstring
+class DINOv3ViTPreTrainedModel(Dinov2PreTrainedModel):
+ _can_record_outputs = {
+ "hidden_states": DINOv3ViTLayer,
+ "attentions": DINOv3ViTAttention,
+ }
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, DINOv3ViTEmbeddings):
+ module.cls_token.data = nn.init.trunc_normal_(
+ module.cls_token.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.cls_token.dtype)
+ if module.config.num_register_tokens > 0:
+ module.register_tokens.data = nn.init.trunc_normal_(
+ module.register_tokens.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.register_tokens.dtype)
+ module.mask_token.data.zero_()
+ elif isinstance(module, DINOv3ViTLayerScale):
+ module.lambda1.data.fill_(self.config.layerscale_value)
+
+
+@auto_docstring
+class DINOv3ViTModel(DINOv3ViTPreTrainedModel):
+ def __init__(self, config: DINOv3ViTConfig):
+ super().__init__(config)
+ self.config = config
+ self.embeddings = DINOv3ViTEmbeddings(config)
+ self.rope_embeddings = DINOv3ViTRopePositionEmbedding(config)
+ self.layer = nn.ModuleList([DINOv3ViTLayer(config) for _ in range(config.num_hidden_layers)])
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
+ pre-training.
+ """
+
+ pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
+ hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
+ position_embeddings = self.rope_embeddings(pixel_values)
+
+ for i, layer_module in enumerate(self.layer):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ hidden_states = layer_module(
+ hidden_states,
+ attention_mask=layer_head_mask,
+ position_embeddings=position_embeddings,
+ )
+
+ sequence_output = self.norm(hidden_states)
+ pooled_output = sequence_output[:, 0, :]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ )
+
+
+__all__ = ["DINOv3ViTModel", "DINOv3ViTPreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dpr/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/dpr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9aeadbeaf416575570c280a3e15a52422a007103
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dpr/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_dpr import *
+ from .modeling_dpr import *
+ from .modeling_tf_dpr import *
+ from .tokenization_dpr import *
+ from .tokenization_dpr_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dpr/configuration_dpr.py b/venv/lib/python3.13/site-packages/transformers/models/dpr/configuration_dpr.py
new file mode 100644
index 0000000000000000000000000000000000000000..03b16900249329ad867ae6b13b58b89d7722a25a
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dpr/configuration_dpr.py
@@ -0,0 +1,131 @@
+# coding=utf-8
+# Copyright 2010, DPR authors, The Hugging Face Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DPR model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DPRConfig(PretrainedConfig):
+ r"""
+ [`DPRConfig`] is the configuration class to store the configuration of a *DPRModel*.
+
+ This is the configuration class to store the configuration of a [`DPRContextEncoder`], [`DPRQuestionEncoder`], or a
+ [`DPRReader`]. It is used to instantiate the components of the DPR model according to the specified arguments,
+ defining the model component architectures. Instantiating a configuration with the defaults will yield a similar
+ configuration to that of the DPRContextEncoder
+ [facebook/dpr-ctx_encoder-single-nq-base](https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base)
+ architecture.
+
+ This class is a subclass of [`BertConfig`]. Please check the superclass for the documentation of all kwargs.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the DPR model. Defines the different tokens that can be represented by the *inputs_ids*
+ passed to the forward method of [`BertModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the *token_type_ids* passed into [`BertModel`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
+ projection_dim (`int`, *optional*, defaults to 0):
+ Dimension of the projection for the context and question encoders. If it is set to zero (default), then no
+ projection is done.
+
+ Example:
+
+ ```python
+ >>> from transformers import DPRConfig, DPRContextEncoder
+
+ >>> # Initializing a DPR facebook/dpr-ctx_encoder-single-nq-base style configuration
+ >>> configuration = DPRConfig()
+
+ >>> # Initializing a model (with random weights) from the facebook/dpr-ctx_encoder-single-nq-base style configuration
+ >>> model = DPRContextEncoder(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "dpr"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=0,
+ position_embedding_type="absolute",
+ projection_dim: int = 0,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.projection_dim = projection_dim
+ self.position_embedding_type = position_embedding_type
+
+
+__all__ = ["DPRConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dpr/modeling_dpr.py b/venv/lib/python3.13/site-packages/transformers/models/dpr/modeling_dpr.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1ae00a02e07a0f3ee4c0ca064e7e9818568e605
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dpr/modeling_dpr.py
@@ -0,0 +1,592 @@
+# coding=utf-8
+# Copyright 2018 DPR Authors, The Hugging Face Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DPR model for Open Domain Question Answering."""
+
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import Tensor, nn
+
+from ...modeling_outputs import BaseModelOutputWithPooling
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ ModelOutput,
+ auto_docstring,
+ logging,
+)
+from ..bert.modeling_bert import BertModel
+from .configuration_dpr import DPRConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+##########
+# Outputs
+##########
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Class for outputs of [`DPRQuestionEncoder`].
+ """
+)
+class DPRContextEncoderOutput(ModelOutput):
+ r"""
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
+ The DPR encoder outputs the *pooler_output* that corresponds to the context representation. Last layer
+ hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
+ This output is to be used to embed contexts for nearest neighbors queries with questions embeddings.
+ """
+
+ pooler_output: torch.FloatTensor
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Class for outputs of [`DPRQuestionEncoder`].
+ """
+)
+class DPRQuestionEncoderOutput(ModelOutput):
+ r"""
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
+ The DPR encoder outputs the *pooler_output* that corresponds to the question representation. Last layer
+ hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
+ This output is to be used to embed questions for nearest neighbors queries with context embeddings.
+ """
+
+ pooler_output: torch.FloatTensor
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Class for outputs of [`DPRQuestionEncoder`].
+ """
+)
+class DPRReaderOutput(ModelOutput):
+ r"""
+ start_logits (`torch.FloatTensor` of shape `(n_passages, sequence_length)`):
+ Logits of the start index of the span for each passage.
+ end_logits (`torch.FloatTensor` of shape `(n_passages, sequence_length)`):
+ Logits of the end index of the span for each passage.
+ relevance_logits (`torch.FloatTensor` of shape `(n_passages, )`):
+ Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage to answer the
+ question, compared to all the other passages.
+ """
+
+ start_logits: torch.FloatTensor
+ end_logits: Optional[torch.FloatTensor] = None
+ relevance_logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@auto_docstring
+class DPRPreTrainedModel(PreTrainedModel):
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+class DPREncoder(DPRPreTrainedModel):
+ base_model_prefix = "bert_model"
+
+ def __init__(self, config: DPRConfig):
+ super().__init__(config)
+ self.bert_model = BertModel(config, add_pooling_layer=False)
+ if self.bert_model.config.hidden_size <= 0:
+ raise ValueError("Encoder hidden_size can't be zero")
+ self.projection_dim = config.projection_dim
+ if self.projection_dim > 0:
+ self.encode_proj = nn.Linear(self.bert_model.config.hidden_size, config.projection_dim)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Tensor,
+ attention_mask: Optional[Tensor] = None,
+ token_type_ids: Optional[Tensor] = None,
+ inputs_embeds: Optional[Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = False,
+ ) -> Union[BaseModelOutputWithPooling, tuple[Tensor, ...]]:
+ outputs = self.bert_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ pooled_output = sequence_output[:, 0, :]
+
+ if self.projection_dim > 0:
+ pooled_output = self.encode_proj(pooled_output)
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + outputs[2:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ @property
+ def embeddings_size(self) -> int:
+ if self.projection_dim > 0:
+ return self.encode_proj.out_features
+ return self.bert_model.config.hidden_size
+
+
+class DPRSpanPredictor(DPRPreTrainedModel):
+ base_model_prefix = "encoder"
+
+ def __init__(self, config: DPRConfig):
+ super().__init__(config)
+ self.encoder = DPREncoder(config)
+ self.qa_outputs = nn.Linear(self.encoder.embeddings_size, 2)
+ self.qa_classifier = nn.Linear(self.encoder.embeddings_size, 1)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Tensor,
+ attention_mask: Tensor,
+ inputs_embeds: Optional[Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = False,
+ ) -> Union[DPRReaderOutput, tuple[Tensor, ...]]:
+ # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
+ n_passages, sequence_length = input_ids.size() if input_ids is not None else inputs_embeds.size()[:2]
+ # feed encoder
+ outputs = self.encoder(
+ input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+
+ # compute logits
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+ relevance_logits = self.qa_classifier(sequence_output[:, 0, :])
+
+ # resize
+ start_logits = start_logits.view(n_passages, sequence_length)
+ end_logits = end_logits.view(n_passages, sequence_length)
+ relevance_logits = relevance_logits.view(n_passages)
+
+ if not return_dict:
+ return (start_logits, end_logits, relevance_logits) + outputs[2:]
+
+ return DPRReaderOutput(
+ start_logits=start_logits,
+ end_logits=end_logits,
+ relevance_logits=relevance_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+##################
+# PreTrainedModel
+##################
+
+
+class DPRPretrainedContextEncoder(DPRPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config: DPRConfig
+ load_tf_weights = None
+ base_model_prefix = "ctx_encoder"
+
+
+class DPRPretrainedQuestionEncoder(DPRPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config: DPRConfig
+ load_tf_weights = None
+ base_model_prefix = "question_encoder"
+
+
+class DPRPretrainedReader(DPRPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config: DPRConfig
+ load_tf_weights = None
+ base_model_prefix = "span_predictor"
+
+
+###############
+# Actual Models
+###############
+
+
+@auto_docstring(
+ custom_intro="""
+ The bare DPRContextEncoder transformer outputting pooler outputs as context representations.
+ """
+)
+class DPRContextEncoder(DPRPretrainedContextEncoder):
+ def __init__(self, config: DPRConfig):
+ super().__init__(config)
+ self.config = config
+ self.ctx_encoder = DPREncoder(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[Tensor] = None,
+ attention_mask: Optional[Tensor] = None,
+ token_type_ids: Optional[Tensor] = None,
+ inputs_embeds: Optional[Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[DPRContextEncoderOutput, tuple[Tensor, ...]]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be
+ formatted with [CLS] and [SEP] tokens as follows:
+
+ (a) For sequence pairs (for a pair title+text for example):
+
+ ```
+ tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
+ token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
+ ```
+
+ (b) For single sequences (for a question for example):
+
+ ```
+ tokens: [CLS] the dog is hairy . [SEP]
+ token_type_ids: 0 0 0 0 0 0 0
+ ```
+
+ DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
+ rather than the left.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ Examples:
+
+ ```python
+ >>> from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
+
+ >>> tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
+ >>> model = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
+ >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="pt")["input_ids"]
+ >>> embeddings = model(input_ids).pooler_output
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = (
+ torch.ones(input_shape, device=device)
+ if input_ids is None
+ else (input_ids != self.config.pad_token_id)
+ )
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ outputs = self.ctx_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return outputs[1:]
+ return DPRContextEncoderOutput(
+ pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.
+ """
+)
+class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
+ def __init__(self, config: DPRConfig):
+ super().__init__(config)
+ self.config = config
+ self.question_encoder = DPREncoder(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[Tensor] = None,
+ attention_mask: Optional[Tensor] = None,
+ token_type_ids: Optional[Tensor] = None,
+ inputs_embeds: Optional[Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[DPRQuestionEncoderOutput, tuple[Tensor, ...]]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be
+ formatted with [CLS] and [SEP] tokens as follows:
+
+ (a) For sequence pairs (for a pair title+text for example):
+
+ ```
+ tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
+ token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
+ ```
+
+ (b) For single sequences (for a question for example):
+
+ ```
+ tokens: [CLS] the dog is hairy . [SEP]
+ token_type_ids: 0 0 0 0 0 0 0
+ ```
+
+ DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
+ rather than the left.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ Examples:
+
+ ```python
+ >>> from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
+
+ >>> tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
+ >>> model = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
+ >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="pt")["input_ids"]
+ >>> embeddings = model(input_ids).pooler_output
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = (
+ torch.ones(input_shape, device=device)
+ if input_ids is None
+ else (input_ids != self.config.pad_token_id)
+ )
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ outputs = self.question_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return outputs[1:]
+ return DPRQuestionEncoderOutput(
+ pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The bare DPRReader transformer outputting span predictions.
+ """
+)
+class DPRReader(DPRPretrainedReader):
+ def __init__(self, config: DPRConfig):
+ super().__init__(config)
+ self.config = config
+ self.span_predictor = DPRSpanPredictor(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[Tensor] = None,
+ attention_mask: Optional[Tensor] = None,
+ inputs_embeds: Optional[Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[DPRReaderOutput, tuple[Tensor, ...]]:
+ r"""
+ input_ids (`tuple[torch.LongTensor]` of shapes `(n_passages, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. It has to be a sequence triplet with 1) the question
+ and 2) the passages titles and 3) the passages texts To match pretraining, DPR `input_ids` sequence should
+ be formatted with [CLS] and [SEP] with the format:
+
+ `[CLS] [SEP] [SEP] `
+
+ DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
+ rather than the left.
+
+ Indices can be obtained using [`DPRReaderTokenizer`]. See this class documentation for more details.
+
+ [What are input IDs?](../glossary#input-ids)
+ inputs_embeds (`torch.FloatTensor` of shape `(n_passages, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+
+ Examples:
+
+ ```python
+ >>> from transformers import DPRReader, DPRReaderTokenizer
+
+ >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base")
+ >>> model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base")
+ >>> encoded_inputs = tokenizer(
+ ... questions=["What is love ?"],
+ ... titles=["Haddaway"],
+ ... texts=["'What Is Love' is a song recorded by the artist Haddaway"],
+ ... return_tensors="pt",
+ ... )
+ >>> outputs = model(**encoded_inputs)
+ >>> start_logits = outputs.start_logits
+ >>> end_logits = outputs.end_logits
+ >>> relevance_logits = outputs.relevance_logits
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+
+ return self.span_predictor(
+ input_ids,
+ attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+__all__ = [
+ "DPRContextEncoder",
+ "DPRPretrainedContextEncoder",
+ "DPRPreTrainedModel",
+ "DPRPretrainedQuestionEncoder",
+ "DPRPretrainedReader",
+ "DPRQuestionEncoder",
+ "DPRReader",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dpr/modeling_tf_dpr.py b/venv/lib/python3.13/site-packages/transformers/models/dpr/modeling_tf_dpr.py
new file mode 100644
index 0000000000000000000000000000000000000000..aef83e6c55fbe27ea57e48bf2baca515999010cb
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dpr/modeling_tf_dpr.py
@@ -0,0 +1,799 @@
+# coding=utf-8
+# Copyright 2018 DPR Authors, The Hugging Face Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""TensorFlow DPR model for Open Domain Question Answering."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+import tensorflow as tf
+
+from ...modeling_tf_outputs import TFBaseModelOutputWithPooling
+from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, get_initializer, keras, shape_list, unpack_inputs
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from ..bert.modeling_tf_bert import TFBertMainLayer
+from .configuration_dpr import DPRConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "DPRConfig"
+
+
+##########
+# Outputs
+##########
+
+
+@dataclass
+class TFDPRContextEncoderOutput(ModelOutput):
+ r"""
+ Class for outputs of [`TFDPRContextEncoder`].
+
+ Args:
+ pooler_output (`tf.Tensor` of shape `(batch_size, embeddings_size)`):
+ The DPR encoder outputs the *pooler_output* that corresponds to the context representation. Last layer
+ hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
+ This output is to be used to embed contexts for nearest neighbors queries with questions embeddings.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ pooler_output: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor, ...] | None = None
+ attentions: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFDPRQuestionEncoderOutput(ModelOutput):
+ """
+ Class for outputs of [`TFDPRQuestionEncoder`].
+
+ Args:
+ pooler_output (`tf.Tensor` of shape `(batch_size, embeddings_size)`):
+ The DPR encoder outputs the *pooler_output* that corresponds to the question representation. Last layer
+ hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
+ This output is to be used to embed questions for nearest neighbors queries with context embeddings.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ pooler_output: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor, ...] | None = None
+ attentions: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFDPRReaderOutput(ModelOutput):
+ """
+ Class for outputs of [`TFDPRReaderEncoder`].
+
+ Args:
+ start_logits (`tf.Tensor` of shape `(n_passages, sequence_length)`):
+ Logits of the start index of the span for each passage.
+ end_logits (`tf.Tensor` of shape `(n_passages, sequence_length)`):
+ Logits of the end index of the span for each passage.
+ relevance_logits (`tf.Tensor` of shape `(n_passages, )`):
+ Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage to answer the
+ question, compared to all the other passages.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ start_logits: tf.Tensor | None = None
+ end_logits: tf.Tensor | None = None
+ relevance_logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor, ...] | None = None
+ attentions: tuple[tf.Tensor, ...] | None = None
+
+
+class TFDPREncoderLayer(keras.layers.Layer):
+ base_model_prefix = "bert_model"
+
+ def __init__(self, config: DPRConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ # resolve name conflict with TFBertMainLayer instead of TFBertModel
+ self.bert_model = TFBertMainLayer(config, add_pooling_layer=False, name="bert_model")
+ self.config = config
+
+ if self.config.hidden_size <= 0:
+ raise ValueError("Encoder hidden_size can't be zero")
+ self.projection_dim = config.projection_dim
+ if self.projection_dim > 0:
+ self.encode_proj = keras.layers.Dense(
+ config.projection_dim, kernel_initializer=get_initializer(config.initializer_range), name="encode_proj"
+ )
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: tf.Tensor | None = None,
+ attention_mask: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor, ...]:
+ outputs = self.bert_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+ pooled_output = sequence_output[:, 0, :]
+ if self.projection_dim > 0:
+ pooled_output = self.encode_proj(pooled_output)
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + outputs[1:]
+
+ return TFBaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ @property
+ def embeddings_size(self) -> int:
+ if self.projection_dim > 0:
+ return self.projection_dim
+ return self.bert_model.config.hidden_size
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "bert_model", None) is not None:
+ with tf.name_scope(self.bert_model.name):
+ self.bert_model.build(None)
+ if getattr(self, "encode_proj", None) is not None:
+ with tf.name_scope(self.encode_proj.name):
+ self.encode_proj.build(None)
+
+
+class TFDPRSpanPredictorLayer(keras.layers.Layer):
+ base_model_prefix = "encoder"
+
+ def __init__(self, config: DPRConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.encoder = TFDPREncoderLayer(config, name="encoder")
+
+ self.qa_outputs = keras.layers.Dense(
+ 2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
+ )
+ self.qa_classifier = keras.layers.Dense(
+ 1, kernel_initializer=get_initializer(config.initializer_range), name="qa_classifier"
+ )
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: tf.Tensor | None = None,
+ attention_mask: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = False,
+ training: bool = False,
+ ) -> TFDPRReaderOutput | tuple[tf.Tensor, ...]:
+ # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
+ n_passages, sequence_length = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)[:2]
+ # feed encoder
+ outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+
+ # compute logits
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = tf.split(logits, 2, axis=-1)
+ start_logits = tf.squeeze(start_logits, axis=-1)
+ end_logits = tf.squeeze(end_logits, axis=-1)
+ relevance_logits = self.qa_classifier(sequence_output[:, 0, :])
+
+ # resize
+ start_logits = tf.reshape(start_logits, [n_passages, sequence_length])
+ end_logits = tf.reshape(end_logits, [n_passages, sequence_length])
+ relevance_logits = tf.reshape(relevance_logits, [n_passages])
+
+ if not return_dict:
+ return (start_logits, end_logits, relevance_logits) + outputs[2:]
+
+ return TFDPRReaderOutput(
+ start_logits=start_logits,
+ end_logits=end_logits,
+ relevance_logits=relevance_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "qa_outputs", None) is not None:
+ with tf.name_scope(self.qa_outputs.name):
+ self.qa_outputs.build([None, None, self.encoder.embeddings_size])
+ if getattr(self, "qa_classifier", None) is not None:
+ with tf.name_scope(self.qa_classifier.name):
+ self.qa_classifier.build([None, None, self.encoder.embeddings_size])
+
+
+class TFDPRSpanPredictor(TFPreTrainedModel):
+ base_model_prefix = "encoder"
+
+ def __init__(self, config: DPRConfig, **kwargs):
+ super().__init__(config, **kwargs)
+ self.encoder = TFDPRSpanPredictorLayer(config)
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: tf.Tensor | None = None,
+ attention_mask: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = False,
+ training: bool = False,
+ ) -> TFDPRReaderOutput | tuple[tf.Tensor, ...]:
+ outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+
+class TFDPREncoder(TFPreTrainedModel):
+ base_model_prefix = "encoder"
+
+ def __init__(self, config: DPRConfig, **kwargs):
+ super().__init__(config, **kwargs)
+
+ self.encoder = TFDPREncoderLayer(config)
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: tf.Tensor | None = None,
+ attention_mask: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = False,
+ training: bool = False,
+ ) -> TFDPRReaderOutput | tuple[tf.Tensor, ...]:
+ outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ return outputs
+
+
+##################
+# PreTrainedModel
+##################
+
+
+class TFDPRPretrainedContextEncoder(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DPRConfig
+ base_model_prefix = "ctx_encoder"
+
+
+class TFDPRPretrainedQuestionEncoder(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DPRConfig
+ base_model_prefix = "question_encoder"
+
+
+class TFDPRPretrainedReader(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DPRConfig
+ base_model_prefix = "reader"
+
+
+###############
+# Actual Models
+###############
+
+
+TF_DPR_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Tensorflow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
+ subclass. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to
+ general usage and behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`DPRConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+TF_DPR_ENCODERS_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be
+ formatted with [CLS] and [SEP] tokens as follows:
+
+ (a) For sequence pairs (for a pair title+text for example):
+
+ ```
+ tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
+ token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
+ ```
+
+ (b) For single sequences (for a question for example):
+
+ ```
+ tokens: [CLS] the dog is hairy . [SEP]
+ token_type_ids: 0 0 0 0 0 0 0
+ ```
+
+ DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
+ rather than the left.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ inputs_embeds (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+TF_DPR_READER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Numpy array` or `tf.Tensor` of shapes `(n_passages, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. It has to be a sequence triplet with 1) the question
+ and 2) the passages titles and 3) the passages texts To match pretraining, DPR `input_ids` sequence should
+ be formatted with [CLS] and [SEP] with the format:
+
+ `[CLS] [SEP] [SEP] `
+
+ DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
+ rather than the left.
+
+ Indices can be obtained using [`DPRReaderTokenizer`]. See this class documentation for more details.
+ attention_mask (`Numpy array` or `tf.Tensor` of shape `(n_passages, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ inputs_embeds (`Numpy array` or `tf.Tensor` of shape `(n_passages, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare DPRContextEncoder transformer outputting pooler outputs as context representations.",
+ TF_DPR_START_DOCSTRING,
+)
+class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
+ def __init__(self, config: DPRConfig, *args, **kwargs):
+ super().__init__(config, *args, **kwargs)
+ self.ctx_encoder = TFDPREncoderLayer(config, name="ctx_encoder")
+
+ def get_input_embeddings(self):
+ try:
+ return self.ctx_encoder.bert_model.get_input_embeddings()
+ except AttributeError:
+ self.build()
+ return self.ctx_encoder.bert_model.get_input_embeddings()
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFDPRContextEncoderOutput | tuple[tf.Tensor, ...]:
+ r"""
+ Return:
+
+ Examples:
+
+ ```python
+ >>> from transformers import TFDPRContextEncoder, DPRContextEncoderTokenizer
+
+ >>> tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
+ >>> model = TFDPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", from_pt=True)
+ >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="tf")["input_ids"]
+ >>> embeddings = model(input_ids).pooler_output
+ ```
+ """
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if attention_mask is None:
+ attention_mask = (
+ tf.ones(input_shape, dtype=tf.dtypes.int32)
+ if input_ids is None
+ else (input_ids != self.config.pad_token_id)
+ )
+ if token_type_ids is None:
+ token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
+
+ outputs = self.ctx_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return outputs[1:]
+
+ return TFDPRContextEncoderOutput(
+ pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "ctx_encoder", None) is not None:
+ with tf.name_scope(self.ctx_encoder.name):
+ self.ctx_encoder.build(None)
+
+
+@add_start_docstrings(
+ "The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.",
+ TF_DPR_START_DOCSTRING,
+)
+class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
+ def __init__(self, config: DPRConfig, *args, **kwargs):
+ super().__init__(config, *args, **kwargs)
+ self.question_encoder = TFDPREncoderLayer(config, name="question_encoder")
+
+ def get_input_embeddings(self):
+ try:
+ return self.question_encoder.bert_model.get_input_embeddings()
+ except AttributeError:
+ self.build()
+ return self.question_encoder.bert_model.get_input_embeddings()
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFDPRQuestionEncoderOutput | tuple[tf.Tensor, ...]:
+ r"""
+ Return:
+
+ Examples:
+
+ ```python
+ >>> from transformers import TFDPRQuestionEncoder, DPRQuestionEncoderTokenizer
+
+ >>> tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
+ >>> model = TFDPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base", from_pt=True)
+ >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="tf")["input_ids"]
+ >>> embeddings = model(input_ids).pooler_output
+ ```
+ """
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if attention_mask is None:
+ attention_mask = (
+ tf.ones(input_shape, dtype=tf.dtypes.int32)
+ if input_ids is None
+ else (input_ids != self.config.pad_token_id)
+ )
+ if token_type_ids is None:
+ token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
+
+ outputs = self.question_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return outputs[1:]
+ return TFDPRQuestionEncoderOutput(
+ pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "question_encoder", None) is not None:
+ with tf.name_scope(self.question_encoder.name):
+ self.question_encoder.build(None)
+
+
+@add_start_docstrings(
+ "The bare DPRReader transformer outputting span predictions.",
+ TF_DPR_START_DOCSTRING,
+)
+class TFDPRReader(TFDPRPretrainedReader):
+ def __init__(self, config: DPRConfig, *args, **kwargs):
+ super().__init__(config, *args, **kwargs)
+ self.span_predictor = TFDPRSpanPredictorLayer(config, name="span_predictor")
+
+ def get_input_embeddings(self):
+ try:
+ return self.span_predictor.encoder.bert_model.get_input_embeddings()
+ except AttributeError:
+ self.build()
+ return self.span_predictor.encoder.bert_model.get_input_embeddings()
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(TF_DPR_READER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFDPRReaderOutput | tuple[tf.Tensor, ...]:
+ r"""
+ Return:
+
+ Examples:
+
+ ```python
+ >>> from transformers import TFDPRReader, DPRReaderTokenizer
+
+ >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base")
+ >>> model = TFDPRReader.from_pretrained("facebook/dpr-reader-single-nq-base", from_pt=True)
+ >>> encoded_inputs = tokenizer(
+ ... questions=["What is love ?"],
+ ... titles=["Haddaway"],
+ ... texts=["'What Is Love' is a song recorded by the artist Haddaway"],
+ ... return_tensors="tf",
+ ... )
+ >>> outputs = model(encoded_inputs)
+ >>> start_logits = outputs.start_logits
+ >>> end_logits = outputs.end_logits
+ >>> relevance_logits = outputs.relevance_logits
+ ```
+ """
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if attention_mask is None:
+ attention_mask = tf.ones(input_shape, dtype=tf.dtypes.int32)
+
+ return self.span_predictor(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "span_predictor", None) is not None:
+ with tf.name_scope(self.span_predictor.name):
+ self.span_predictor.build(None)
+
+
+__all__ = [
+ "TFDPRContextEncoder",
+ "TFDPRPretrainedContextEncoder",
+ "TFDPRPretrainedQuestionEncoder",
+ "TFDPRPretrainedReader",
+ "TFDPRQuestionEncoder",
+ "TFDPRReader",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dpr/tokenization_dpr.py b/venv/lib/python3.13/site-packages/transformers/models/dpr/tokenization_dpr.py
new file mode 100644
index 0000000000000000000000000000000000000000..020b235cb6bd97bda74f2e067294a9391617e00f
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dpr/tokenization_dpr.py
@@ -0,0 +1,321 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team, The Hugging Face Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for DPR."""
+
+import collections
+from typing import Optional, Union
+
+from ...tokenization_utils_base import BatchEncoding
+from ...utils import TensorType, add_end_docstrings, add_start_docstrings, logging
+from ..bert.tokenization_bert import BertTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
+
+
+class DPRContextEncoderTokenizer(BertTokenizer):
+ r"""
+ Construct a DPRContextEncoder tokenizer.
+
+ [`DPRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
+ splitting and wordpiece.
+
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+
+
+class DPRQuestionEncoderTokenizer(BertTokenizer):
+ r"""
+ Constructs a DPRQuestionEncoder tokenizer.
+
+ [`DPRQuestionEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
+ splitting and wordpiece.
+
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+
+
+DPRSpanPrediction = collections.namedtuple(
+ "DPRSpanPrediction", ["span_score", "relevance_score", "doc_id", "start_index", "end_index", "text"]
+)
+
+DPRReaderOutput = collections.namedtuple("DPRReaderOutput", ["start_logits", "end_logits", "relevance_logits"])
+
+
+CUSTOM_DPR_READER_DOCSTRING = r"""
+ Return a dictionary with the token ids of the input strings and other information to give to `.decode_best_spans`.
+ It converts the strings of a question and different passages (title and text) in a sequence of IDs (integers),
+ using the tokenizer and vocabulary. The resulting `input_ids` is a matrix of size `(n_passages, sequence_length)`
+ with the format:
+
+ ```
+ [CLS] [SEP] [SEP]
+ ```
+
+ Args:
+ questions (`str` or `list[str]`):
+ The questions to be encoded. You can specify one question for many passages. In this case, the question
+ will be duplicated like `[questions] * n_passages`. Otherwise you have to specify as many questions as in
+ `titles` or `texts`.
+ titles (`str` or `list[str]`):
+ The passages titles to be encoded. This can be a string or a list of strings if there are several passages.
+ texts (`str` or `list[str]`):
+ The passages texts to be encoded. This can be a string or a list of strings if there are several passages.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
+ if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ Activates and controls truncation. Accepts the following values:
+
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to
+ the maximum acceptable input length for the model if that argument is not provided. This will truncate
+ token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch
+ of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided. This will only truncate the first
+ sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided. This will only truncate the
+ second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters.
+
+ If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
+ is required by one of the truncation/padding parameters. If the model has no specific maximum input
+ length (like XLNet) truncation/padding to a maximum length will be deactivated.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ return_attention_mask (`bool`, *optional*):
+ Whether or not to return the attention mask. If not set, will return the attention mask according to the
+ specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Returns:
+ `dict[str, list[list[int]]]`: A dictionary with the following keys:
+
+ - `input_ids`: List of token ids to be fed to a model.
+ - `attention_mask`: List of indices specifying which tokens should be attended to by the model.
+ """
+
+
+@add_start_docstrings(CUSTOM_DPR_READER_DOCSTRING)
+class CustomDPRReaderTokenizerMixin:
+ def __call__(
+ self,
+ questions,
+ titles: Optional[str] = None,
+ texts: Optional[str] = None,
+ padding: Union[bool, str] = False,
+ truncation: Union[bool, str] = False,
+ max_length: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_attention_mask: Optional[bool] = None,
+ **kwargs,
+ ) -> BatchEncoding:
+ if titles is None and texts is None:
+ return super().__call__(
+ questions,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ return_tensors=return_tensors,
+ return_attention_mask=return_attention_mask,
+ **kwargs,
+ )
+ elif titles is None or texts is None:
+ text_pair = titles if texts is None else texts
+ return super().__call__(
+ questions,
+ text_pair,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ return_tensors=return_tensors,
+ return_attention_mask=return_attention_mask,
+ **kwargs,
+ )
+ titles = titles if not isinstance(titles, str) else [titles]
+ texts = texts if not isinstance(texts, str) else [texts]
+ n_passages = len(titles)
+ questions = questions if not isinstance(questions, str) else [questions] * n_passages
+ if len(titles) != len(texts):
+ raise ValueError(
+ f"There should be as many titles than texts but got {len(titles)} titles and {len(texts)} texts."
+ )
+ encoded_question_and_titles = super().__call__(questions, titles, padding=False, truncation=False)["input_ids"]
+ encoded_texts = super().__call__(texts, add_special_tokens=False, padding=False, truncation=False)["input_ids"]
+ encoded_inputs = {
+ "input_ids": [
+ (encoded_question_and_title + encoded_text)[:max_length]
+ if max_length is not None and truncation
+ else encoded_question_and_title + encoded_text
+ for encoded_question_and_title, encoded_text in zip(encoded_question_and_titles, encoded_texts)
+ ]
+ }
+ if return_attention_mask is not False:
+ attention_mask = []
+ for input_ids in encoded_inputs["input_ids"]:
+ attention_mask.append([int(input_id != self.pad_token_id) for input_id in input_ids])
+ encoded_inputs["attention_mask"] = attention_mask
+ return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors)
+
+ def decode_best_spans(
+ self,
+ reader_input: BatchEncoding,
+ reader_output: DPRReaderOutput,
+ num_spans: int = 16,
+ max_answer_length: int = 64,
+ num_spans_per_passage: int = 4,
+ ) -> list[DPRSpanPrediction]:
+ """
+ Get the span predictions for the extractive Q&A model.
+
+ Returns: *List* of *DPRReaderOutput* sorted by descending *(relevance_score, span_score)*. Each
+ *DPRReaderOutput* is a *Tuple* with:
+
+ - **span_score**: `float` that corresponds to the score given by the reader for this span compared to other
+ spans in the same passage. It corresponds to the sum of the start and end logits of the span.
+ - **relevance_score**: `float` that corresponds to the score of the each passage to answer the question,
+ compared to all the other passages. It corresponds to the output of the QA classifier of the DPRReader.
+ - **doc_id**: `int` the id of the passage. - **start_index**: `int` the start index of the span
+ (inclusive). - **end_index**: `int` the end index of the span (inclusive).
+
+ Examples:
+
+ ```python
+ >>> from transformers import DPRReader, DPRReaderTokenizer
+
+ >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base")
+ >>> model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base")
+ >>> encoded_inputs = tokenizer(
+ ... questions=["What is love ?"],
+ ... titles=["Haddaway"],
+ ... texts=["'What Is Love' is a song recorded by the artist Haddaway"],
+ ... return_tensors="pt",
+ ... )
+ >>> outputs = model(**encoded_inputs)
+ >>> predicted_spans = tokenizer.decode_best_spans(encoded_inputs, outputs)
+ >>> print(predicted_spans[0].text) # best span
+ a song
+ ```"""
+ input_ids = reader_input["input_ids"]
+ start_logits, end_logits, relevance_logits = reader_output[:3]
+ n_passages = len(relevance_logits)
+ sorted_docs = sorted(range(n_passages), reverse=True, key=relevance_logits.__getitem__)
+ nbest_spans_predictions: list[DPRReaderOutput] = []
+ for doc_id in sorted_docs:
+ sequence_ids = list(input_ids[doc_id])
+ # assuming question & title information is at the beginning of the sequence
+ passage_offset = sequence_ids.index(self.sep_token_id, 2) + 1 # second sep id
+ if sequence_ids[-1] == self.pad_token_id:
+ sequence_len = sequence_ids.index(self.pad_token_id)
+ else:
+ sequence_len = len(sequence_ids)
+
+ best_spans = self._get_best_spans(
+ start_logits=start_logits[doc_id][passage_offset:sequence_len],
+ end_logits=end_logits[doc_id][passage_offset:sequence_len],
+ max_answer_length=max_answer_length,
+ top_spans=num_spans_per_passage,
+ )
+ for start_index, end_index in best_spans:
+ start_index += passage_offset
+ end_index += passage_offset
+ nbest_spans_predictions.append(
+ DPRSpanPrediction(
+ span_score=start_logits[doc_id][start_index] + end_logits[doc_id][end_index],
+ relevance_score=relevance_logits[doc_id],
+ doc_id=doc_id,
+ start_index=start_index,
+ end_index=end_index,
+ text=self.decode(sequence_ids[start_index : end_index + 1]),
+ )
+ )
+ if len(nbest_spans_predictions) >= num_spans:
+ break
+ return nbest_spans_predictions[:num_spans]
+
+ def _get_best_spans(
+ self,
+ start_logits: list[int],
+ end_logits: list[int],
+ max_answer_length: int,
+ top_spans: int,
+ ) -> list[DPRSpanPrediction]:
+ """
+ Finds the best answer span for the extractive Q&A model for one passage. It returns the best span by descending
+ `span_score` order and keeping max `top_spans` spans. Spans longer that `max_answer_length` are ignored.
+ """
+ scores = []
+ for start_index, start_score in enumerate(start_logits):
+ for answer_length, end_score in enumerate(end_logits[start_index : start_index + max_answer_length]):
+ scores.append(((start_index, start_index + answer_length), start_score + end_score))
+ scores = sorted(scores, key=lambda x: x[1], reverse=True)
+ chosen_span_intervals = []
+ for (start_index, end_index), score in scores:
+ if start_index > end_index:
+ raise ValueError(f"Wrong span indices: [{start_index}:{end_index}]")
+ length = end_index - start_index + 1
+ if length > max_answer_length:
+ raise ValueError(f"Span is too long: {length} > {max_answer_length}")
+ if any(
+ start_index <= prev_start_index <= prev_end_index <= end_index
+ or prev_start_index <= start_index <= end_index <= prev_end_index
+ for (prev_start_index, prev_end_index) in chosen_span_intervals
+ ):
+ continue
+ chosen_span_intervals.append((start_index, end_index))
+
+ if len(chosen_span_intervals) == top_spans:
+ break
+ return chosen_span_intervals
+
+
+@add_end_docstrings(CUSTOM_DPR_READER_DOCSTRING)
+class DPRReaderTokenizer(CustomDPRReaderTokenizerMixin, BertTokenizer):
+ r"""
+ Construct a DPRReader tokenizer.
+
+ [`DPRReaderTokenizer`] is almost identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
+ splitting and wordpiece. The difference is that is has three inputs strings: question, titles and texts that are
+ combined to be fed to the [`DPRReader`] model.
+
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+
+__all__ = ["DPRContextEncoderTokenizer", "DPRQuestionEncoderTokenizer", "DPRReaderOutput", "DPRReaderTokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dpr/tokenization_dpr_fast.py b/venv/lib/python3.13/site-packages/transformers/models/dpr/tokenization_dpr_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbf745291745c3ac29472391822b09ba68d933a4
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dpr/tokenization_dpr_fast.py
@@ -0,0 +1,321 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team, The Hugging Face Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for DPR."""
+
+import collections
+from typing import Optional, Union
+
+from ...tokenization_utils_base import BatchEncoding
+from ...utils import TensorType, add_end_docstrings, add_start_docstrings, logging
+from ..bert.tokenization_bert_fast import BertTokenizerFast
+from .tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer, DPRReaderTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
+
+
+class DPRContextEncoderTokenizerFast(BertTokenizerFast):
+ r"""
+ Construct a "fast" DPRContextEncoder tokenizer (backed by HuggingFace's *tokenizers* library).
+
+ [`DPRContextEncoderTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization:
+ punctuation splitting and wordpiece.
+
+ Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = DPRContextEncoderTokenizer
+
+
+class DPRQuestionEncoderTokenizerFast(BertTokenizerFast):
+ r"""
+ Constructs a "fast" DPRQuestionEncoder tokenizer (backed by HuggingFace's *tokenizers* library).
+
+ [`DPRQuestionEncoderTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization:
+ punctuation splitting and wordpiece.
+
+ Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = DPRQuestionEncoderTokenizer
+
+
+DPRSpanPrediction = collections.namedtuple(
+ "DPRSpanPrediction", ["span_score", "relevance_score", "doc_id", "start_index", "end_index", "text"]
+)
+
+DPRReaderOutput = collections.namedtuple("DPRReaderOutput", ["start_logits", "end_logits", "relevance_logits"])
+
+
+CUSTOM_DPR_READER_DOCSTRING = r"""
+ Return a dictionary with the token ids of the input strings and other information to give to `.decode_best_spans`.
+ It converts the strings of a question and different passages (title and text) in a sequence of IDs (integers),
+ using the tokenizer and vocabulary. The resulting `input_ids` is a matrix of size `(n_passages, sequence_length)`
+ with the format:
+
+ [CLS] [SEP] [SEP]
+
+ Args:
+ questions (`str` or `list[str]`):
+ The questions to be encoded. You can specify one question for many passages. In this case, the question
+ will be duplicated like `[questions] * n_passages`. Otherwise you have to specify as many questions as in
+ `titles` or `texts`.
+ titles (`str` or `list[str]`):
+ The passages titles to be encoded. This can be a string or a list of strings if there are several passages.
+ texts (`str` or `list[str]`):
+ The passages texts to be encoded. This can be a string or a list of strings if there are several passages.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
+ if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ Activates and controls truncation. Accepts the following values:
+
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to
+ the maximum acceptable input length for the model if that argument is not provided. This will truncate
+ token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch
+ of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided. This will only truncate the first
+ sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided. This will only truncate the
+ second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters.
+
+ If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
+ is required by one of the truncation/padding parameters. If the model has no specific maximum input
+ length (like XLNet) truncation/padding to a maximum length will be deactivated.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ return_attention_mask (`bool`, *optional*):
+ Whether or not to return the attention mask. If not set, will return the attention mask according to the
+ specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Return:
+ `dict[str, list[list[int]]]`: A dictionary with the following keys:
+
+ - `input_ids`: List of token ids to be fed to a model.
+ - `attention_mask`: List of indices specifying which tokens should be attended to by the model.
+ """
+
+
+@add_start_docstrings(CUSTOM_DPR_READER_DOCSTRING)
+class CustomDPRReaderTokenizerMixin:
+ def __call__(
+ self,
+ questions,
+ titles: Optional[str] = None,
+ texts: Optional[str] = None,
+ padding: Union[bool, str] = False,
+ truncation: Union[bool, str] = False,
+ max_length: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_attention_mask: Optional[bool] = None,
+ **kwargs,
+ ) -> BatchEncoding:
+ if titles is None and texts is None:
+ return super().__call__(
+ questions,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ return_tensors=return_tensors,
+ return_attention_mask=return_attention_mask,
+ **kwargs,
+ )
+ elif titles is None or texts is None:
+ text_pair = titles if texts is None else texts
+ return super().__call__(
+ questions,
+ text_pair,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ return_tensors=return_tensors,
+ return_attention_mask=return_attention_mask,
+ **kwargs,
+ )
+ titles = titles if not isinstance(titles, str) else [titles]
+ texts = texts if not isinstance(texts, str) else [texts]
+ n_passages = len(titles)
+ questions = questions if not isinstance(questions, str) else [questions] * n_passages
+ assert len(titles) == len(texts), (
+ f"There should be as many titles than texts but got {len(titles)} titles and {len(texts)} texts."
+ )
+ encoded_question_and_titles = super().__call__(questions, titles, padding=False, truncation=False)["input_ids"]
+ encoded_texts = super().__call__(texts, add_special_tokens=False, padding=False, truncation=False)["input_ids"]
+ encoded_inputs = {
+ "input_ids": [
+ (encoded_question_and_title + encoded_text)[:max_length]
+ if max_length is not None and truncation
+ else encoded_question_and_title + encoded_text
+ for encoded_question_and_title, encoded_text in zip(encoded_question_and_titles, encoded_texts)
+ ]
+ }
+ if return_attention_mask is not False:
+ attention_mask = []
+ for input_ids in encoded_inputs["input_ids"]:
+ attention_mask.append([int(input_id != self.pad_token_id) for input_id in input_ids])
+ encoded_inputs["attention_mask"] = attention_mask
+ return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors)
+
+ def decode_best_spans(
+ self,
+ reader_input: BatchEncoding,
+ reader_output: DPRReaderOutput,
+ num_spans: int = 16,
+ max_answer_length: int = 64,
+ num_spans_per_passage: int = 4,
+ ) -> list[DPRSpanPrediction]:
+ """
+ Get the span predictions for the extractive Q&A model.
+
+ Returns: *List* of *DPRReaderOutput* sorted by descending *(relevance_score, span_score)*. Each
+ *DPRReaderOutput* is a *Tuple* with:
+
+ - **span_score**: `float` that corresponds to the score given by the reader for this span compared to other
+ spans in the same passage. It corresponds to the sum of the start and end logits of the span.
+ - **relevance_score**: `float` that corresponds to the score of the each passage to answer the question,
+ compared to all the other passages. It corresponds to the output of the QA classifier of the DPRReader.
+ - **doc_id**: `int` the id of the passage. - ***start_index**: `int` the start index of the span
+ (inclusive). - **end_index**: `int` the end index of the span (inclusive).
+
+ Examples:
+
+ ```python
+ >>> from transformers import DPRReader, DPRReaderTokenizer
+
+ >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base")
+ >>> model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base")
+ >>> encoded_inputs = tokenizer(
+ ... questions=["What is love ?"],
+ ... titles=["Haddaway"],
+ ... texts=["'What Is Love' is a song recorded by the artist Haddaway"],
+ ... return_tensors="pt",
+ ... )
+ >>> outputs = model(**encoded_inputs)
+ >>> predicted_spans = tokenizer.decode_best_spans(encoded_inputs, outputs)
+ >>> print(predicted_spans[0].text) # best span
+ a song
+ ```"""
+ input_ids = reader_input["input_ids"]
+ start_logits, end_logits, relevance_logits = reader_output[:3]
+ n_passages = len(relevance_logits)
+ sorted_docs = sorted(range(n_passages), reverse=True, key=relevance_logits.__getitem__)
+ nbest_spans_predictions: list[DPRReaderOutput] = []
+ for doc_id in sorted_docs:
+ sequence_ids = list(input_ids[doc_id])
+ # assuming question & title information is at the beginning of the sequence
+ passage_offset = sequence_ids.index(self.sep_token_id, 2) + 1 # second sep id
+ if sequence_ids[-1] == self.pad_token_id:
+ sequence_len = sequence_ids.index(self.pad_token_id)
+ else:
+ sequence_len = len(sequence_ids)
+
+ best_spans = self._get_best_spans(
+ start_logits=start_logits[doc_id][passage_offset:sequence_len],
+ end_logits=end_logits[doc_id][passage_offset:sequence_len],
+ max_answer_length=max_answer_length,
+ top_spans=num_spans_per_passage,
+ )
+ for start_index, end_index in best_spans:
+ start_index += passage_offset
+ end_index += passage_offset
+ nbest_spans_predictions.append(
+ DPRSpanPrediction(
+ span_score=start_logits[doc_id][start_index] + end_logits[doc_id][end_index],
+ relevance_score=relevance_logits[doc_id],
+ doc_id=doc_id,
+ start_index=start_index,
+ end_index=end_index,
+ text=self.decode(sequence_ids[start_index : end_index + 1]),
+ )
+ )
+ if len(nbest_spans_predictions) >= num_spans:
+ break
+ return nbest_spans_predictions[:num_spans]
+
+ def _get_best_spans(
+ self,
+ start_logits: list[int],
+ end_logits: list[int],
+ max_answer_length: int,
+ top_spans: int,
+ ) -> list[DPRSpanPrediction]:
+ """
+ Finds the best answer span for the extractive Q&A model for one passage. It returns the best span by descending
+ `span_score` order and keeping max `top_spans` spans. Spans longer that `max_answer_length` are ignored.
+ """
+ scores = []
+ for start_index, start_score in enumerate(start_logits):
+ for answer_length, end_score in enumerate(end_logits[start_index : start_index + max_answer_length]):
+ scores.append(((start_index, start_index + answer_length), start_score + end_score))
+ scores = sorted(scores, key=lambda x: x[1], reverse=True)
+ chosen_span_intervals = []
+ for (start_index, end_index), score in scores:
+ assert start_index <= end_index, f"Wrong span indices: [{start_index}:{end_index}]"
+ length = end_index - start_index + 1
+ assert length <= max_answer_length, f"Span is too long: {length} > {max_answer_length}"
+ if any(
+ start_index <= prev_start_index <= prev_end_index <= end_index
+ or prev_start_index <= start_index <= end_index <= prev_end_index
+ for (prev_start_index, prev_end_index) in chosen_span_intervals
+ ):
+ continue
+ chosen_span_intervals.append((start_index, end_index))
+
+ if len(chosen_span_intervals) == top_spans:
+ break
+ return chosen_span_intervals
+
+
+@add_end_docstrings(CUSTOM_DPR_READER_DOCSTRING)
+class DPRReaderTokenizerFast(CustomDPRReaderTokenizerMixin, BertTokenizerFast):
+ r"""
+ Constructs a "fast" DPRReader tokenizer (backed by HuggingFace's *tokenizers* library).
+
+ [`DPRReaderTokenizerFast`] is almost identical to [`BertTokenizerFast`] and runs end-to-end tokenization:
+ punctuation splitting and wordpiece. The difference is that is has three inputs strings: question, titles and texts
+ that are combined to be fed to the [`DPRReader`] model.
+
+ Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters.
+
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = DPRReaderTokenizer
+
+
+__all__ = ["DPRContextEncoderTokenizerFast", "DPRQuestionEncoderTokenizerFast", "DPRReaderTokenizerFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dpt/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/dpt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce0070f270f3604afd0661e0cd8aaa4fa2141217
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dpt/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_dpt import *
+ from .feature_extraction_dpt import *
+ from .image_processing_dpt import *
+ from .image_processing_dpt_fast import *
+ from .modeling_dpt import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dpt/configuration_dpt.py b/venv/lib/python3.13/site-packages/transformers/models/dpt/configuration_dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..311425fcda1c88c888171256f36e27d9aeaa7487
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dpt/configuration_dpt.py
@@ -0,0 +1,302 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DPT model configuration"""
+
+import copy
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import verify_backbone_config_arguments
+from ..auto.configuration_auto import CONFIG_MAPPING
+from ..bit import BitConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class DPTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DPTModel`]. It is used to instantiate an DPT
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the DPT
+ [Intel/dpt-large](https://huggingface.co/Intel/dpt-large) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 384):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ is_hybrid (`bool`, *optional*, defaults to `False`):
+ Whether to use a hybrid backbone. Useful in the context of loading DPT-Hybrid models.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ backbone_out_indices (`list[int]`, *optional*, defaults to `[2, 5, 8, 11]`):
+ Indices of the intermediate hidden states to use from backbone.
+ readout_type (`str`, *optional*, defaults to `"project"`):
+ The readout type to use when processing the readout token (CLS token) of the intermediate hidden states of
+ the ViT backbone. Can be one of [`"ignore"`, `"add"`, `"project"`].
+
+ - "ignore" simply ignores the CLS token.
+ - "add" passes the information from the CLS token to all other tokens by adding the representations.
+ - "project" passes information to the other tokens by concatenating the readout to all other tokens before
+ projecting the
+ representation to the original feature dimension D using a linear layer followed by a GELU non-linearity.
+ reassemble_factors (`list[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`):
+ The up/downsampling factors of the reassemble layers.
+ neck_hidden_sizes (`list[str]`, *optional*, defaults to `[96, 192, 384, 768]`):
+ The hidden sizes to project to for the feature maps of the backbone.
+ fusion_hidden_size (`int`, *optional*, defaults to 256):
+ The number of channels before fusion.
+ head_in_index (`int`, *optional*, defaults to -1):
+ The index of the features to use in the heads.
+ use_batch_norm_in_fusion_residual (`bool`, *optional*, defaults to `False`):
+ Whether to use batch normalization in the pre-activate residual units of the fusion blocks.
+ use_bias_in_fusion_residual (`bool`, *optional*, defaults to `True`):
+ Whether to use bias in the pre-activate residual units of the fusion blocks.
+ add_projection (`bool`, *optional*, defaults to `False`):
+ Whether to add a projection layer before the depth estimation head.
+ use_auxiliary_head (`bool`, *optional*, defaults to `True`):
+ Whether to use an auxiliary head during training.
+ auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
+ Weight of the cross-entropy loss of the auxiliary head.
+ semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
+ The index that is ignored by the loss function of the semantic segmentation model.
+ semantic_classifier_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the semantic classification head.
+ backbone_featmap_shape (`list[int]`, *optional*, defaults to `[1, 1024, 24, 24]`):
+ Used only for the `hybrid` embedding type. The shape of the feature maps of the backbone.
+ neck_ignore_stages (`list[int]`, *optional*, defaults to `[0, 1]`):
+ Used only for the `hybrid` embedding type. The stages of the readout layers to ignore.
+ backbone_config (`Union[dict[str, Any], PretrainedConfig]`, *optional*):
+ The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to
+ leverage the [`AutoBackbone`] API.
+ backbone (`str`, *optional*):
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
+ Whether to use pretrained weights for the backbone.
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
+ library.
+ backbone_kwargs (`dict`, *optional*):
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
+ pooler_output_size (`int`, *optional*):
+ Dimensionality of the pooler layer. If None, defaults to `hidden_size`.
+ pooler_act (`str`, *optional*, defaults to `"tanh"`):
+ The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and
+ Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are
+ supported for Tensorflow.
+
+ Example:
+
+ ```python
+ >>> from transformers import DPTModel, DPTConfig
+
+ >>> # Initializing a DPT dpt-large style configuration
+ >>> configuration = DPTConfig()
+
+ >>> # Initializing a model from the dpt-large style configuration
+ >>> model = DPTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "dpt"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ image_size=384,
+ patch_size=16,
+ num_channels=3,
+ is_hybrid=False,
+ qkv_bias=True,
+ backbone_out_indices=[2, 5, 8, 11],
+ readout_type="project",
+ reassemble_factors=[4, 2, 1, 0.5],
+ neck_hidden_sizes=[96, 192, 384, 768],
+ fusion_hidden_size=256,
+ head_in_index=-1,
+ use_batch_norm_in_fusion_residual=False,
+ use_bias_in_fusion_residual=None,
+ add_projection=False,
+ use_auxiliary_head=True,
+ auxiliary_loss_weight=0.4,
+ semantic_loss_ignore_index=255,
+ semantic_classifier_dropout=0.1,
+ backbone_featmap_shape=[1, 1024, 24, 24],
+ neck_ignore_stages=[0, 1],
+ backbone_config=None,
+ backbone=None,
+ use_pretrained_backbone=False,
+ use_timm_backbone=False,
+ backbone_kwargs=None,
+ pooler_output_size=None,
+ pooler_act="tanh",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.is_hybrid = is_hybrid
+
+ use_autobackbone = False
+ if self.is_hybrid:
+ if backbone_config is None:
+ backbone_config = {
+ "global_padding": "same",
+ "layer_type": "bottleneck",
+ "depths": [3, 4, 9],
+ "out_features": ["stage1", "stage2", "stage3"],
+ "embedding_dynamic_padding": True,
+ }
+
+ if isinstance(backbone_config, dict):
+ logger.info("Initializing the config with a `BiT` backbone.")
+ backbone_config = BitConfig(**backbone_config)
+ elif not isinstance(backbone_config, PretrainedConfig):
+ raise ValueError(
+ f"backbone_config must be a dictionary or a `PretrainedConfig`, got {backbone_config.__class__}."
+ )
+ self.backbone_config = backbone_config
+ self.backbone_featmap_shape = backbone_featmap_shape
+ self.neck_ignore_stages = neck_ignore_stages
+
+ if readout_type != "project":
+ raise ValueError("Readout type must be 'project' when using `DPT-hybrid` mode.")
+
+ elif backbone is not None or backbone_config is not None:
+ use_autobackbone = True
+ if isinstance(backbone_config, dict):
+ backbone_model_type = backbone_config.get("model_type")
+ config_class = CONFIG_MAPPING[backbone_model_type]
+ backbone_config = config_class.from_dict(backbone_config)
+
+ self.backbone_config = backbone_config
+ self.backbone_featmap_shape = None
+ self.neck_ignore_stages = []
+
+ # We only use load_backbone when config.is_hydrid is False
+ verify_backbone_config_arguments(
+ use_timm_backbone=use_timm_backbone,
+ use_pretrained_backbone=use_pretrained_backbone,
+ backbone=backbone,
+ backbone_config=backbone_config,
+ backbone_kwargs=backbone_kwargs,
+ )
+ else:
+ self.backbone_config = None
+ self.backbone_featmap_shape = None
+ self.neck_ignore_stages = []
+
+ self.backbone = backbone
+ self.use_pretrained_backbone = use_pretrained_backbone
+ self.use_timm_backbone = use_timm_backbone
+ self.backbone_kwargs = backbone_kwargs
+
+ # ViT parameters used if not using a hybrid backbone
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.use_autobackbone = use_autobackbone
+ self.backbone_out_indices = None if use_autobackbone else backbone_out_indices
+
+ if readout_type not in ["ignore", "add", "project"]:
+ raise ValueError("Readout_type must be one of ['ignore', 'add', 'project']")
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.readout_type = readout_type
+ self.reassemble_factors = reassemble_factors
+ self.neck_hidden_sizes = neck_hidden_sizes
+ self.fusion_hidden_size = fusion_hidden_size
+ self.head_in_index = head_in_index
+ self.use_batch_norm_in_fusion_residual = use_batch_norm_in_fusion_residual
+ self.use_bias_in_fusion_residual = use_bias_in_fusion_residual
+ self.add_projection = add_projection
+
+ # auxiliary head attributes (semantic segmentation)
+ self.use_auxiliary_head = use_auxiliary_head
+ self.auxiliary_loss_weight = auxiliary_loss_weight
+ self.semantic_loss_ignore_index = semantic_loss_ignore_index
+ self.semantic_classifier_dropout = semantic_classifier_dropout
+ self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size
+ self.pooler_act = pooler_act
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
+ `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+
+ if output["backbone_config"] is not None:
+ output["backbone_config"] = self.backbone_config.to_dict()
+
+ output["model_type"] = self.__class__.model_type
+ return output
+
+ @property
+ def sub_configs(self):
+ return (
+ {"backbone_config": type(self.backbone_config)}
+ if getattr(self, "backbone_config", None) is not None
+ else {}
+ )
+
+
+__all__ = ["DPTConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dpt/feature_extraction_dpt.py b/venv/lib/python3.13/site-packages/transformers/models/dpt/feature_extraction_dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6ab8ccbed8d33b1e5b15d429b6cb057ff781f78
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dpt/feature_extraction_dpt.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for DPT."""
+
+import warnings
+
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_dpt import DPTImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class DPTFeatureExtractor(DPTImageProcessor):
+ def __init__(self, *args, **kwargs) -> None:
+ warnings.warn(
+ "The class DPTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
+ " use DPTImageProcessor instead.",
+ FutureWarning,
+ )
+ super().__init__(*args, **kwargs)
+
+
+__all__ = ["DPTFeatureExtractor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dpt/image_processing_dpt.py b/venv/lib/python3.13/site-packages/transformers/models/dpt/image_processing_dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b28950d2ded2d046e6c39d54d246f893a50d122
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dpt/image_processing_dpt.py
@@ -0,0 +1,677 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for DPT."""
+
+import math
+from collections.abc import Iterable
+from typing import TYPE_CHECKING, Optional, Union
+
+from ...utils.import_utils import requires
+
+
+if TYPE_CHECKING:
+ from ...modeling_outputs import DepthEstimatorOutput
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import pad, resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ is_torch_available,
+ is_torch_tensor,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import (
+ TensorType,
+ filter_out_non_signature_kwargs,
+ is_vision_available,
+ logging,
+ requires_backends,
+)
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_resize_output_image_size(
+ input_image: np.ndarray,
+ output_size: Union[int, Iterable[int]],
+ keep_aspect_ratio: bool,
+ multiple: int,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> tuple[int, int]:
+ def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None):
+ x = round(val / multiple) * multiple
+
+ if max_val is not None and x > max_val:
+ x = math.floor(val / multiple) * multiple
+
+ if x < min_val:
+ x = math.ceil(val / multiple) * multiple
+
+ return x
+
+ output_size = (output_size, output_size) if isinstance(output_size, int) else output_size
+
+ input_height, input_width = get_image_size(input_image, input_data_format)
+ output_height, output_width = output_size
+
+ # determine new height and width
+ scale_height = output_height / input_height
+ scale_width = output_width / input_width
+
+ if keep_aspect_ratio:
+ # scale as little as possible
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+
+ new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
+ new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
+
+ return (new_height, new_width)
+
+
+@requires(backends=("vision",))
+class DPTImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a DPT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions. Can be overridden by `do_resize` in `preprocess`.
+ size (`dict[str, int]` *optional*, defaults to `{"height": 384, "width": 384}`):
+ Size of the image after resizing. Can be overridden by `size` in `preprocess`.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Defines the resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
+ keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
+ If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
+ be overridden by `keep_aspect_ratio` in `preprocess`.
+ ensure_multiple_of (`int`, *optional*, defaults to 1):
+ If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overridden
+ by `ensure_multiple_of` in `preprocess`.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ `preprocess`.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in `preprocess`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_pad (`bool`, *optional*, defaults to `False`):
+ Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in
+ combination with DPT.
+ size_divisor (`int`, *optional*):
+ If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the
+ DINOv2 paper, which uses the model in combination with DPT.
+ do_reduce_labels (`bool`, *optional*, defaults to `False`):
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
+ used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
+ background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
+ `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ keep_aspect_ratio: bool = False,
+ ensure_multiple_of: int = 1,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: bool = False,
+ size_divisor: Optional[int] = None,
+ do_reduce_labels: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 384, "width": 384}
+ size = get_size_dict(size)
+ self.do_resize = do_resize
+ self.size = size
+ self.keep_aspect_ratio = keep_aspect_ratio
+ self.ensure_multiple_of = ensure_multiple_of
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+ self.do_pad = do_pad
+ self.size_divisor = size_divisor
+ self.do_reduce_labels = do_reduce_labels
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ keep_aspect_ratio: bool = False,
+ ensure_multiple_of: int = 1,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image
+ is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is
+ set, the image is resized to a size that is a multiple of this value.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Target size of the output image.
+ keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
+ If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
+ ensure_multiple_of (`int`, *optional*, defaults to 1):
+ The image is resized to a size that is a multiple of this value.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size
+ specified in `size`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
+
+ output_size = get_resize_output_image_size(
+ image,
+ output_size=(size["height"], size["width"]),
+ keep_aspect_ratio=keep_aspect_ratio,
+ multiple=ensure_multiple_of,
+ input_data_format=input_data_format,
+ )
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def pad_image(
+ self,
+ image: np.ndarray,
+ size_divisor: int,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Center pad an image to be a multiple of `multiple`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to pad.
+ size_divisor (`int`):
+ The width and height of the image will be padded to a multiple of this number.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+
+ def _get_pad(size, size_divisor):
+ new_size = math.ceil(size / size_divisor) * size_divisor
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+
+ height, width = get_image_size(image, input_data_format)
+
+ pad_size_left, pad_size_right = _get_pad(height, size_divisor)
+ pad_size_top, pad_size_bottom = _get_pad(width, size_divisor)
+
+ return pad(image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format)
+
+ # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label
+ def reduce_label(self, label: ImageInput) -> np.ndarray:
+ label = to_numpy_array(label)
+ # Avoid using underflow conversion
+ label[label == 0] = 255
+ label = label - 1
+ label[label == 254] = 255
+ return label
+
+ def _preprocess(
+ self,
+ image: ImageInput,
+ do_reduce_labels: Optional[bool] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ keep_aspect_ratio: Optional[bool] = None,
+ ensure_multiple_of: Optional[int] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ size_divisor: Optional[int] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ if do_reduce_labels:
+ image = self.reduce_label(image)
+
+ if do_resize:
+ image = self.resize(
+ image=image,
+ size=size,
+ resample=resample,
+ keep_aspect_ratio=keep_aspect_ratio,
+ ensure_multiple_of=ensure_multiple_of,
+ input_data_format=input_data_format,
+ )
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+
+ if do_pad:
+ image = self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format)
+
+ return image
+
+ def _preprocess_image(
+ self,
+ image: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ keep_aspect_ratio: Optional[bool] = None,
+ ensure_multiple_of: Optional[int] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ size_divisor: Optional[int] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """Preprocesses a single image."""
+ # All transformations expect numpy arrays.
+ image = to_numpy_array(image)
+ if do_rescale and is_scaled_image(image):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(image)
+
+ image = self._preprocess(
+ image,
+ do_reduce_labels=False,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ keep_aspect_ratio=keep_aspect_ratio,
+ ensure_multiple_of=ensure_multiple_of,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_pad=do_pad,
+ size_divisor=size_divisor,
+ input_data_format=input_data_format,
+ )
+ if data_format is not None:
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ return image
+
+ def _preprocess_segmentation_map(
+ self,
+ segmentation_map: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ keep_aspect_ratio: Optional[bool] = None,
+ ensure_multiple_of: Optional[int] = None,
+ do_reduce_labels: Optional[bool] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """Preprocesses a single segmentation map."""
+ # All transformations expect numpy arrays.
+ segmentation_map = to_numpy_array(segmentation_map)
+ # Add an axis to the segmentation maps for transformations.
+ if segmentation_map.ndim == 2:
+ segmentation_map = segmentation_map[None, ...]
+ added_dimension = True
+ input_data_format = ChannelDimension.FIRST
+ else:
+ added_dimension = False
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
+ segmentation_map = self._preprocess(
+ image=segmentation_map,
+ do_reduce_labels=do_reduce_labels,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ keep_aspect_ratio=keep_aspect_ratio,
+ ensure_multiple_of=ensure_multiple_of,
+ do_normalize=False,
+ do_rescale=False,
+ input_data_format=input_data_format,
+ )
+ # Remove extra axis if added
+ if added_dimension:
+ segmentation_map = np.squeeze(segmentation_map, axis=0)
+ segmentation_map = segmentation_map.astype(np.int64)
+ return segmentation_map
+
+ # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.__call__
+ def __call__(self, images, segmentation_maps=None, **kwargs):
+ # Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both
+ # be passed in as positional arguments.
+ return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[int] = None,
+ keep_aspect_ratio: Optional[bool] = None,
+ ensure_multiple_of: Optional[int] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ size_divisor: Optional[int] = None,
+ do_reduce_labels: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ segmentation_maps (`ImageInput`, *optional*):
+ Segmentation map to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after reszing. If `keep_aspect_ratio` is `True`, the image is resized to the largest
+ possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is set, the image is
+ resized to a size that is a multiple of this value.
+ keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`):
+ Whether to keep the aspect ratio of the image. If False, the image will be resized to (size, size). If
+ True, the image will be resized to keep the aspect ratio and the size will be the maximum possible.
+ ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`):
+ Ensure that the image size is a multiple of this value.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
+ is used for background, and background itself is not included in all classes of a dataset (e.g.
+ ADE20k). The background label will be replaced by 255.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ keep_aspect_ratio = keep_aspect_ratio if keep_aspect_ratio is not None else self.keep_aspect_ratio
+ ensure_multiple_of = ensure_multiple_of if ensure_multiple_of is not None else self.ensure_multiple_of
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ size_divisor = size_divisor if size_divisor is not None else self.size_divisor
+ do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
+
+ images = make_flat_list_of_images(images)
+
+ if segmentation_maps is not None:
+ segmentation_maps = make_flat_list_of_images(segmentation_maps, expected_ndims=2)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ images = [
+ self._preprocess_image(
+ image=img,
+ do_resize=do_resize,
+ do_rescale=do_rescale,
+ do_normalize=do_normalize,
+ do_pad=do_pad,
+ size=size,
+ resample=resample,
+ keep_aspect_ratio=keep_aspect_ratio,
+ ensure_multiple_of=ensure_multiple_of,
+ rescale_factor=rescale_factor,
+ image_mean=image_mean,
+ image_std=image_std,
+ size_divisor=size_divisor,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ for img in images
+ ]
+
+ data = {"pixel_values": images}
+
+ if segmentation_maps is not None:
+ segmentation_maps = [
+ self._preprocess_segmentation_map(
+ segmentation_map=segmentation_map,
+ do_reduce_labels=do_reduce_labels,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ keep_aspect_ratio=keep_aspect_ratio,
+ ensure_multiple_of=ensure_multiple_of,
+ input_data_format=input_data_format,
+ )
+ for segmentation_map in segmentation_maps
+ ]
+
+ data["labels"] = segmentation_maps
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->DPT
+ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
+ """
+ Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
+
+ Args:
+ outputs ([`DPTForSemanticSegmentation`]):
+ Raw outputs of the model.
+ target_sizes (`list[Tuple]` of length `batch_size`, *optional*):
+ List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
+ predictions will not be resized.
+
+ Returns:
+ semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic
+ segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
+ specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
+ """
+ # TODO: add support for other frameworks
+ logits = outputs.logits
+
+ # Resize logits and compute semantic segmentation maps
+ if target_sizes is not None:
+ if len(logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ if is_torch_tensor(target_sizes):
+ target_sizes = target_sizes.numpy()
+
+ semantic_segmentation = []
+
+ for idx in range(len(logits)):
+ resized_logits = torch.nn.functional.interpolate(
+ logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
+ )
+ semantic_map = resized_logits[0].argmax(dim=0)
+ semantic_segmentation.append(semantic_map)
+ else:
+ semantic_segmentation = logits.argmax(dim=1)
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
+
+ return semantic_segmentation
+
+ def post_process_depth_estimation(
+ self,
+ outputs: "DepthEstimatorOutput",
+ target_sizes: Optional[Union[TensorType, list[tuple[int, int]], None]] = None,
+ ) -> list[dict[str, TensorType]]:
+ """
+ Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
+ Only supports PyTorch.
+
+ Args:
+ outputs ([`DepthEstimatorOutput`]):
+ Raw outputs of the model.
+ target_sizes (`TensorType` or `list[tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
+ (height, width) of each image in the batch. If left to None, predictions will not be resized.
+
+ Returns:
+ `list[dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
+ predictions.
+ """
+ requires_backends(self, "torch")
+
+ predicted_depth = outputs.predicted_depth
+
+ if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
+ )
+
+ results = []
+ target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
+ for depth, target_size in zip(predicted_depth, target_sizes):
+ if target_size is not None:
+ depth = torch.nn.functional.interpolate(
+ depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
+ ).squeeze()
+
+ results.append({"predicted_depth": depth})
+
+ return results
+
+
+__all__ = ["DPTImageProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dpt/image_processing_dpt_fast.py b/venv/lib/python3.13/site-packages/transformers/models/dpt/image_processing_dpt_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..faaddb8023c08aee60e90142bef0cf44048e5ca2
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dpt/image_processing_dpt_fast.py
@@ -0,0 +1,406 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/dpt/modular_dpt.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_dpt.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections.abc import Iterable
+from typing import TYPE_CHECKING, Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_base import BatchFeature
+from ...image_processing_utils_fast import BaseImageProcessorFast, DefaultFastImageProcessorKwargs
+from ...image_transforms import group_images_by_shape, reorder_images
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+ is_torch_tensor,
+)
+from ...processing_utils import Unpack
+from ...utils import TensorType, auto_docstring, requires_backends
+
+
+if TYPE_CHECKING:
+ from ...modeling_outputs import DepthEstimatorOutput
+
+
+class DPTFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ ensure_multiple_of (`int`, *optional*, defaults to 1):
+ If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overridden
+ by `ensure_multiple_of` in `preprocess`.
+ size_divisor (`int`, *optional*):
+ If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the
+ DINOv2 paper, which uses the model in combination with DPT.
+ keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
+ If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
+ be overridden by `keep_aspect_ratio` in `preprocess`.
+ do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
+ is used for background, and background itself is not included in all classes of a dataset (e.g.
+ ADE20k). The background label will be replaced by 255.
+ """
+
+ ensure_multiple_of: Optional[int]
+ size_divisor: Optional[int]
+ keep_aspect_ratio: Optional[bool]
+ do_reduce_labels: Optional[bool]
+
+
+def get_resize_output_image_size(
+ input_image: "torch.Tensor",
+ output_size: Union[int, Iterable[int]],
+ keep_aspect_ratio: bool,
+ multiple: int,
+) -> SizeDict:
+ def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None):
+ x = round(val / multiple) * multiple
+
+ if max_val is not None and x > max_val:
+ x = math.floor(val / multiple) * multiple
+
+ if x < min_val:
+ x = math.ceil(val / multiple) * multiple
+
+ return x
+
+ input_height, input_width = input_image.shape[-2:]
+ output_height, output_width = output_size
+
+ # determine new height and width
+ scale_height = output_height / input_height
+ scale_width = output_width / input_width
+
+ if keep_aspect_ratio:
+ # scale as little as possible
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+
+ new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
+ new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
+
+ return SizeDict(height=new_height, width=new_width)
+
+
+@auto_docstring
+class DPTImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ size = {"height": 384, "width": 384}
+ default_to_square = True
+ crop_size = None
+ do_resize = True
+ do_center_crop = None
+ do_rescale = True
+ do_normalize = True
+ do_reduce_labels = None
+
+ valid_kwargs = DPTFastImageProcessorKwargs
+ do_pad = False
+ rescale_factor = 1 / 255
+ ensure_multiple_of = 1
+ keep_aspect_ratio = False
+
+ def __init__(self, **kwargs: Unpack[DPTFastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ def reduce_label(self, labels: list["torch.Tensor"]):
+ for idx in range(len(labels)):
+ label = labels[idx]
+ label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label)
+ label = label - 1
+ label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label)
+ labels[idx] = label
+
+ return label
+
+ @auto_docstring
+ def preprocess(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput] = None,
+ **kwargs: Unpack[DPTFastImageProcessorKwargs],
+ ) -> BatchFeature:
+ r"""
+ segmentation_maps (`ImageInput`, *optional*):
+ The segmentation maps to preprocess.
+ """
+ return super().preprocess(images, segmentation_maps, **kwargs)
+
+ def _preprocess_image_like_inputs(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput],
+ do_convert_rgb: bool,
+ input_data_format: ChannelDimension,
+ device: Optional[Union[str, "torch.device"]] = None,
+ **kwargs: Unpack[DPTFastImageProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Preprocess image-like inputs.
+ """
+ images = self._prepare_image_like_inputs(
+ images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
+ )
+ images_kwargs = kwargs.copy()
+ images_kwargs["do_reduce_labels"] = False
+ batch_feature = self._preprocess(images, **images_kwargs)
+
+ if segmentation_maps is not None:
+ processed_segmentation_maps = self._prepare_image_like_inputs(
+ images=segmentation_maps,
+ expected_ndims=2,
+ do_convert_rgb=False,
+ input_data_format=ChannelDimension.FIRST,
+ )
+
+ segmentation_maps_kwargs = kwargs.copy()
+ segmentation_maps_kwargs.update({"do_normalize": False, "do_rescale": False})
+ processed_segmentation_maps = self._preprocess(
+ images=processed_segmentation_maps, **segmentation_maps_kwargs
+ ).pixel_values
+ batch_feature["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64)
+
+ return batch_feature
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_reduce_labels: bool,
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ keep_aspect_ratio: bool,
+ ensure_multiple_of: Optional[int],
+ do_pad: bool,
+ size_divisor: Optional[int],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ if do_reduce_labels:
+ images = self.reduce_label(images)
+
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(
+ image=stacked_images,
+ size=size,
+ interpolation=interpolation,
+ ensure_multiple_of=ensure_multiple_of,
+ keep_aspect_ratio=keep_aspect_ratio,
+ )
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_center_crop:
+ stacked_images = self.center_crop(stacked_images, crop_size)
+ if do_pad:
+ stacked_images = self.pad_image(stacked_images, size_divisor)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+ return BatchFeature(data={"pixel_values": processed_images})
+
+ def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
+ """
+ Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
+
+ Args:
+ outputs ([`DPTForSemanticSegmentation`]):
+ Raw outputs of the model.
+ target_sizes (`list[Tuple]` of length `batch_size`, *optional*):
+ List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
+ predictions will not be resized.
+
+ Returns:
+ semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic
+ segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
+ specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
+ """
+ # TODO: add support for other frameworks
+ logits = outputs.logits
+
+ # Resize logits and compute semantic segmentation maps
+ if target_sizes is not None:
+ if len(logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ if is_torch_tensor(target_sizes):
+ target_sizes = target_sizes.numpy()
+
+ semantic_segmentation = []
+
+ for idx in range(len(logits)):
+ resized_logits = torch.nn.functional.interpolate(
+ logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
+ )
+ semantic_map = resized_logits[0].argmax(dim=0)
+ semantic_segmentation.append(semantic_map)
+ else:
+ semantic_segmentation = logits.argmax(dim=1)
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
+
+ return semantic_segmentation
+
+ def resize(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ antialias: bool = True,
+ ensure_multiple_of: Optional[int] = 1,
+ keep_aspect_ratio: bool = False,
+ ) -> "torch.Tensor":
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to resize.
+ size (`SizeDict`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
+ antialias (`bool`, *optional*, defaults to `True`):
+ Whether to use antialiasing when resizing the image
+ ensure_multiple_of (`int`, *optional*):
+ If `do_resize` is `True`, the image is resized to a size that is a multiple of this value
+ keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
+ If `True`, and `do_resize` is `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
+
+ Returns:
+ `torch.Tensor`: The resized image.
+ """
+ if not size.height or not size.width:
+ raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
+
+ output_size = get_resize_output_image_size(
+ image,
+ output_size=(size.height, size.width),
+ keep_aspect_ratio=keep_aspect_ratio,
+ multiple=ensure_multiple_of,
+ )
+ return super().resize(image, output_size, interpolation=interpolation, antialias=antialias)
+
+ def pad_image(
+ self,
+ image: "torch.Tensor",
+ size_divisor: int = 1,
+ ) -> "torch.Tensor":
+ r"""
+ Center pad a batch of images to be a multiple of `size_divisor`.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to pad. Can be a batch of images of dimensions (N, C, H, W) or a single image of dimensions (C, H, W).
+ size_divisor (`int`):
+ The width and height of the image will be padded to a multiple of this number.
+ """
+ height, width = image.shape[-2:]
+
+ def _get_pad(size, size_divisor):
+ new_size = math.ceil(size / size_divisor) * size_divisor
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ pad_top, pad_bottom = _get_pad(height, size_divisor)
+ pad_left, pad_right = _get_pad(width, size_divisor)
+ padding = (pad_left, pad_top, pad_right, pad_bottom)
+ return F.pad(image, padding)
+
+ def post_process_depth_estimation(
+ self,
+ outputs: "DepthEstimatorOutput",
+ target_sizes: Optional[Union[TensorType, list[tuple[int, int]], None]] = None,
+ ) -> list[dict[str, TensorType]]:
+ """
+ Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
+ Only supports PyTorch.
+
+ Args:
+ outputs ([`DepthEstimatorOutput`]):
+ Raw outputs of the model.
+ target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
+ (height, width) of each image in the batch. If left to None, predictions will not be resized.
+
+ Returns:
+ `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
+ predictions.
+ """
+ requires_backends(self, "torch")
+
+ predicted_depth = outputs.predicted_depth
+
+ if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
+ )
+
+ results = []
+ target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
+ for depth, target_size in zip(predicted_depth, target_sizes):
+ if target_size is not None:
+ depth = torch.nn.functional.interpolate(
+ depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
+ ).squeeze()
+
+ results.append({"predicted_depth": depth})
+
+ return results
+
+
+__all__ = ["DPTImageProcessorFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dpt/modeling_dpt.py b/venv/lib/python3.13/site-packages/transformers/models/dpt/modeling_dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..49e1bbcd9f0fc8823a74c64a8e7d9c578133fbaa
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dpt/modeling_dpt.py
@@ -0,0 +1,1225 @@
+# coding=utf-8
+# Copyright 2022 Intel Labs, OpenMMLab and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DPT (Dense Prediction Transformers) model.
+
+This implementation is heavily inspired by OpenMMLab's implementation, found here:
+https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/dpt_head.py.
+
+"""
+
+import collections.abc
+from dataclasses import dataclass
+from typing import Callable, Optional
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import ModelOutput, auto_docstring, logging, torch_int
+from ...utils.backbone_utils import load_backbone
+from ...utils.generic import can_return_tuple, check_model_inputs
+from .configuration_dpt import DPTConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for model's outputs that also contains intermediate activations that can be used at later stages. Useful
+ in the context of Vision models.:
+ """
+)
+class BaseModelOutputWithIntermediateActivations(ModelOutput):
+ r"""
+ last_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
+ Intermediate activations that can be used to compute hidden states of the model at various layers.
+ """
+
+ last_hidden_states: Optional[torch.FloatTensor] = None
+ intermediate_activations: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for model's outputs that also contains a pooling of the last hidden states as well as intermediate
+ activations that can be used by the model at later stages.
+ """
+)
+class BaseModelOutputWithPoolingAndIntermediateActivations(ModelOutput):
+ r"""
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token) after further processing
+ through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
+ the classification token after processing through a linear layer and a tanh activation function. The linear
+ layer weights are trained from the next sentence prediction (classification) objective during pretraining.
+ intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
+ Intermediate activations that can be used to compute hidden states of the model at various layers.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ intermediate_activations: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+class DPTViTHybridEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config: DPTConfig, feature_size: Optional[tuple[int, int]] = None):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+
+ self.backbone = load_backbone(config)
+ feature_dim = self.backbone.channels[-1]
+ if len(self.backbone.channels) != 3:
+ raise ValueError(f"Expected backbone to have 3 output features, got {len(self.backbone.channels)}")
+ self.residual_feature_map_index = [0, 1] # Always take the output of the first and second backbone stage
+
+ if feature_size is None:
+ feat_map_shape = config.backbone_featmap_shape
+ feature_size = feat_map_shape[-2:]
+ feature_dim = feat_map_shape[1]
+ else:
+ feature_size = (
+ feature_size if isinstance(feature_size, collections.abc.Iterable) else (feature_size, feature_size)
+ )
+ feature_dim = self.backbone.channels[-1]
+
+ self.image_size = image_size
+ self.patch_size = patch_size[0]
+ self.num_channels = num_channels
+
+ self.projection = nn.Conv2d(feature_dim, hidden_size, kernel_size=1)
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
+
+ def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_index=1):
+ posemb_tok = posemb[:, :start_index]
+ posemb_grid = posemb[0, start_index:]
+
+ old_grid_size = torch_int(len(posemb_grid) ** 0.5)
+
+ posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
+ posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, grid_size_height * grid_size_width, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+ def forward(
+ self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False
+ ) -> BaseModelOutputWithIntermediateActivations:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ if not interpolate_pos_encoding:
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
+
+ position_embeddings = self._resize_pos_embed(
+ self.position_embeddings, height // self.patch_size, width // self.patch_size
+ )
+
+ backbone_output = self.backbone(pixel_values)
+
+ features = backbone_output.feature_maps[-1]
+
+ # Retrieve also the intermediate activations to use them at later stages
+ output_hidden_states = [backbone_output.feature_maps[index] for index in self.residual_feature_map_index]
+
+ embeddings = self.projection(features).flatten(2).transpose(1, 2)
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + position_embeddings
+
+ # Return hidden states and intermediate activations
+ return BaseModelOutputWithIntermediateActivations(
+ last_hidden_states=embeddings,
+ intermediate_activations=output_hidden_states,
+ )
+
+
+class DPTViTEmbeddings(nn.Module):
+ """
+ Construct the CLS token, position and patch embeddings.
+
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.patch_embeddings = DPTViTPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_index=1):
+ posemb_tok = posemb[:, :start_index]
+ posemb_grid = posemb[0, start_index:]
+
+ old_grid_size = torch_int(posemb_grid.size(0) ** 0.5)
+
+ posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
+ posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, grid_size_height * grid_size_width, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+ def forward(self, pixel_values: torch.Tensor) -> BaseModelOutputWithIntermediateActivations:
+ batch_size, num_channels, height, width = pixel_values.shape
+
+ # possibly interpolate position encodings to handle varying image sizes
+ patch_size = self.config.patch_size
+ position_embeddings = self._resize_pos_embed(
+ self.position_embeddings, height // patch_size, width // patch_size
+ )
+
+ embeddings = self.patch_embeddings(pixel_values)
+
+ batch_size, seq_len, _ = embeddings.size()
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return BaseModelOutputWithIntermediateActivations(last_hidden_states=embeddings)
+
+
+class DPTViTPatchEmbeddings(nn.Module):
+ """
+ Image to Patch Embedding.
+
+ """
+
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+
+ # Normalize the attention scores to probabilities.
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ # Mask heads if we want to
+ if attention_mask is not None:
+ attn_weights = attn_weights * attention_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DPT
+class DPTSelfAttention(nn.Module):
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.config = config
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.dropout_prob = config.attention_probs_dropout_prob
+ self.scaling = self.attention_head_size**-0.5
+ self.is_causal = False
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ def forward(
+ self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size = hidden_states.shape[0]
+ new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
+
+ key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
+ value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
+ query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ context_layer, attention_probs = attention_interface(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ head_mask,
+ is_causal=self.is_causal,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.dropout_prob,
+ )
+
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.reshape(new_context_layer_shape)
+
+ return context_layer, attention_probs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViTConfig->DPTConfig, ViTSelfOutput->DPTViTSelfOutput
+class DPTViTSelfOutput(nn.Module):
+ """
+ The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViTConfig->DPTConfig, ViTSelfAttention->DPTSelfAttention, ViTSelfOutput->DPTViTSelfOutput
+class DPTViTAttention(nn.Module):
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+ self.attention = DPTSelfAttention(config)
+ self.output = DPTViTSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: set[int]):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ self_attn_output, _ = self.attention(hidden_states, head_mask)
+ output = self.output(self_attn_output, hidden_states)
+ return output
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViTConfig->DPTConfig, ViTIntermediate->DPTViTIntermediate
+class DPTViTIntermediate(nn.Module):
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViTConfig->DPTConfig, ViTOutput->DPTViTOutput
+class DPTViTOutput(nn.Module):
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states + input_tensor
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput, ViTLayer->DPTViTLayer
+class DPTViTLayer(GradientCheckpointingLayer):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = DPTViTAttention(config)
+ self.intermediate = DPTViTIntermediate(config)
+ self.output = DPTViTOutput(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ hidden_states_norm = self.layernorm_before(hidden_states)
+ attention_output = self.attention(hidden_states_norm, head_mask)
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in ViT, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ return layer_output
+
+
+# Copied from transformers.models.dinov2.modeling_dinov2.Dinov2Encoder with Dinov2Config->DPTConfig, Dinov2->DPTViT
+class DPTViTEncoder(nn.Module):
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([DPTViTLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False
+ ) -> BaseModelOutput:
+ all_hidden_states = [hidden_states] if output_hidden_states else None
+ for i, layer_module in enumerate(self.layer):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ hidden_states = layer_module(hidden_states, layer_head_mask)
+ if all_hidden_states:
+ all_hidden_states.append(hidden_states)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
+ )
+
+
+class DPTReassembleStage(nn.Module):
+ """
+ This class reassembles the hidden states of the backbone into image-like feature representations at various
+ resolutions.
+
+ This happens in 3 stages:
+ 1. Map the N + 1 tokens to a set of N tokens, by taking into account the readout ([CLS]) token according to
+ `config.readout_type`.
+ 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
+ 3. Resizing the spatial dimensions (height, width).
+
+ Args:
+ config (`[DPTConfig]`):
+ Model configuration class defining the model architecture.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.layers = nn.ModuleList()
+ if config.is_hybrid:
+ self._init_reassemble_dpt_hybrid(config)
+ else:
+ self._init_reassemble_dpt(config)
+
+ self.neck_ignore_stages = config.neck_ignore_stages
+
+ def _init_reassemble_dpt_hybrid(self, config):
+ r""" "
+ For DPT-Hybrid the first 2 reassemble layers are set to `nn.Identity()`, please check the official
+ implementation: https://github.com/isl-org/DPT/blob/f43ef9e08d70a752195028a51be5e1aff227b913/dpt/vit.py#L438
+ for more details.
+ """
+ for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors):
+ if i <= 1:
+ self.layers.append(nn.Identity())
+ elif i > 1:
+ self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor))
+
+ if config.readout_type != "project":
+ raise ValueError(f"Readout type {config.readout_type} is not supported for DPT-Hybrid.")
+
+ # When using DPT-Hybrid the readout type is set to "project". The sanity check is done on the config file
+ self.readout_projects = nn.ModuleList()
+ hidden_size = _get_backbone_hidden_size(config)
+ for i in range(len(config.neck_hidden_sizes)):
+ if i <= 1:
+ self.readout_projects.append(nn.Sequential(nn.Identity()))
+ elif i > 1:
+ self.readout_projects.append(
+ nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
+ )
+
+ def _init_reassemble_dpt(self, config):
+ for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors):
+ self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor))
+
+ if config.readout_type == "project":
+ self.readout_projects = nn.ModuleList()
+ hidden_size = _get_backbone_hidden_size(config)
+ for _ in range(len(config.neck_hidden_sizes)):
+ self.readout_projects.append(
+ nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
+ )
+
+ def forward(self, hidden_states: list[torch.Tensor], patch_height=None, patch_width=None) -> list[torch.Tensor]:
+ """
+ Args:
+ hidden_states (`list[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
+ List of hidden states from the backbone.
+ """
+ out = []
+
+ for i, hidden_state in enumerate(hidden_states):
+ if i not in self.neck_ignore_stages:
+ # reshape to (batch_size, num_channels, height, width)
+ cls_token, hidden_state = hidden_state[:, 0], hidden_state[:, 1:]
+ batch_size, sequence_length, num_channels = hidden_state.shape
+ if patch_height is not None and patch_width is not None:
+ hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
+ else:
+ size = torch_int(sequence_length**0.5)
+ hidden_state = hidden_state.reshape(batch_size, size, size, num_channels)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+
+ feature_shape = hidden_state.shape
+ if self.config.readout_type == "project":
+ # reshape to (batch_size, height*width, num_channels)
+ hidden_state = hidden_state.flatten(2).permute((0, 2, 1))
+ readout = cls_token.unsqueeze(1).expand_as(hidden_state)
+ # concatenate the readout token to the hidden states and project
+ hidden_state = self.readout_projects[i](torch.cat((hidden_state, readout), -1))
+ # reshape back to (batch_size, num_channels, height, width)
+ hidden_state = hidden_state.permute(0, 2, 1).reshape(feature_shape)
+ elif self.config.readout_type == "add":
+ hidden_state = hidden_state.flatten(2) + cls_token.unsqueeze(-1)
+ hidden_state = hidden_state.reshape(feature_shape)
+ hidden_state = self.layers[i](hidden_state)
+ out.append(hidden_state)
+
+ return out
+
+
+def _get_backbone_hidden_size(config):
+ if config.backbone_config is not None and config.is_hybrid is False:
+ return config.backbone_config.hidden_size
+ else:
+ return config.hidden_size
+
+
+class DPTReassembleLayer(nn.Module):
+ def __init__(self, config: DPTConfig, channels: int, factor: int):
+ super().__init__()
+ # projection
+ hidden_size = _get_backbone_hidden_size(config)
+ self.projection = nn.Conv2d(in_channels=hidden_size, out_channels=channels, kernel_size=1)
+
+ # up/down sampling depending on factor
+ if factor > 1:
+ self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
+ elif factor == 1:
+ self.resize = nn.Identity()
+ elif factor < 1:
+ # so should downsample
+ self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1)
+
+ def forward(self, hidden_state):
+ hidden_state = self.projection(hidden_state)
+ hidden_state = self.resize(hidden_state)
+ return hidden_state
+
+
+class DPTFeatureFusionStage(nn.Module):
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+ self.layers = nn.ModuleList()
+ for _ in range(len(config.neck_hidden_sizes)):
+ self.layers.append(DPTFeatureFusionLayer(config))
+
+ def forward(self, hidden_states):
+ # reversing the hidden_states, we start from the last
+ hidden_states = hidden_states[::-1]
+
+ fused_hidden_states = []
+ fused_hidden_state = None
+ for hidden_state, layer in zip(hidden_states, self.layers):
+ if fused_hidden_state is None:
+ # first layer only uses the last hidden_state
+ fused_hidden_state = layer(hidden_state)
+ else:
+ fused_hidden_state = layer(fused_hidden_state, hidden_state)
+ fused_hidden_states.append(fused_hidden_state)
+
+ return fused_hidden_states
+
+
+class DPTPreActResidualLayer(nn.Module):
+ """
+ ResidualConvUnit, pre-activate residual unit.
+
+ Args:
+ config (`[DPTConfig]`):
+ Model configuration class defining the model architecture.
+ """
+
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+
+ self.use_batch_norm = config.use_batch_norm_in_fusion_residual
+ use_bias_in_fusion_residual = (
+ config.use_bias_in_fusion_residual
+ if config.use_bias_in_fusion_residual is not None
+ else not self.use_batch_norm
+ )
+
+ self.activation1 = nn.ReLU()
+ self.convolution1 = nn.Conv2d(
+ config.fusion_hidden_size,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias_in_fusion_residual,
+ )
+
+ self.activation2 = nn.ReLU()
+ self.convolution2 = nn.Conv2d(
+ config.fusion_hidden_size,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias_in_fusion_residual,
+ )
+
+ if self.use_batch_norm:
+ self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size)
+ self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ residual = hidden_state
+ hidden_state = self.activation1(hidden_state)
+
+ hidden_state = self.convolution1(hidden_state)
+
+ if self.use_batch_norm:
+ hidden_state = self.batch_norm1(hidden_state)
+
+ hidden_state = self.activation2(hidden_state)
+ hidden_state = self.convolution2(hidden_state)
+
+ if self.use_batch_norm:
+ hidden_state = self.batch_norm2(hidden_state)
+
+ return hidden_state + residual
+
+
+class DPTFeatureFusionLayer(nn.Module):
+ """Feature fusion layer, merges feature maps from different stages.
+
+ Args:
+ config (`[DPTConfig]`):
+ Model configuration class defining the model architecture.
+ align_corners (`bool`, *optional*, defaults to `True`):
+ The align_corner setting for bilinear upsample.
+ """
+
+ def __init__(self, config: DPTConfig, align_corners: bool = True):
+ super().__init__()
+
+ self.align_corners = align_corners
+
+ self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
+
+ self.residual_layer1 = DPTPreActResidualLayer(config)
+ self.residual_layer2 = DPTPreActResidualLayer(config)
+
+ def forward(self, hidden_state: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if residual is not None:
+ if hidden_state.shape != residual.shape:
+ residual = nn.functional.interpolate(
+ residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False
+ )
+ hidden_state = hidden_state + self.residual_layer1(residual)
+
+ hidden_state = self.residual_layer2(hidden_state)
+ hidden_state = nn.functional.interpolate(
+ hidden_state, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+ hidden_state = self.projection(hidden_state)
+
+ return hidden_state
+
+
+@auto_docstring
+class DPTPreTrainedModel(PreTrainedModel):
+ config: DPTConfig
+ base_model_prefix = "dpt"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "attentions": DPTSelfAttention,
+ }
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, (DPTViTEmbeddings, DPTViTHybridEmbeddings)):
+ module.cls_token.data.zero_()
+ module.position_embeddings.data.zero_()
+
+
+@auto_docstring
+class DPTModel(DPTPreTrainedModel):
+ def __init__(self, config: DPTConfig, add_pooling_layer: bool = True):
+ r"""
+ add_pooling_layer (bool, *optional*, defaults to `True`):
+ Whether to add a pooling layer
+ """
+ super().__init__(config)
+ self.config = config
+
+ # vit encoder
+ if config.is_hybrid:
+ self.embeddings = DPTViTHybridEmbeddings(config)
+ else:
+ self.embeddings = DPTViTEmbeddings(config)
+ self.encoder = DPTViTEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = DPTViTPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ if self.config.is_hybrid:
+ return self.embeddings
+ else:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ **kwargs,
+ ) -> BaseModelOutputWithPoolingAndIntermediateActivations:
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output: BaseModelOutputWithIntermediateActivations = self.embeddings(pixel_values)
+ embedding_last_hidden_states = embedding_output.last_hidden_states
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ embedding_last_hidden_states, head_mask=head_mask, output_hidden_states=output_hidden_states
+ )
+ sequence_output = encoder_outputs.last_hidden_state
+
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ return BaseModelOutputWithPoolingAndIntermediateActivations(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ intermediate_activations=embedding_output.intermediate_activations,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViTConfig->DPTConfig, ViTPooler->DPTViTPooler
+class DPTViTPooler(nn.Module):
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
+ self.activation = ACT2FN[config.pooler_act]
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class DPTNeck(nn.Module):
+ """
+ DPTNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
+ input and produces another list of tensors as output. For DPT, it includes 2 stages:
+
+ * DPTReassembleStage
+ * DPTFeatureFusionStage.
+
+ Args:
+ config (dict): config dict.
+ """
+
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+ self.config = config
+
+ # postprocessing: only required in case of a non-hierarchical backbone (e.g. ViT, BEiT)
+ if config.backbone_config is not None and config.backbone_config.model_type == "swinv2":
+ self.reassemble_stage = None
+ else:
+ self.reassemble_stage = DPTReassembleStage(config)
+
+ self.convs = nn.ModuleList()
+ for channel in config.neck_hidden_sizes:
+ self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
+
+ # fusion
+ self.fusion_stage = DPTFeatureFusionStage(config)
+
+ def forward(
+ self,
+ hidden_states: list[torch.Tensor],
+ patch_height: Optional[int] = None,
+ patch_width: Optional[int] = None,
+ ) -> list[torch.Tensor]:
+ """
+ Args:
+ hidden_states (`list[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
+ List of hidden states from the backbone.
+ """
+ if not isinstance(hidden_states, (tuple, list)):
+ raise TypeError("hidden_states should be a tuple or list of tensors")
+
+ if len(hidden_states) != len(self.config.neck_hidden_sizes):
+ raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
+
+ # postprocess hidden states
+ if self.reassemble_stage is not None:
+ hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
+
+ features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
+
+ # fusion blocks
+ output = self.fusion_stage(features)
+
+ return output
+
+
+class DPTDepthEstimationHead(nn.Module):
+ """
+ Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
+ the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
+ supplementary material).
+ """
+
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+
+ self.config = config
+
+ self.projection = None
+ if config.add_projection:
+ self.projection = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+
+ features = config.fusion_hidden_size
+ self.head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(),
+ )
+
+ def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
+ # use last features
+ hidden_states = hidden_states[self.config.head_in_index]
+
+ if self.projection is not None:
+ hidden_states = self.projection(hidden_states)
+ hidden_states = nn.ReLU()(hidden_states)
+
+ predicted_depth = self.head(hidden_states)
+ predicted_depth = predicted_depth.squeeze(dim=1)
+
+ return predicted_depth
+
+
+@auto_docstring(
+ custom_intro="""
+ DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
+ """
+)
+class DPTForDepthEstimation(DPTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.backbone = None
+ if config.is_hybrid is False and (config.backbone_config is not None or config.backbone is not None):
+ self.backbone = load_backbone(config)
+ else:
+ self.dpt = DPTModel(config, add_pooling_layer=False)
+
+ # Neck
+ self.neck = DPTNeck(config)
+
+ # Depth estimation head
+ self.head = DPTDepthEstimationHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ **kwargs,
+ ) -> DepthEstimatorOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth depth estimation maps for computing the loss.
+
+ Examples:
+ ```python
+ >>> from transformers import AutoImageProcessor, DPTForDepthEstimation
+ >>> import torch
+ >>> import numpy as np
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large")
+ >>> model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
+
+ >>> # prepare image for the model
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> # interpolate to original size
+ >>> post_processed_output = image_processor.post_process_depth_estimation(
+ ... outputs,
+ ... target_sizes=[(image.height, image.width)],
+ ... )
+
+ >>> # visualize the prediction
+ >>> predicted_depth = post_processed_output[0]["predicted_depth"]
+ >>> depth = predicted_depth * 255 / predicted_depth.max()
+ >>> depth = depth.detach().cpu().numpy()
+ >>> depth = Image.fromarray(depth.astype("uint8"))
+ ```"""
+
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ loss = None
+ if labels is not None:
+ raise NotImplementedError("Training is not implemented yet")
+
+ if self.backbone is not None:
+ outputs = self.backbone.forward_with_filtered_kwargs(pixel_values, output_hidden_states=True, **kwargs)
+ hidden_states = outputs.feature_maps
+ else:
+ outputs = self.dpt(pixel_values, head_mask=head_mask, output_hidden_states=True, **kwargs)
+ hidden_states = outputs.hidden_states
+ # only keep certain features based on config.backbone_out_indices
+ # note that the hidden_states also include the initial embeddings
+ if not self.config.is_hybrid:
+ hidden_states = [
+ feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
+ ]
+ else:
+ backbone_hidden_states = outputs.intermediate_activations
+ backbone_hidden_states.extend(
+ feature
+ for idx, feature in enumerate(hidden_states[1:])
+ if idx in self.config.backbone_out_indices[2:]
+ )
+ hidden_states = backbone_hidden_states
+
+ patch_height, patch_width = None, None
+ if self.config.backbone_config is not None and self.config.is_hybrid is False:
+ _, _, height, width = pixel_values.shape
+ patch_size = self.config.backbone_config.patch_size
+ patch_height = height // patch_size
+ patch_width = width // patch_size
+
+ hidden_states = self.neck(hidden_states, patch_height, patch_width)
+ predicted_depth = self.head(hidden_states)
+
+ return DepthEstimatorOutput(
+ loss=loss,
+ predicted_depth=predicted_depth,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
+
+
+class DPTSemanticSegmentationHead(nn.Module):
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+
+ self.config = config
+ features = config.fusion_hidden_size
+ self.head = nn.Sequential(
+ nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(features),
+ nn.ReLU(),
+ nn.Dropout(config.semantic_classifier_dropout),
+ nn.Conv2d(features, config.num_labels, kernel_size=1),
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
+ )
+
+ def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
+ # use last features
+ hidden_states = hidden_states[self.config.head_in_index]
+ logits = self.head(hidden_states)
+ return logits
+
+
+class DPTAuxiliaryHead(nn.Module):
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+
+ features = config.fusion_hidden_size
+ self.head = nn.Sequential(
+ nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(features),
+ nn.ReLU(),
+ nn.Dropout(0.1, False),
+ nn.Conv2d(features, config.num_labels, kernel_size=1),
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ logits = self.head(hidden_states)
+ return logits
+
+
+@auto_docstring
+class DPTForSemanticSegmentation(DPTPreTrainedModel):
+ def __init__(self, config: DPTConfig):
+ super().__init__(config)
+
+ self.dpt = DPTModel(config, add_pooling_layer=False)
+
+ # Neck
+ self.neck = DPTNeck(config)
+
+ # Segmentation head(s)
+ self.head = DPTSemanticSegmentationHead(config)
+ self.auxiliary_head = DPTAuxiliaryHead(config) if config.use_auxiliary_head else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ **kwargs,
+ ) -> SemanticSegmenterOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
+
+ Examples:
+ ```python
+ >>> from transformers import AutoImageProcessor, DPTForSemanticSegmentation
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large-ade")
+ >>> model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ ```"""
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ if labels is not None and self.config.num_labels == 1:
+ raise ValueError("The number of labels should be greater than one")
+
+ outputs: BaseModelOutputWithPoolingAndIntermediateActivations = self.dpt(
+ pixel_values, head_mask=head_mask, output_hidden_states=True, **kwargs
+ )
+ hidden_states = outputs.hidden_states
+
+ # only keep certain features based on config.backbone_out_indices
+ # note that the hidden_states also include the initial embeddings
+ if not self.config.is_hybrid:
+ hidden_states = [
+ feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
+ ]
+ else:
+ backbone_hidden_states = outputs.intermediate_activations
+ backbone_hidden_states.extend(
+ feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices[2:]
+ )
+
+ hidden_states = backbone_hidden_states
+
+ hidden_states = self.neck(hidden_states=hidden_states)
+ logits = self.head(hidden_states)
+
+ auxiliary_logits = None
+ if self.auxiliary_head is not None:
+ auxiliary_logits = self.auxiliary_head(hidden_states[-1])
+
+ loss = None
+ if labels is not None:
+ # upsample logits to the images' original size
+ upsampled_logits = nn.functional.interpolate(
+ logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
+ )
+ if auxiliary_logits is not None:
+ upsampled_auxiliary_logits = nn.functional.interpolate(
+ auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
+ )
+ # compute weighted loss
+ loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
+ main_loss = loss_fct(upsampled_logits, labels)
+ auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
+ loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
+
+ return SemanticSegmenterOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["DPTForDepthEstimation", "DPTForSemanticSegmentation", "DPTModel", "DPTPreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/dpt/modular_dpt.py b/venv/lib/python3.13/site-packages/transformers/models/dpt/modular_dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..34eb08f39b684cfca10624b73a5669b9d9577632
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/dpt/modular_dpt.py
@@ -0,0 +1,299 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections.abc import Iterable
+from typing import TYPE_CHECKING, Optional, Union
+
+import torch
+
+from ...image_processing_base import BatchFeature
+from ...image_processing_utils_fast import BaseImageProcessorFast, DefaultFastImageProcessorKwargs
+from ...image_transforms import group_images_by_shape, reorder_images
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ PILImageResampling,
+ SizeDict,
+)
+from ...utils import (
+ TensorType,
+ auto_docstring,
+ requires_backends,
+)
+from ..beit.image_processing_beit_fast import BeitImageProcessorFast
+
+
+if TYPE_CHECKING:
+ from ...modeling_outputs import DepthEstimatorOutput
+
+from torchvision.transforms.v2 import functional as F
+
+
+def get_resize_output_image_size(
+ input_image: "torch.Tensor",
+ output_size: Union[int, Iterable[int]],
+ keep_aspect_ratio: bool,
+ multiple: int,
+) -> SizeDict:
+ def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None):
+ x = round(val / multiple) * multiple
+
+ if max_val is not None and x > max_val:
+ x = math.floor(val / multiple) * multiple
+
+ if x < min_val:
+ x = math.ceil(val / multiple) * multiple
+
+ return x
+
+ input_height, input_width = input_image.shape[-2:]
+ output_height, output_width = output_size
+
+ # determine new height and width
+ scale_height = output_height / input_height
+ scale_width = output_width / input_width
+
+ if keep_aspect_ratio:
+ # scale as little as possible
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+
+ new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
+ new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
+
+ return SizeDict(height=new_height, width=new_width)
+
+
+class DPTFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ ensure_multiple_of (`int`, *optional*, defaults to 1):
+ If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overridden
+ by `ensure_multiple_of` in `preprocess`.
+ size_divisor (`int`, *optional*):
+ If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the
+ DINOv2 paper, which uses the model in combination with DPT.
+ keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
+ If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
+ be overridden by `keep_aspect_ratio` in `preprocess`.
+ do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
+ is used for background, and background itself is not included in all classes of a dataset (e.g.
+ ADE20k). The background label will be replaced by 255.
+ """
+
+ ensure_multiple_of: Optional[int]
+ size_divisor: Optional[int]
+ keep_aspect_ratio: Optional[bool]
+ do_reduce_labels: Optional[bool]
+
+
+@auto_docstring
+class DPTImageProcessorFast(BeitImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ size = {"height": 384, "width": 384}
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_pad = False
+ rescale_factor = 1 / 255
+ ensure_multiple_of = 1
+ keep_aspect_ratio = False
+ do_reduce_labels = False
+ crop_size = None
+ do_center_crop = None
+ do_reduce_labels = None
+
+ valid_kwargs = DPTFastImageProcessorKwargs
+
+ def resize(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ antialias: bool = True,
+ ensure_multiple_of: Optional[int] = 1,
+ keep_aspect_ratio: bool = False,
+ ) -> "torch.Tensor":
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to resize.
+ size (`SizeDict`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
+ antialias (`bool`, *optional*, defaults to `True`):
+ Whether to use antialiasing when resizing the image
+ ensure_multiple_of (`int`, *optional*):
+ If `do_resize` is `True`, the image is resized to a size that is a multiple of this value
+ keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
+ If `True`, and `do_resize` is `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
+
+ Returns:
+ `torch.Tensor`: The resized image.
+ """
+ if not size.height or not size.width:
+ raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
+
+ output_size = get_resize_output_image_size(
+ image,
+ output_size=(size.height, size.width),
+ keep_aspect_ratio=keep_aspect_ratio,
+ multiple=ensure_multiple_of,
+ )
+ return BaseImageProcessorFast.resize(
+ self, image, output_size, interpolation=interpolation, antialias=antialias
+ )
+
+ def pad_image(
+ self,
+ image: "torch.Tensor",
+ size_divisor: int = 1,
+ ) -> "torch.Tensor":
+ r"""
+ Center pad a batch of images to be a multiple of `size_divisor`.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to pad. Can be a batch of images of dimensions (N, C, H, W) or a single image of dimensions (C, H, W).
+ size_divisor (`int`):
+ The width and height of the image will be padded to a multiple of this number.
+ """
+ height, width = image.shape[-2:]
+
+ def _get_pad(size, size_divisor):
+ new_size = math.ceil(size / size_divisor) * size_divisor
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ pad_top, pad_bottom = _get_pad(height, size_divisor)
+ pad_left, pad_right = _get_pad(width, size_divisor)
+ padding = (pad_left, pad_top, pad_right, pad_bottom)
+ return F.pad(image, padding)
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_reduce_labels: bool,
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ keep_aspect_ratio: bool,
+ ensure_multiple_of: Optional[int],
+ do_pad: bool,
+ size_divisor: Optional[int],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ if do_reduce_labels:
+ images = self.reduce_label(images)
+
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(
+ image=stacked_images,
+ size=size,
+ interpolation=interpolation,
+ ensure_multiple_of=ensure_multiple_of,
+ keep_aspect_ratio=keep_aspect_ratio,
+ )
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_center_crop:
+ stacked_images = self.center_crop(stacked_images, crop_size)
+ if do_pad:
+ stacked_images = self.pad_image(stacked_images, size_divisor)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+ return BatchFeature(data={"pixel_values": processed_images})
+
+ def post_process_depth_estimation(
+ self,
+ outputs: "DepthEstimatorOutput",
+ target_sizes: Optional[Union[TensorType, list[tuple[int, int]], None]] = None,
+ ) -> list[dict[str, TensorType]]:
+ """
+ Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
+ Only supports PyTorch.
+
+ Args:
+ outputs ([`DepthEstimatorOutput`]):
+ Raw outputs of the model.
+ target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
+ (height, width) of each image in the batch. If left to None, predictions will not be resized.
+
+ Returns:
+ `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
+ predictions.
+ """
+ requires_backends(self, "torch")
+
+ predicted_depth = outputs.predicted_depth
+
+ if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
+ )
+
+ results = []
+ target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
+ for depth, target_size in zip(predicted_depth, target_sizes):
+ if target_size is not None:
+ depth = torch.nn.functional.interpolate(
+ depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
+ ).squeeze()
+
+ results.append({"predicted_depth": depth})
+
+ return results
+
+
+__all__ = ["DPTImageProcessorFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/electra/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/electra/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a78ed5c42aea51038335efabde5b03e333592ed6
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/electra/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_electra import *
+ from .modeling_electra import *
+ from .modeling_flax_electra import *
+ from .modeling_tf_electra import *
+ from .tokenization_electra import *
+ from .tokenization_electra_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/electra/configuration_electra.py b/venv/lib/python3.13/site-packages/transformers/models/electra/configuration_electra.py
new file mode 100644
index 0000000000000000000000000000000000000000..f12756d976b35ee3a4f333483b1b4e6e1a07fb7e
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/electra/configuration_electra.py
@@ -0,0 +1,187 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ELECTRA model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class ElectraConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ElectraModel`] or a [`TFElectraModel`]. It is
+ used to instantiate a ELECTRA model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the ELECTRA
+ [google/electra-small-discriminator](https://huggingface.co/google/electra-small-discriminator) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the ELECTRA model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ElectraModel`] or [`TFElectraModel`].
+ embedding_size (`int`, *optional*, defaults to 128):
+ Dimensionality of the encoder layers and the pooler layer.
+ hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 4):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`ElectraModel`] or [`TFElectraModel`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ summary_type (`str`, *optional*, defaults to `"first"`):
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
+
+ Has to be one of the following options:
+
+ - `"last"`: Take the last token hidden state (like XLNet).
+ - `"first"`: Take the first token hidden state (like BERT).
+ - `"mean"`: Take the mean of all tokens hidden states.
+ - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
+ - `"attn"`: Not implemented now, use multi-head attention.
+ summary_use_proj (`bool`, *optional*, defaults to `True`):
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
+
+ Whether or not to add a projection after the vector extraction.
+ summary_activation (`str`, *optional*):
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
+
+ Pass `"gelu"` for a gelu activation to the output, any other value will result in no activation.
+ summary_last_dropout (`float`, *optional*, defaults to 0.0):
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
+
+ The dropout ratio to be used after the projection and activation.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ classifier_dropout (`float`, *optional*):
+ The dropout ratio for the classification head.
+
+ Examples:
+
+ ```python
+ >>> from transformers import ElectraConfig, ElectraModel
+
+ >>> # Initializing a ELECTRA electra-base-uncased style configuration
+ >>> configuration = ElectraConfig()
+
+ >>> # Initializing a model (with random weights) from the electra-base-uncased style configuration
+ >>> model = ElectraModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "electra"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ embedding_size=128,
+ hidden_size=256,
+ num_hidden_layers=12,
+ num_attention_heads=4,
+ intermediate_size=1024,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ summary_type="first",
+ summary_use_proj=True,
+ summary_activation="gelu",
+ summary_last_dropout=0.1,
+ pad_token_id=0,
+ position_embedding_type="absolute",
+ use_cache=True,
+ classifier_dropout=None,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.embedding_size = embedding_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+
+ self.summary_type = summary_type
+ self.summary_use_proj = summary_use_proj
+ self.summary_activation = summary_activation
+ self.summary_last_dropout = summary_last_dropout
+ self.position_embedding_type = position_embedding_type
+ self.use_cache = use_cache
+ self.classifier_dropout = classifier_dropout
+
+
+class ElectraOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ("token_type_ids", dynamic_axis),
+ ]
+ )
+
+
+__all__ = ["ElectraConfig", "ElectraOnnxConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/electra/modeling_electra.py b/venv/lib/python3.13/site-packages/transformers/models/electra/modeling_electra.py
new file mode 100644
index 0000000000000000000000000000000000000000..a10b0b6583374d8b620fb58b46d6e8f0ad469c1e
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/electra/modeling_electra.py
@@ -0,0 +1,1586 @@
+# coding=utf-8
+# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ELECTRA model."""
+
+import math
+import os
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN, get_activation
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import GenerationMixin
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithCrossAttentions,
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import ModelOutput, auto_docstring, logging
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_electra import ElectraConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_or_generator="discriminator"):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+ for name, array in zip(names, arrays):
+ original_name: str = name
+
+ try:
+ if isinstance(model, ElectraForMaskedLM):
+ name = name.replace("electra/embeddings/", "generator/embeddings/")
+
+ if discriminator_or_generator == "generator":
+ name = name.replace("electra/", "discriminator/")
+ name = name.replace("generator/", "electra/")
+
+ name = name.replace("dense_1", "dense_prediction")
+ name = name.replace("generator_predictions/output_bias", "generator_lm_head/bias")
+
+ name = name.split("/")
+ # print(original_name, name)
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(n in ["global_step", "temperature"] for n in name):
+ logger.info(f"Skipping {original_name}")
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ pointer = getattr(pointer, scope_names[0])
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name.endswith("_embeddings"):
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+ except ValueError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ print(f"Initialize PyTorch weight {name}", original_name)
+ pointer.data = torch.from_numpy(array)
+ except AttributeError as e:
+ print(f"Skipping {original_name}", name, e)
+ continue
+ return model
+
+
+class ElectraEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer(
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
+ )
+
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values_length: int = 0,
+ ) -> torch.Tensor:
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Electra
+class ElectraSelfAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None, layer_idx=None):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = position_embedding_type or getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+ self.layer_idx = layer_idx
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor]:
+ batch_size, seq_length, _ = hidden_states.shape
+ query_layer = self.query(hidden_states)
+ query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
+ 1, 2
+ )
+
+ is_updated = False
+ is_cross_attention = encoder_hidden_states is not None
+ if past_key_values is not None:
+ if isinstance(past_key_values, EncoderDecoderCache):
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_layer from cache
+ curr_past_key_value = past_key_values.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_values.self_attention_cache
+ else:
+ curr_past_key_value = past_key_values
+
+ current_states = encoder_hidden_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_values is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_layer = curr_past_key_value.layers[self.layer_idx].keys
+ value_layer = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_layer = self.key(current_states)
+ key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
+ 1, 2
+ )
+ value_layer = self.value(current_states)
+ value_layer = value_layer.view(
+ batch_size, -1, self.num_attention_heads, self.attention_head_size
+ ).transpose(1, 2)
+
+ if past_key_values is not None:
+ # save all key/value_layer to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_layer, value_layer = curr_past_key_value.update(
+ key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
+ past_key_values.is_updated[self.layer_idx] = True
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
+ if past_key_values is not None:
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
+ -1, 1
+ )
+ else:
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in ElectraModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ return context_layer, attention_probs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
+class ElectraSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+ELECTRA_SELF_ATTENTION_CLASSES = {
+ "eager": ElectraSelfAttention,
+}
+
+
+# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra,BERT->ELECTRA
+class ElectraAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None, layer_idx=None):
+ super().__init__()
+ self.self = ELECTRA_SELF_ATTENTION_CLASSES[config._attn_implementation](
+ config,
+ position_embedding_type=position_embedding_type,
+ layer_idx=layer_idx,
+ )
+ self.output = ElectraSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor]:
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate
+class ElectraIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOutput
+class ElectraOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Electra
+class ElectraLayer(GradientCheckpointingLayer):
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = ElectraAttention(config, layer_idx=layer_idx)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = ElectraAttention(config, position_embedding_type="absolute", layer_idx=layer_idx)
+ self.intermediate = ElectraIntermediate(config)
+ self.output = ElectraOutput(config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor]:
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Electra
+class ElectraEncoder(nn.Module):
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([ElectraLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if use_cache and self.config.is_decoder and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+
+ if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ past_key_values,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class ElectraDiscriminatorPredictions(nn.Module):
+ """Prediction module for the discriminator, made up of two dense layers."""
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = get_activation(config.hidden_act)
+ self.dense_prediction = nn.Linear(config.hidden_size, 1)
+ self.config = config
+
+ def forward(self, discriminator_hidden_states):
+ hidden_states = self.dense(discriminator_hidden_states)
+ hidden_states = self.activation(hidden_states)
+ logits = self.dense_prediction(hidden_states).squeeze(-1)
+
+ return logits
+
+
+class ElectraGeneratorPredictions(nn.Module):
+ """Prediction module for the generator, made up of two dense layers."""
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.activation = get_activation("gelu")
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
+ self.dense = nn.Linear(config.hidden_size, config.embedding_size)
+
+ def forward(self, generator_hidden_states):
+ hidden_states = self.dense(generator_hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+
+ return hidden_states
+
+
+@auto_docstring
+class ElectraPreTrainedModel(PreTrainedModel):
+ config: ElectraConfig
+ load_tf_weights = load_tf_weights_in_electra
+ base_model_prefix = "electra"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`ElectraForPreTraining`].
+ """
+)
+class ElectraForPreTrainingOutput(ModelOutput):
+ r"""
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
+ Total loss of the ELECTRA objective.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Prediction scores of the head (scores for each token before SoftMax).
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+@auto_docstring
+class ElectraModel(ElectraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.embeddings = ElectraEmbeddings(config)
+
+ if config.embedding_size != config.hidden_size:
+ self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size)
+
+ self.encoder = ElectraEncoder(config)
+ self.config = config
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ past_key_values_length = 0
+ if past_key_values is not None:
+ past_key_values_length = (
+ past_key_values[0][0].shape[-2]
+ if not isinstance(past_key_values, Cache)
+ else past_key_values.get_seq_length()
+ )
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ hidden_states = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if hasattr(self, "embeddings_project"):
+ hidden_states = self.embeddings_project(hidden_states)
+
+ hidden_states = self.encoder(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ return hidden_states
+
+
+class ElectraClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.activation = get_activation("gelu")
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(self, features, **kwargs):
+ x = features[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = self.activation(x) # although BERT uses tanh here, it seems Electra authors used gelu here
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Electra
+class ElectraSequenceSummary(nn.Module):
+ r"""
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ config ([`ElectraConfig`]):
+ The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
+ config class of your model for the default values it uses):
+
+ - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
+
+ - `"last"` -- Take the last token hidden state (like XLNet)
+ - `"first"` -- Take the first token hidden state (like Bert)
+ - `"mean"` -- Take the mean of all tokens hidden states
+ - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
+ - `"attn"` -- Not implemented now, use multi-head attention
+
+ - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
+ - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
+ (otherwise to `config.hidden_size`).
+ - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
+ another string or `None` will add no activation.
+ - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
+ - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
+ """
+
+ def __init__(self, config: ElectraConfig):
+ super().__init__()
+
+ self.summary_type = getattr(config, "summary_type", "last")
+ if self.summary_type == "attn":
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
+ raise NotImplementedError
+
+ self.summary = nn.Identity()
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
+ num_classes = config.num_labels
+ else:
+ num_classes = config.hidden_size
+ self.summary = nn.Linear(config.hidden_size, num_classes)
+
+ activation_string = getattr(config, "summary_activation", None)
+ self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
+
+ self.first_dropout = nn.Identity()
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
+
+ self.last_dropout = nn.Identity()
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
+ ) -> torch.FloatTensor:
+ """
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
+ The hidden states of the last layer.
+ cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
+ Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
+
+ Returns:
+ `torch.FloatTensor`: The summary of the sequence hidden states.
+ """
+ if self.summary_type == "last":
+ output = hidden_states[:, -1]
+ elif self.summary_type == "first":
+ output = hidden_states[:, 0]
+ elif self.summary_type == "mean":
+ output = hidden_states.mean(dim=1)
+ elif self.summary_type == "cls_index":
+ if cls_index is None:
+ cls_index = torch.full_like(
+ hidden_states[..., :1, :],
+ hidden_states.shape[-2] - 1,
+ dtype=torch.long,
+ )
+ else:
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
+ elif self.summary_type == "attn":
+ raise NotImplementedError
+
+ output = self.first_dropout(output)
+ output = self.summary(output)
+ output = self.activation(output)
+ output = self.last_dropout(output)
+
+ return output
+
+
+@auto_docstring(
+ custom_intro="""
+ ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """
+)
+class ElectraForSequenceClassification(ElectraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.electra = ElectraModel(config)
+ self.classifier = ElectraClassificationHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ discriminator_hidden_states = self.electra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = discriminator_hidden_states[0]
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + discriminator_hidden_states[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.
+
+ It is recommended to load the discriminator checkpoint into that model.
+ """
+)
+class ElectraForPreTraining(ElectraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.electra = ElectraModel(config)
+ self.discriminator_predictions = ElectraDiscriminatorPredictions(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], ElectraForPreTrainingOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see `input_ids` docstring)
+ Indices should be in `[0, 1]`:
+
+ - 0 indicates the token is an original token,
+ - 1 indicates the token was replaced.
+
+ Examples:
+
+ ```python
+ >>> from transformers import ElectraForPreTraining, AutoTokenizer
+ >>> import torch
+
+ >>> discriminator = ElectraForPreTraining.from_pretrained("google/electra-base-discriminator")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator")
+
+ >>> sentence = "The quick brown fox jumps over the lazy dog"
+ >>> fake_sentence = "The quick brown fox fake over the lazy dog"
+
+ >>> fake_tokens = tokenizer.tokenize(fake_sentence, add_special_tokens=True)
+ >>> fake_inputs = tokenizer.encode(fake_sentence, return_tensors="pt")
+ >>> discriminator_outputs = discriminator(fake_inputs)
+ >>> predictions = torch.round((torch.sign(discriminator_outputs[0]) + 1) / 2)
+
+ >>> fake_tokens
+ ['[CLS]', 'the', 'quick', 'brown', 'fox', 'fake', 'over', 'the', 'lazy', 'dog', '[SEP]']
+
+ >>> predictions.squeeze().tolist()
+ [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ discriminator_hidden_states = self.electra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ discriminator_sequence_output = discriminator_hidden_states[0]
+
+ logits = self.discriminator_predictions(discriminator_sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = nn.BCEWithLogitsLoss()
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1, discriminator_sequence_output.shape[1]) == 1
+ active_logits = logits.view(-1, discriminator_sequence_output.shape[1])[active_loss]
+ active_labels = labels[active_loss]
+ loss = loss_fct(active_logits, active_labels.float())
+ else:
+ loss = loss_fct(logits.view(-1, discriminator_sequence_output.shape[1]), labels.float())
+
+ if not return_dict:
+ output = (logits,) + discriminator_hidden_states[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ElectraForPreTrainingOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Electra model with a language modeling head on top.
+
+ Even though both the discriminator and generator may be loaded into this model, the generator is the only model of
+ the two to have been trained for the masked language modeling task.
+ """
+)
+class ElectraForMaskedLM(ElectraPreTrainedModel):
+ _tied_weights_keys = ["generator_lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.electra = ElectraModel(config)
+ self.generator_predictions = ElectraGeneratorPredictions(config)
+
+ self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.generator_lm_head
+
+ def set_output_embeddings(self, word_embeddings):
+ self.generator_lm_head = word_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ generator_hidden_states = self.electra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ generator_sequence_output = generator_hidden_states[0]
+
+ prediction_scores = self.generator_predictions(generator_sequence_output)
+ prediction_scores = self.generator_lm_head(prediction_scores)
+
+ loss = None
+ # Masked language modeling softmax layer
+ if labels is not None:
+ loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
+ loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + generator_hidden_states[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MaskedLMOutput(
+ loss=loss,
+ logits=prediction_scores,
+ hidden_states=generator_hidden_states.hidden_states,
+ attentions=generator_hidden_states.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Electra model with a token classification head on top.
+
+ Both the discriminator and generator may be loaded into this model.
+ """
+)
+class ElectraForTokenClassification(ElectraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.electra = ElectraModel(config)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ discriminator_hidden_states = self.electra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ discriminator_sequence_output = discriminator_hidden_states[0]
+
+ discriminator_sequence_output = self.dropout(discriminator_sequence_output)
+ logits = self.classifier(discriminator_sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + discriminator_hidden_states[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+
+@auto_docstring
+class ElectraForQuestionAnswering(ElectraPreTrainedModel):
+ config: ElectraConfig
+ base_model_prefix = "electra"
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.electra = ElectraModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ discriminator_hidden_states = self.electra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ sequence_output = discriminator_hidden_states[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (
+ start_logits,
+ end_logits,
+ ) + discriminator_hidden_states[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+
+@auto_docstring
+class ElectraForMultipleChoice(ElectraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.electra = ElectraModel(config)
+ self.sequence_summary = ElectraSequenceSummary(config)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ discriminator_hidden_states = self.electra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = discriminator_hidden_states[0]
+
+ pooled_output = self.sequence_summary(sequence_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + discriminator_hidden_states[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.
+ """
+)
+class ElectraForCausalLM(ElectraPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["generator_lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if not config.is_decoder:
+ logger.warning("If you want to use `ElectraForCausalLM` as a standalone, add `is_decoder=True.`")
+
+ self.electra = ElectraModel(config)
+ self.generator_predictions = ElectraGeneratorPredictions(config)
+ self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.generator_lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.generator_lm_head = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, ElectraForCausalLM, ElectraConfig
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-base-generator")
+ >>> config = ElectraConfig.from_pretrained("google/electra-base-generator")
+ >>> config.is_decoder = True
+ >>> model = ElectraForCausalLM.from_pretrained("google/electra-base-generator", config=config)
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.electra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.generator_lm_head(self.generator_predictions(sequence_output))
+
+ lm_loss = None
+ if labels is not None:
+ lm_loss = self.loss_function(
+ prediction_scores,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[1:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+__all__ = [
+ "ElectraForCausalLM",
+ "ElectraForMaskedLM",
+ "ElectraForMultipleChoice",
+ "ElectraForPreTraining",
+ "ElectraForQuestionAnswering",
+ "ElectraForSequenceClassification",
+ "ElectraForTokenClassification",
+ "ElectraModel",
+ "ElectraPreTrainedModel",
+ "load_tf_weights_in_electra",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/electra/modeling_flax_electra.py b/venv/lib/python3.13/site-packages/transformers/models/electra/modeling_flax_electra.py
new file mode 100644
index 0000000000000000000000000000000000000000..14d845476d62f9defb2de4392742037762fb959f
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/electra/modeling_flax_electra.py
@@ -0,0 +1,1614 @@
+# coding=utf-8
+# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional
+
+import flax
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
+from flax.linen import partitioning as nn_partitioning
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+
+from ...modeling_flax_outputs import (
+ FlaxBaseModelOutput,
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
+ FlaxCausalLMOutputWithCrossAttentions,
+ FlaxMaskedLMOutput,
+ FlaxMultipleChoiceModelOutput,
+ FlaxQuestionAnsweringModelOutput,
+ FlaxSequenceClassifierOutput,
+ FlaxTokenClassifierOutput,
+)
+from ...modeling_flax_utils import (
+ ACT2FN,
+ FlaxPreTrainedModel,
+ append_call_sample_docstring,
+ append_replace_return_docstrings,
+ overwrite_call_docstring,
+)
+from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from .configuration_electra import ElectraConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator"
+_CONFIG_FOR_DOC = "ElectraConfig"
+
+remat = nn_partitioning.remat
+
+
+@flax.struct.dataclass
+class FlaxElectraForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`ElectraForPreTraining`].
+
+ Args:
+ logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ logits: jnp.ndarray = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+ attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+ELECTRA_START_DOCSTRING = r"""
+
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`ElectraConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ELECTRA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`numpy.ndarray` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ head_mask (`numpy.ndarray` of shape `({0})`, `optional):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+"""
+
+
+class FlaxElectraEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.word_embeddings = nn.Embed(
+ self.config.vocab_size,
+ self.config.embedding_size,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.position_embeddings = nn.Embed(
+ self.config.max_position_embeddings,
+ self.config.embedding_size,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.token_type_embeddings = nn.Embed(
+ self.config.type_vocab_size,
+ self.config.embedding_size,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
+
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings.__call__
+ def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
+ # Embed
+ inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
+ position_embeds = self.position_embeddings(position_ids.astype("i4"))
+ token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
+
+ # Sum all embeddings
+ hidden_states = inputs_embeds + token_type_embeddings + position_embeds
+
+ # Layer Norm
+ hidden_states = self.LayerNorm(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Electra
+class FlaxElectraSelfAttention(nn.Module):
+ config: ElectraConfig
+ causal: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.head_dim = self.config.hidden_size // self.config.num_attention_heads
+ if self.config.hidden_size % self.config.num_attention_heads != 0:
+ raise ValueError(
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
+ " : {self.config.num_attention_heads}"
+ )
+
+ self.query = nn.Dense(
+ self.config.hidden_size,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ )
+ self.key = nn.Dense(
+ self.config.hidden_size,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ )
+ self.value = nn.Dense(
+ self.config.hidden_size,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ )
+
+ if self.causal:
+ self.causal_mask = make_causal_mask(
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
+ )
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
+
+ @nn.compact
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slightly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ key_value_states: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic=True,
+ output_attentions: bool = False,
+ ):
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ batch_size = hidden_states.shape[0]
+
+ # get query proj
+ query_states = self.query(hidden_states)
+ # get key, value proj
+ if is_cross_attention:
+ # cross_attentions
+ key_states = self.key(key_value_states)
+ value_states = self.value(key_value_states)
+ else:
+ # self_attention
+ key_states = self.key(hidden_states)
+ value_states = self.value(hidden_states)
+
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # handle cache prepare causal attention mask
+ if self.causal:
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
+ if self.has_variable("cache", "cached_key"):
+ mask_shift = self.variables["cache"]["cache_index"]
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_mask = lax.dynamic_slice(
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
+ )
+ else:
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+ # combine masks if needed
+ if attention_mask is not None and self.causal:
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_mask)
+ elif self.causal:
+ attention_mask = causal_mask
+ elif attention_mask is not None:
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
+ key_states, value_states, query_states, attention_mask
+ )
+
+ # Convert the boolean attention mask to an attention bias.
+ if attention_mask is not None:
+ # attention mask in the form of attention bias
+ attention_bias = lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
+ )
+ else:
+ attention_bias = None
+
+ dropout_rng = None
+ if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ bias=attention_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.config.attention_probs_dropout_prob,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ precision=None,
+ )
+
+ # Mask heads if we want to
+ if layer_head_mask is not None:
+ attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
+
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+ attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
+
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Electra
+class FlaxElectraSelfOutput(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.dense = nn.Dense(
+ self.config.hidden_size,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ dtype=self.dtype,
+ )
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
+
+ def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Electra
+class FlaxElectraAttention(nn.Module):
+ config: ElectraConfig
+ causal: bool = False
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.self = FlaxElectraSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
+ self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ key_value_states=None,
+ init_cache=False,
+ deterministic=True,
+ output_attentions: bool = False,
+ ):
+ # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
+ # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
+ # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
+ attn_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=layer_head_mask,
+ key_value_states=key_value_states,
+ init_cache=init_cache,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ )
+ attn_output = attn_outputs[0]
+ hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_outputs[1],)
+
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Electra
+class FlaxElectraIntermediate(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.dense = nn.Dense(
+ self.config.intermediate_size,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ dtype=self.dtype,
+ )
+ self.activation = ACT2FN[self.config.hidden_act]
+
+ def __call__(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Electra
+class FlaxElectraOutput(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.dense = nn.Dense(
+ self.config.hidden_size,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ dtype=self.dtype,
+ )
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+
+ def __call__(self, hidden_states, attention_output, deterministic: bool = True):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ hidden_states = self.LayerNorm(hidden_states + attention_output)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Electra
+class FlaxElectraLayer(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.attention = FlaxElectraAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
+ self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype)
+ self.output = FlaxElectraOutput(self.config, dtype=self.dtype)
+ if self.config.add_cross_attention:
+ self.crossattention = FlaxElectraAttention(self.config, causal=False, dtype=self.dtype)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ ):
+ # Self Attention
+ attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=layer_head_mask,
+ init_cache=init_cache,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ )
+ attention_output = attention_outputs[0]
+
+ # Cross-Attention Block
+ if encoder_hidden_states is not None:
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=layer_head_mask,
+ key_value_states=encoder_hidden_states,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+
+ hidden_states = self.intermediate(attention_output)
+ hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attention_outputs[1],)
+ if encoder_hidden_states is not None:
+ outputs += (cross_attention_outputs[1],)
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Electra
+class FlaxElectraLayerCollection(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ if self.gradient_checkpointing:
+ FlaxElectraCheckpointLayer = remat(FlaxElectraLayer, static_argnums=(5, 6, 7))
+ self.layers = [
+ FlaxElectraCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
+ for i in range(self.config.num_hidden_layers)
+ ]
+ else:
+ self.layers = [
+ FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype)
+ for i in range(self.config.num_hidden_layers)
+ ]
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ # Check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.shape[0] != (len(self.layers)):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
+ f" {head_mask.shape[0]}."
+ )
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask,
+ head_mask[i] if head_mask is not None else None,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ init_cache,
+ deterministic,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
+
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Electra
+class FlaxElectraEncoder(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.layer = FlaxElectraLayerCollection(
+ self.config,
+ dtype=self.dtype,
+ gradient_checkpointing=self.gradient_checkpointing,
+ )
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ return self.layer(
+ hidden_states,
+ attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+class FlaxElectraGeneratorPredictions(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+ self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype)
+
+ def __call__(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class FlaxElectraDiscriminatorPredictions(nn.Module):
+ """Prediction module for the discriminator, made up of two dense layers."""
+
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
+ self.dense_prediction = nn.Dense(1, dtype=self.dtype)
+
+ def __call__(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
+ hidden_states = self.dense_prediction(hidden_states).squeeze(-1)
+ return hidden_states
+
+
+class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ElectraConfig
+ base_model_prefix = "electra"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: ElectraConfig,
+ input_shape: tuple = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ gradient_checkpointing: bool = False,
+ **kwargs,
+ ):
+ module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
+ def enable_gradient_checkpointing(self):
+ self._module = self.module_class(
+ config=self.config,
+ dtype=self.dtype,
+ gradient_checkpointing=True,
+ )
+
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ token_type_ids = jnp.zeros_like(input_ids)
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
+ attention_mask = jnp.ones_like(input_ids)
+ head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ if self.config.add_cross_attention:
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
+ encoder_attention_mask = attention_mask
+ module_init_outputs = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ return_dict=False,
+ )
+ else:
+ module_init_outputs = self.module.init(
+ rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
+ )
+
+ random_params = module_init_outputs["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
+ def init_cache(self, batch_size, max_length):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ """
+ # init input variables to retrieve cache
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
+ )
+ return unfreeze(init_variables["cache"])
+
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ params: Optional[dict] = None,
+ dropout_rng: jax.random.PRNGKey = None,
+ train: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ past_key_values: Optional[dict] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ # init input tensors if not passed
+ if token_type_ids is None:
+ token_type_ids = jnp.ones_like(input_ids)
+
+ if position_ids is None:
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+
+ if head_mask is None:
+ head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ if self.config.add_cross_attention:
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
+ # changed by FlaxElectraAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ head_mask=jnp.array(head_mask, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ deterministic=not train,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ rngs=rngs,
+ mutable=mutable,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past_key_values = outputs
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past_key_values = outputs
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
+
+ else:
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ head_mask=jnp.array(head_mask, dtype="i4"),
+ deterministic=not train,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ rngs=rngs,
+ )
+
+ return outputs
+
+
+class FlaxElectraModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype)
+ if self.config.embedding_size != self.config.hidden_size:
+ self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype)
+ self.encoder = FlaxElectraEncoder(
+ self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask: Optional[np.ndarray] = None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ embeddings = self.embeddings(
+ input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
+ )
+ if hasattr(self, "embeddings_project"):
+ embeddings = self.embeddings_project(embeddings)
+
+ return self.encoder(
+ embeddings,
+ attention_mask,
+ head_mask=head_mask,
+ deterministic=deterministic,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+@add_start_docstrings(
+ "The bare Electra Model transformer outputting raw hidden-states without any specific head on top.",
+ ELECTRA_START_DOCSTRING,
+)
+class FlaxElectraModel(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraModule
+
+
+append_call_sample_docstring(FlaxElectraModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)
+
+
+class FlaxElectraTiedDense(nn.Module):
+ embedding_size: int
+ dtype: jnp.dtype = jnp.float32
+ precision = None
+ bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
+
+ def setup(self):
+ self.bias = self.param("bias", self.bias_init, (self.embedding_size,))
+
+ def __call__(self, x, kernel):
+ x = jnp.asarray(x, self.dtype)
+ kernel = jnp.asarray(kernel, self.dtype)
+ y = lax.dot_general(
+ x,
+ kernel,
+ (((x.ndim - 1,), (0,)), ((), ())),
+ precision=self.precision,
+ )
+ bias = jnp.asarray(self.bias, self.dtype)
+ return y + bias
+
+
+class FlaxElectraForMaskedLMModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.electra = FlaxElectraModule(
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+ self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
+ if self.config.tie_word_embeddings:
+ self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
+ else:
+ self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+ prediction_scores = self.generator_predictions(hidden_states)
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
+ prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T)
+ else:
+ prediction_scores = self.generator_lm_head(prediction_scores)
+
+ if not return_dict:
+ return (prediction_scores,) + outputs[1:]
+
+ return FlaxMaskedLMOutput(
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings("""Electra Model with a `language modeling` head on top.""", ELECTRA_START_DOCSTRING)
+class FlaxElectraForMaskedLM(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForMaskedLMModule
+
+
+append_call_sample_docstring(FlaxElectraForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC)
+
+
+class FlaxElectraForPreTrainingModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.electra = FlaxElectraModule(
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+ self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+
+ logits = self.discriminator_predictions(hidden_states)
+
+ if not return_dict:
+ return (logits,) + outputs[1:]
+
+ return FlaxElectraForPreTrainingOutput(
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.
+
+ It is recommended to load the discriminator checkpoint into that model.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class FlaxElectraForPreTraining(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForPreTrainingModule
+
+
+FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING = """
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, FlaxElectraForPreTraining
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")
+ >>> model = FlaxElectraForPreTraining.from_pretrained("google/electra-small-discriminator")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.logits
+ ```
+"""
+
+overwrite_call_docstring(
+ FlaxElectraForPreTraining,
+ ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING,
+)
+append_replace_return_docstrings(
+ FlaxElectraForPreTraining, output_type=FlaxElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
+)
+
+
+class FlaxElectraForTokenClassificationModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.electra = FlaxElectraModule(
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+ classifier_dropout = (
+ self.config.classifier_dropout
+ if self.config.classifier_dropout is not None
+ else self.config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ logits = self.classifier(hidden_states)
+
+ if not return_dict:
+ return (logits,) + outputs[1:]
+
+ return FlaxTokenClassifierOutput(
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Electra model with a token classification head on top.
+
+ Both the discriminator and generator may be loaded into this model.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class FlaxElectraForTokenClassification(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForTokenClassificationModule
+
+
+append_call_sample_docstring(
+ FlaxElectraForTokenClassification,
+ _CHECKPOINT_FOR_DOC,
+ FlaxTokenClassifierOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+def identity(x, **kwargs):
+ return x
+
+
+class FlaxElectraSequenceSummary(nn.Module):
+ r"""
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ config ([`PretrainedConfig`]):
+ The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
+ config class of your model for the default values it uses):
+
+ - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
+ - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
+ (otherwise to `config.hidden_size`).
+ - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
+ another string or `None` will add no activation.
+ - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
+ - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
+ """
+
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.summary = identity
+ if hasattr(self.config, "summary_use_proj") and self.config.summary_use_proj:
+ if (
+ hasattr(self.config, "summary_proj_to_labels")
+ and self.config.summary_proj_to_labels
+ and self.config.num_labels > 0
+ ):
+ num_classes = self.config.num_labels
+ else:
+ num_classes = self.config.hidden_size
+ self.summary = nn.Dense(num_classes, dtype=self.dtype)
+
+ activation_string = getattr(self.config, "summary_activation", None)
+ self.activation = ACT2FN[activation_string] if activation_string else lambda x: x # noqa F407
+
+ self.first_dropout = identity
+ if hasattr(self.config, "summary_first_dropout") and self.config.summary_first_dropout > 0:
+ self.first_dropout = nn.Dropout(self.config.summary_first_dropout)
+
+ self.last_dropout = identity
+ if hasattr(self.config, "summary_last_dropout") and self.config.summary_last_dropout > 0:
+ self.last_dropout = nn.Dropout(self.config.summary_last_dropout)
+
+ def __call__(self, hidden_states, cls_index=None, deterministic: bool = True):
+ """
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ hidden_states (`jnp.ndarray` of shape `[batch_size, seq_len, hidden_size]`):
+ The hidden states of the last layer.
+ cls_index (`jnp.ndarray` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
+ Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
+
+ Returns:
+ `jnp.ndarray`: The summary of the sequence hidden states.
+ """
+ # NOTE: this does "first" type summary always
+ output = hidden_states[:, 0]
+ output = self.first_dropout(output, deterministic=deterministic)
+ output = self.summary(output)
+ output = self.activation(output)
+ output = self.last_dropout(output, deterministic=deterministic)
+ return output
+
+
+class FlaxElectraForMultipleChoiceModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.electra = FlaxElectraModule(
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+ self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype)
+ self.classifier = nn.Dense(1, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ num_choices = input_ids.shape[1]
+ input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
+ attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
+ position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
+
+ # Model
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+ pooled_output = self.sequence_summary(hidden_states, deterministic=deterministic)
+ logits = self.classifier(pooled_output)
+
+ reshaped_logits = logits.reshape(-1, num_choices)
+
+ if not return_dict:
+ return (reshaped_logits,) + outputs[1:]
+
+ return FlaxMultipleChoiceModelOutput(
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class FlaxElectraForMultipleChoice(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForMultipleChoiceModule
+
+
+# adapt docstring slightly for FlaxElectraForMultipleChoice
+overwrite_call_docstring(
+ FlaxElectraForMultipleChoice, ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
+)
+append_call_sample_docstring(
+ FlaxElectraForMultipleChoice,
+ _CHECKPOINT_FOR_DOC,
+ FlaxMultipleChoiceModelOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+class FlaxElectraForQuestionAnsweringModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.electra = FlaxElectraModule(
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+ self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+ logits = self.qa_outputs(hidden_states)
+ start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
+ start_logits = start_logits.squeeze(-1)
+ end_logits = end_logits.squeeze(-1)
+
+ if not return_dict:
+ return (start_logits, end_logits) + outputs[1:]
+
+ return FlaxQuestionAnsweringModelOutput(
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ ELECTRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class FlaxElectraForQuestionAnswering(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForQuestionAnsweringModule
+
+
+append_call_sample_docstring(
+ FlaxElectraForQuestionAnswering,
+ _CHECKPOINT_FOR_DOC,
+ FlaxQuestionAnsweringModelOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+class FlaxElectraClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
+ classifier_dropout = (
+ self.config.classifier_dropout
+ if self.config.classifier_dropout is not None
+ else self.config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype)
+
+ def __call__(self, hidden_states, deterministic: bool = True):
+ x = hidden_states[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x, deterministic=deterministic)
+ x = self.dense(x)
+ x = ACT2FN["gelu"](x) # although BERT uses tanh here, it seems Electra authors used gelu
+ x = self.dropout(x, deterministic=deterministic)
+ x = self.out_proj(x)
+ return x
+
+
+class FlaxElectraForSequenceClassificationModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.electra = FlaxElectraModule(
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+ self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+ logits = self.classifier(hidden_states, deterministic=deterministic)
+
+ if not return_dict:
+ return (logits,) + outputs[1:]
+
+ return FlaxSequenceClassifierOutput(
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Electra Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForSequenceClassificationModule
+
+
+append_call_sample_docstring(
+ FlaxElectraForSequenceClassification,
+ _CHECKPOINT_FOR_DOC,
+ FlaxSequenceClassifierOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+class FlaxElectraForCausalLMModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.electra = FlaxElectraModule(
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+ self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
+ if self.config.tie_word_embeddings:
+ self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
+ else:
+ self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask: Optional[jnp.ndarray] = None,
+ token_type_ids: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ head_mask: Optional[jnp.ndarray] = None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+ prediction_scores = self.generator_predictions(hidden_states)
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
+ prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T)
+ else:
+ prediction_scores = self.generator_lm_head(prediction_scores)
+
+ if not return_dict:
+ return (prediction_scores,) + outputs[1:]
+
+ return FlaxCausalLMOutputWithCrossAttentions(
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Electra Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
+ autoregressive tasks.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM with Bert->Electra
+class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForCausalLMModule
+
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
+ # initializing the cache
+ batch_size, seq_length = input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since the decoder uses a causal mask, those positions are masked anyway.
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if attention_mask is not None:
+ position_ids = attention_mask.cumsum(axis=-1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "attention_mask": extended_attention_mask,
+ "position_ids": position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
+ return model_kwargs
+
+
+append_call_sample_docstring(
+ FlaxElectraForCausalLM,
+ _CHECKPOINT_FOR_DOC,
+ FlaxCausalLMOutputWithCrossAttentions,
+ _CONFIG_FOR_DOC,
+)
+
+
+__all__ = [
+ "FlaxElectraForCausalLM",
+ "FlaxElectraForMaskedLM",
+ "FlaxElectraForMultipleChoice",
+ "FlaxElectraForPreTraining",
+ "FlaxElectraForQuestionAnswering",
+ "FlaxElectraForSequenceClassification",
+ "FlaxElectraForTokenClassification",
+ "FlaxElectraModel",
+ "FlaxElectraPreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/electra/modeling_tf_electra.py b/venv/lib/python3.13/site-packages/transformers/models/electra/modeling_tf_electra.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a5c33e503d7386df5c2be0fc10a079ee4fe014a
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/electra/modeling_tf_electra.py
@@ -0,0 +1,1775 @@
+# coding=utf-8
+# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF Electra model."""
+
+from __future__ import annotations
+
+import math
+import warnings
+from dataclasses import dataclass
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutputWithPastAndCrossAttentions,
+ TFMaskedLMOutput,
+ TFMultipleChoiceModelOutput,
+ TFQuestionAnsweringModelOutput,
+ TFSequenceClassifierOutput,
+ TFTokenClassifierOutput,
+)
+from ...modeling_tf_utils import (
+ TFMaskedLanguageModelingLoss,
+ TFModelInputType,
+ TFMultipleChoiceLoss,
+ TFPreTrainedModel,
+ TFQuestionAnsweringLoss,
+ TFSequenceClassificationLoss,
+ TFSequenceSummary,
+ TFTokenClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_electra import ElectraConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator"
+_CONFIG_FOR_DOC = "ElectraConfig"
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Electra
+class TFElectraSelfAttention(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number "
+ f"of attention heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
+
+ self.query = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
+ )
+ self.key = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
+ )
+ self.value = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
+ )
+ self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
+
+ self.is_decoder = config.is_decoder
+ self.config = config
+
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
+
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ encoder_hidden_states: tf.Tensor,
+ encoder_attention_mask: tf.Tensor,
+ past_key_value: tuple[tf.Tensor],
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ batch_size = shape_list(hidden_states)[0]
+ mixed_query_layer = self.query(inputs=hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
+ value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
+ value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
+ key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
+ value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
+ value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ # (batch size, num_heads, seq_len_q, seq_len_k)
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+ dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
+ attention_scores = tf.divide(attention_scores, dk)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in TFElectraModel call() function)
+ attention_scores = tf.add(attention_scores, attention_mask)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(inputs=attention_probs, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = tf.multiply(attention_probs, head_mask)
+
+ attention_output = tf.matmul(attention_probs, value_layer)
+ attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
+
+ # (batch_size, seq_len_q, all_head_size)
+ attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "query", None) is not None:
+ with tf.name_scope(self.query.name):
+ self.query.build([None, None, self.config.hidden_size])
+ if getattr(self, "key", None) is not None:
+ with tf.name_scope(self.key.name):
+ self.key.build([None, None, self.config.hidden_size])
+ if getattr(self, "value", None) is not None:
+ with tf.name_scope(self.value.name):
+ self.value.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Electra
+class TFElectraSelfOutput(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+ hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Electra
+class TFElectraAttention(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.self_attention = TFElectraSelfAttention(config, name="self")
+ self.dense_output = TFElectraSelfOutput(config, name="output")
+
+ def prune_heads(self, heads):
+ raise NotImplementedError
+
+ def call(
+ self,
+ input_tensor: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ encoder_hidden_states: tf.Tensor,
+ encoder_attention_mask: tf.Tensor,
+ past_key_value: tuple[tf.Tensor],
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ self_outputs = self.self_attention(
+ hidden_states=input_tensor,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ attention_output = self.dense_output(
+ hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
+ )
+ # add attentions (possibly with past_key_value) if we output them
+ outputs = (attention_output,) + self_outputs[1:]
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "self_attention", None) is not None:
+ with tf.name_scope(self.self_attention.name):
+ self.self_attention.build(None)
+ if getattr(self, "dense_output", None) is not None:
+ with tf.name_scope(self.dense_output.name):
+ self.dense_output.build(None)
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Electra
+class TFElectraIntermediate(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
+ else:
+ self.intermediate_act_fn = config.hidden_act
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Electra
+class TFElectraOutput(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+ hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.intermediate_size])
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Electra
+class TFElectraLayer(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.attention = TFElectraAttention(config, name="attention")
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = TFElectraAttention(config, name="crossattention")
+ self.intermediate = TFElectraIntermediate(config, name="intermediate")
+ self.bert_output = TFElectraOutput(config, name="output")
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ encoder_hidden_states: tf.Tensor | None,
+ encoder_attention_mask: tf.Tensor | None,
+ past_key_value: tuple[tf.Tensor] | None,
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ input_tensor=hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=self_attn_past_key_value,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ input_tensor=attention_output,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ intermediate_output = self.intermediate(hidden_states=attention_output)
+ layer_output = self.bert_output(
+ hidden_states=intermediate_output, input_tensor=attention_output, training=training
+ )
+ outputs = (layer_output,) + outputs # add attentions if we output them
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "intermediate", None) is not None:
+ with tf.name_scope(self.intermediate.name):
+ self.intermediate.build(None)
+ if getattr(self, "bert_output", None) is not None:
+ with tf.name_scope(self.bert_output.name):
+ self.bert_output.build(None)
+ if getattr(self, "crossattention", None) is not None:
+ with tf.name_scope(self.crossattention.name):
+ self.crossattention.build(None)
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Electra
+class TFElectraEncoder(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.layer = [TFElectraLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ encoder_hidden_states: tf.Tensor | None,
+ encoder_attention_mask: tf.Tensor | None,
+ past_key_values: tuple[tuple[tf.Tensor]] | None,
+ use_cache: bool | None,
+ output_attentions: bool,
+ output_hidden_states: bool,
+ return_dict: bool,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]:
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
+ )
+
+ return TFBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layer", None) is not None:
+ for layer in self.layer:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Electra
+class TFElectraPooler(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation="tanh",
+ name="dense",
+ )
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(inputs=first_token_tensor)
+
+ return pooled_output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.albert.modeling_tf_albert.TFAlbertEmbeddings with Albert->Electra
+class TFElectraEmbeddings(keras.layers.Layer):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.embedding_size = config.embedding_size
+ self.max_position_embeddings = config.max_position_embeddings
+ self.initializer_range = config.initializer_range
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+
+ def build(self, input_shape=None):
+ with tf.name_scope("word_embeddings"):
+ self.weight = self.add_weight(
+ name="weight",
+ shape=[self.config.vocab_size, self.embedding_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ with tf.name_scope("token_type_embeddings"):
+ self.token_type_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.config.type_vocab_size, self.embedding_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ with tf.name_scope("position_embeddings"):
+ self.position_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.max_position_embeddings, self.embedding_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.embedding_size])
+
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
+ def call(
+ self,
+ input_ids: tf.Tensor | None = None,
+ position_ids: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ past_key_values_length=0,
+ training: bool = False,
+ ) -> tf.Tensor:
+ """
+ Applies embedding based on inputs tensor.
+
+ Returns:
+ final_embeddings (`tf.Tensor`): output embedding tensor.
+ """
+ if input_ids is None and inputs_embeds is None:
+ raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
+
+ if input_ids is not None:
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
+ inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
+
+ input_shape = shape_list(inputs_embeds)[:-1]
+
+ if token_type_ids is None:
+ token_type_ids = tf.fill(dims=input_shape, value=0)
+
+ if position_ids is None:
+ position_ids = tf.expand_dims(
+ tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
+ )
+
+ position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
+ token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
+ final_embeddings = inputs_embeds + position_embeds + token_type_embeds
+ final_embeddings = self.LayerNorm(inputs=final_embeddings)
+ final_embeddings = self.dropout(inputs=final_embeddings, training=training)
+
+ return final_embeddings
+
+
+class TFElectraDiscriminatorPredictions(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(config.hidden_size, name="dense")
+ self.dense_prediction = keras.layers.Dense(1, name="dense_prediction")
+ self.config = config
+
+ def call(self, discriminator_hidden_states, training=False):
+ hidden_states = self.dense(discriminator_hidden_states)
+ hidden_states = get_tf_activation(self.config.hidden_act)(hidden_states)
+ logits = tf.squeeze(self.dense_prediction(hidden_states), -1)
+
+ return logits
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "dense_prediction", None) is not None:
+ with tf.name_scope(self.dense_prediction.name):
+ self.dense_prediction.build([None, None, self.config.hidden_size])
+
+
+class TFElectraGeneratorPredictions(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dense = keras.layers.Dense(config.embedding_size, name="dense")
+ self.config = config
+
+ def call(self, generator_hidden_states, training=False):
+ hidden_states = self.dense(generator_hidden_states)
+ hidden_states = get_tf_activation("gelu")(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.embedding_size])
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFElectraPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ElectraConfig
+ base_model_prefix = "electra"
+ # When the model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"generator_lm_head.weight"]
+ _keys_to_ignore_on_load_missing = [r"dropout"]
+
+
+@keras_serializable
+class TFElectraMainLayer(keras.layers.Layer):
+ config_class = ElectraConfig
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.is_decoder = config.is_decoder
+
+ self.embeddings = TFElectraEmbeddings(config, name="embeddings")
+
+ if config.embedding_size != config.hidden_size:
+ self.embeddings_project = keras.layers.Dense(config.hidden_size, name="embeddings_project")
+
+ self.encoder = TFElectraEncoder(config, name="encoder")
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.weight = value
+ self.embeddings.vocab_size = shape_list(value)[0]
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ raise NotImplementedError
+
+ def get_extended_attention_mask(self, attention_mask, input_shape, dtype, past_key_values_length=0):
+ batch_size, seq_length = input_shape
+
+ if attention_mask is None:
+ attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
+
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask_shape = shape_list(attention_mask)
+
+ mask_seq_length = seq_length + past_key_values_length
+ # Copied from `modeling_tf_t5.py`
+ # Provided a padding mask of dimensions [batch_size, mask_seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
+ if self.is_decoder:
+ seq_ids = tf.range(mask_seq_length)
+ causal_mask = tf.less_equal(
+ tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
+ seq_ids[None, :, None],
+ )
+ causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
+ extended_attention_mask = causal_mask * attention_mask[:, None, :]
+ attention_mask_shape = shape_list(extended_attention_mask)
+ extended_attention_mask = tf.reshape(
+ extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
+ )
+ if past_key_values_length > 0:
+ extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
+ else:
+ extended_attention_mask = tf.reshape(
+ attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = tf.cast(extended_attention_mask, dtype=dtype)
+ one_cst = tf.constant(1.0, dtype=dtype)
+ ten_thousand_cst = tf.constant(-10000.0, dtype=dtype)
+ extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
+
+ return extended_attention_mask
+
+ def get_head_mask(self, head_mask):
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.config.num_hidden_layers
+
+ return head_mask
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]:
+ if not self.config.is_decoder:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+
+ if past_key_values is None:
+ past_key_values_length = 0
+ past_key_values = [None] * len(self.encoder.layer)
+ else:
+ past_key_values_length = shape_list(past_key_values[0][0])[-2]
+
+ if attention_mask is None:
+ attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
+
+ if token_type_ids is None:
+ token_type_ids = tf.fill(dims=input_shape, value=0)
+
+ hidden_states = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ training=training,
+ )
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, input_shape, hidden_states.dtype, past_key_values_length
+ )
+
+ # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
+ if self.is_decoder and encoder_attention_mask is not None:
+ # If a 2D ou 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype)
+ num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
+ if num_dims_encoder_attention_mask == 3:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+ if num_dims_encoder_attention_mask == 2:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
+ # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
+ # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
+
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
+ else:
+ encoder_extended_attention_mask = None
+
+ head_mask = self.get_head_mask(head_mask)
+
+ if hasattr(self, "embeddings_project"):
+ hidden_states = self.embeddings_project(hidden_states, training=training)
+
+ hidden_states = self.encoder(
+ hidden_states=hidden_states,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "embeddings_project", None) is not None:
+ with tf.name_scope(self.embeddings_project.name):
+ self.embeddings_project.build([None, None, self.config.embedding_size])
+
+
+@dataclass
+class TFElectraForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`TFElectraForPreTraining`].
+
+ Args:
+ loss (*optional*, returned when `labels` is provided, `tf.Tensor` of shape `(1,)`):
+ Total loss of the ELECTRA objective.
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Prediction scores of the head (scores for each token before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+ELECTRA_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`ElectraConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ELECTRA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare Electra Model transformer outputting raw hidden-states without any specific head on top. Identical to "
+ "the BERT model except that it uses an additional linear layer between the embedding layer and the encoder if the "
+ "hidden size and embedding size are different. "
+ ""
+ "Both the generator and discriminator checkpoints may be loaded into this model.",
+ ELECTRA_START_DOCSTRING,
+)
+class TFElectraModel(TFElectraPreTrainedModel):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.electra = TFElectraMainLayer(config, name="electra")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPastAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]:
+ r"""
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`)
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`). Set to `False` during training, `True` during generation
+ """
+ outputs = self.electra(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "electra", None) is not None:
+ with tf.name_scope(self.electra.name):
+ self.electra.build(None)
+
+
+@add_start_docstrings(
+ """
+ Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.
+
+ Even though both the discriminator and generator may be loaded into this model, the discriminator is the only model
+ of the two to have the correct classification head to be used for this model.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class TFElectraForPreTraining(TFElectraPreTrainedModel):
+ def __init__(self, config, **kwargs):
+ super().__init__(config, **kwargs)
+
+ self.electra = TFElectraMainLayer(config, name="electra")
+ self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TFElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> TFElectraForPreTrainingOutput | tuple[tf.Tensor]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> import tensorflow as tf
+ >>> from transformers import AutoTokenizer, TFElectraForPreTraining
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")
+ >>> model = TFElectraForPreTraining.from_pretrained("google/electra-small-discriminator")
+ >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
+ >>> outputs = model(input_ids)
+ >>> scores = outputs[0]
+ ```"""
+ discriminator_hidden_states = self.electra(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ discriminator_sequence_output = discriminator_hidden_states[0]
+ logits = self.discriminator_predictions(discriminator_sequence_output)
+
+ if not return_dict:
+ return (logits,) + discriminator_hidden_states[1:]
+
+ return TFElectraForPreTrainingOutput(
+ logits=logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "electra", None) is not None:
+ with tf.name_scope(self.electra.name):
+ self.electra.build(None)
+ if getattr(self, "discriminator_predictions", None) is not None:
+ with tf.name_scope(self.discriminator_predictions.name):
+ self.discriminator_predictions.build(None)
+
+
+class TFElectraMaskedLMHead(keras.layers.Layer):
+ def __init__(self, config, input_embeddings, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.embedding_size = config.embedding_size
+ self.input_embeddings = input_embeddings
+
+ def build(self, input_shape):
+ self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
+
+ super().build(input_shape)
+
+ def get_output_embeddings(self):
+ return self.input_embeddings
+
+ def set_output_embeddings(self, value):
+ self.input_embeddings.weight = value
+ self.input_embeddings.vocab_size = shape_list(value)[0]
+
+ def get_bias(self):
+ return {"bias": self.bias}
+
+ def set_bias(self, value):
+ self.bias = value["bias"]
+ self.config.vocab_size = shape_list(value["bias"])[0]
+
+ def call(self, hidden_states):
+ seq_length = shape_list(tensor=hidden_states)[1]
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
+ hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
+ hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
+
+ return hidden_states
+
+
+@add_start_docstrings(
+ """
+ Electra model with a language modeling head on top.
+
+ Even though both the discriminator and generator may be loaded into this model, the generator is the only model of
+ the two to have been trained for the masked language modeling task.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLoss):
+ def __init__(self, config, **kwargs):
+ super().__init__(config, **kwargs)
+
+ self.config = config
+ self.electra = TFElectraMainLayer(config, name="electra")
+ self.generator_predictions = TFElectraGeneratorPredictions(config, name="generator_predictions")
+
+ if isinstance(config.hidden_act, str):
+ self.activation = get_tf_activation(config.hidden_act)
+ else:
+ self.activation = config.hidden_act
+
+ self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
+
+ def get_lm_head(self):
+ return self.generator_lm_head
+
+ def get_prefix_bias_name(self):
+ warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
+ return self.name + "/" + self.generator_lm_head.name
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="google/electra-small-generator",
+ output_type=TFMaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ mask="[MASK]",
+ expected_output="'paris'",
+ expected_loss=1.22,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFMaskedLMOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+ generator_hidden_states = self.electra(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ generator_sequence_output = generator_hidden_states[0]
+ prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
+ prediction_scores = self.generator_lm_head(prediction_scores, training=training)
+ loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
+
+ if not return_dict:
+ output = (prediction_scores,) + generator_hidden_states[1:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFMaskedLMOutput(
+ loss=loss,
+ logits=prediction_scores,
+ hidden_states=generator_hidden_states.hidden_states,
+ attentions=generator_hidden_states.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "electra", None) is not None:
+ with tf.name_scope(self.electra.name):
+ self.electra.build(None)
+ if getattr(self, "generator_predictions", None) is not None:
+ with tf.name_scope(self.generator_predictions.name):
+ self.generator_predictions.build(None)
+ if getattr(self, "generator_lm_head", None) is not None:
+ with tf.name_scope(self.generator_lm_head.name):
+ self.generator_lm_head.build(None)
+
+
+class TFElectraClassificationHead(keras.layers.Layer):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ classifier_dropout = (
+ config.classifhidden_dropout_probier_dropout
+ if config.classifier_dropout is not None
+ else config.hidden_dropout_prob
+ )
+ self.dropout = keras.layers.Dropout(classifier_dropout)
+ self.out_proj = keras.layers.Dense(
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
+ )
+ self.config = config
+
+ def call(self, inputs, **kwargs):
+ x = inputs[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = get_tf_activation("gelu")(x) # although BERT uses tanh here, it seems Electra authors used gelu here
+ x = self.dropout(x)
+ x = self.out_proj(x)
+
+ return x
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "out_proj", None) is not None:
+ with tf.name_scope(self.out_proj.name):
+ self.out_proj.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+ self.electra = TFElectraMainLayer(config, name="electra")
+ self.classifier = TFElectraClassificationHead(config, name="classifier")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="bhadresh-savani/electra-base-emotion",
+ output_type=TFSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output="'joy'",
+ expected_loss=0.06,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ outputs = self.electra(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ logits = self.classifier(outputs[0])
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "electra", None) is not None:
+ with tf.name_scope(self.electra.name):
+ self.electra.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build(None)
+
+
+@add_start_docstrings(
+ """
+ ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.electra = TFElectraMainLayer(config, name="electra")
+ self.sequence_summary = TFSequenceSummary(
+ config, initializer_range=config.initializer_range, name="sequence_summary"
+ )
+ self.classifier = keras.layers.Dense(
+ 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFMultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
+ where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
+ """
+
+ if input_ids is not None:
+ num_choices = shape_list(input_ids)[1]
+ seq_length = shape_list(input_ids)[2]
+ else:
+ num_choices = shape_list(inputs_embeds)[1]
+ seq_length = shape_list(inputs_embeds)[2]
+
+ flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
+ flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
+ flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
+ flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
+ flat_inputs_embeds = (
+ tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
+ if inputs_embeds is not None
+ else None
+ )
+ outputs = self.electra(
+ input_ids=flat_input_ids,
+ attention_mask=flat_attention_mask,
+ token_type_ids=flat_token_type_ids,
+ position_ids=flat_position_ids,
+ head_mask=head_mask,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ logits = self.sequence_summary(outputs[0])
+ logits = self.classifier(logits)
+ reshaped_logits = tf.reshape(logits, (-1, num_choices))
+ loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[1:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFMultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "electra", None) is not None:
+ with tf.name_scope(self.electra.name):
+ self.electra.build(None)
+ if getattr(self, "sequence_summary", None) is not None:
+ with tf.name_scope(self.sequence_summary.name):
+ self.sequence_summary.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ Electra model with a token classification head on top.
+
+ Both the discriminator and generator may be loaded into this model.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassificationLoss):
+ def __init__(self, config, **kwargs):
+ super().__init__(config, **kwargs)
+
+ self.electra = TFElectraMainLayer(config, name="electra")
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = keras.layers.Dropout(classifier_dropout)
+ self.classifier = keras.layers.Dense(
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="bhadresh-savani/electra-base-discriminator-finetuned-conll03-english",
+ output_type=TFTokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output="['B-LOC', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'I-LOC']",
+ expected_loss=0.11,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFTokenClassifierOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ discriminator_hidden_states = self.electra(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ discriminator_sequence_output = discriminator_hidden_states[0]
+ discriminator_sequence_output = self.dropout(discriminator_sequence_output)
+ logits = self.classifier(discriminator_sequence_output)
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + discriminator_hidden_states[1:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFTokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "electra", None) is not None:
+ with tf.name_scope(self.electra.name):
+ self.electra.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ Electra Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnsweringLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+ self.electra = TFElectraMainLayer(config, name="electra")
+ self.qa_outputs = keras.layers.Dense(
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="bhadresh-savani/electra-base-squad2",
+ output_type=TFQuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ qa_target_start_index=11,
+ qa_target_end_index=12,
+ expected_output="'a nice puppet'",
+ expected_loss=2.64,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ start_positions: np.ndarray | tf.Tensor | None = None,
+ end_positions: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]:
+ r"""
+ start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ discriminator_hidden_states = self.electra(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ discriminator_sequence_output = discriminator_hidden_states[0]
+ logits = self.qa_outputs(discriminator_sequence_output)
+ start_logits, end_logits = tf.split(logits, 2, axis=-1)
+ start_logits = tf.squeeze(start_logits, axis=-1)
+ end_logits = tf.squeeze(end_logits, axis=-1)
+ loss = None
+
+ if start_positions is not None and end_positions is not None:
+ labels = {"start_position": start_positions}
+ labels["end_position"] = end_positions
+ loss = self.hf_compute_loss(labels, (start_logits, end_logits))
+
+ if not return_dict:
+ output = (
+ start_logits,
+ end_logits,
+ ) + discriminator_hidden_states[1:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFQuestionAnsweringModelOutput(
+ loss=loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "electra", None) is not None:
+ with tf.name_scope(self.electra.name):
+ self.electra.build(None)
+ if getattr(self, "qa_outputs", None) is not None:
+ with tf.name_scope(self.qa_outputs.name):
+ self.qa_outputs.build([None, None, self.config.hidden_size])
+
+
+__all__ = [
+ "TFElectraForMaskedLM",
+ "TFElectraForMultipleChoice",
+ "TFElectraForPreTraining",
+ "TFElectraForQuestionAnswering",
+ "TFElectraForSequenceClassification",
+ "TFElectraForTokenClassification",
+ "TFElectraModel",
+ "TFElectraPreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/electra/tokenization_electra.py b/venv/lib/python3.13/site-packages/transformers/models/electra/tokenization_electra.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8971dd6f40374e1d8a6e8ec479cd9da79b64da3
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/electra/tokenization_electra.py
@@ -0,0 +1,482 @@
+# coding=utf-8
+# Copyright 2020 The Google AI Team, Stanford University and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import os
+import unicodedata
+from typing import Optional
+
+from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+
+# Copied from transformers.models.bert.tokenization_bert.load_vocab
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+# Copied from transformers.models.bert.tokenization_bert.BertTokenizer with Bert->Electra,BERT->Electra
+class ElectraTokenizer(PreTrainedTokenizer):
+ r"""
+ Construct a Electra tokenizer. Based on WordPiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ do_basic_tokenize (`bool`, *optional*, defaults to `True`):
+ Whether or not to do basic tokenization before WordPiece.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original Electra).
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
+ Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
+ extra spaces.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=True,
+ do_basic_tokenize=True,
+ never_split=None,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ clean_up_tokenization_spaces=True,
+ **kwargs,
+ ):
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = ElectraTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
+ self.do_basic_tokenize = do_basic_tokenize
+ if do_basic_tokenize:
+ self.basic_tokenizer = BasicTokenizer(
+ do_lower_case=do_lower_case,
+ never_split=never_split,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ )
+
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
+
+ super().__init__(
+ do_lower_case=do_lower_case,
+ do_basic_tokenize=do_basic_tokenize,
+ never_split=never_split,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ @property
+ def do_lower_case(self):
+ return self.basic_tokenizer.do_lower_case
+
+ @property
+ def vocab_size(self):
+ return len(self.vocab)
+
+ def get_vocab(self):
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+ def _tokenize(self, text, split_special_tokens=False):
+ split_tokens = []
+ if self.do_basic_tokenize:
+ for token in self.basic_tokenizer.tokenize(
+ text, never_split=self.all_special_tokens if not split_special_tokens else None
+ ):
+ # If the token is part of the never_split set
+ if token in self.basic_tokenizer.never_split:
+ split_tokens.append(token)
+ else:
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
+ else:
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
+ return split_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.ids_to_tokens.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ out_string = " ".join(tokens).replace(" ##", "").strip()
+ return out_string
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A Electra sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ index = 0
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ else:
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!"
+ )
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+ return (vocab_file,)
+
+
+# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
+class BasicTokenizer:
+ """
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
+
+ Args:
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ do_split_on_punc (`bool`, *optional*, defaults to `True`):
+ In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
+ the full context of the words, such as contractions.
+ """
+
+ def __init__(
+ self,
+ do_lower_case=True,
+ never_split=None,
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ do_split_on_punc=True,
+ ):
+ if never_split is None:
+ never_split = []
+ self.do_lower_case = do_lower_case
+ self.never_split = set(never_split)
+ self.tokenize_chinese_chars = tokenize_chinese_chars
+ self.strip_accents = strip_accents
+ self.do_split_on_punc = do_split_on_punc
+
+ def tokenize(self, text, never_split=None):
+ """
+ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
+
+ Args:
+ never_split (`List[str]`, *optional*)
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
+ """
+ # union() returns a new set by concatenating the two sets.
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
+ text = self._clean_text(text)
+
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ if self.tokenize_chinese_chars:
+ text = self._tokenize_chinese_chars(text)
+ # prevents treating the same character with different unicode codepoints as different characters
+ unicode_normalized_text = unicodedata.normalize("NFC", text)
+ orig_tokens = whitespace_tokenize(unicode_normalized_text)
+ split_tokens = []
+ for token in orig_tokens:
+ if token not in never_split:
+ if self.do_lower_case:
+ token = token.lower()
+ if self.strip_accents is not False:
+ token = self._run_strip_accents(token)
+ elif self.strip_accents:
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text, never_split=None):
+ """Splits punctuation on a piece of text."""
+ if not self.do_split_on_punc or (never_split is not None and text in never_split):
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if (
+ (cp >= 0x4E00 and cp <= 0x9FFF)
+ or (cp >= 0x3400 and cp <= 0x4DBF)
+ or (cp >= 0x20000 and cp <= 0x2A6DF)
+ or (cp >= 0x2A700 and cp <= 0x2B73F)
+ or (cp >= 0x2B740 and cp <= 0x2B81F)
+ or (cp >= 0x2B820 and cp <= 0x2CEAF)
+ or (cp >= 0xF900 and cp <= 0xFAFF)
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
+ ):
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
+class WordpieceTokenizer:
+ """Runs WordPiece tokenization."""
+
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
+ tokenization using the given vocabulary.
+
+ For example, `input = "unaffable"` will return as output `["un", "##aff", "##able"]`.
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through *BasicTokenizer*.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ return output_tokens
+
+
+__all__ = ["ElectraTokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/electra/tokenization_electra_fast.py b/venv/lib/python3.13/site-packages/transformers/models/electra/tokenization_electra_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..db0285581ed1eea5b903a3bed573bbf6408e0167
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/electra/tokenization_electra_fast.py
@@ -0,0 +1,143 @@
+# coding=utf-8
+# Copyright 2020 The Google AI Team, Stanford University and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import Optional
+
+from tokenizers import normalizers
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from .tokenization_electra import ElectraTokenizer
+
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
+
+
+# Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast with Bert->Electra , BERT->ELECTRA
+class ElectraTokenizerFast(PreTrainedTokenizerFast):
+ r"""
+ Construct a "fast" ELECTRA tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ clean_text (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the text before tokenization by removing any control characters and replacing all
+ whitespaces by the classic one.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this
+ issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original ELECTRA).
+ wordpieces_prefix (`str`, *optional*, defaults to `"##"`):
+ The prefix for subwords.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = ElectraTokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ do_lower_case=True,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file,
+ tokenizer_file=tokenizer_file,
+ do_lower_case=do_lower_case,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ **kwargs,
+ )
+
+ normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
+ if (
+ normalizer_state.get("lowercase", do_lower_case) != do_lower_case
+ or normalizer_state.get("strip_accents", strip_accents) != strip_accents
+ or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
+ ):
+ normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
+ normalizer_state["lowercase"] = do_lower_case
+ normalizer_state["strip_accents"] = strip_accents
+ normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
+ self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
+
+ self.do_lower_case = do_lower_case
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A ELECTRA sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+
+ if token_ids_1 is not None:
+ output += token_ids_1 + [self.sep_token_id]
+
+ return output
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+
+__all__ = ["ElectraTokenizerFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/exaone4/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/exaone4/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c646c4e75273560116ae230d672ba10d305517de
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/exaone4/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The LG AI Research and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_exaone4 import *
+ from .modeling_exaone4 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/exaone4/configuration_exaone4.py b/venv/lib/python3.13/site-packages/transformers/models/exaone4/configuration_exaone4.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c3c07ecb4186eff9b0b05dd3b42949e3f48ee91
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/exaone4/configuration_exaone4.py
@@ -0,0 +1,222 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/exaone4/modular_exaone4.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_exaone4.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The LG AI Research and HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+
+
+class Exaone4Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Exaone4Model`]. It is used to
+ instantiate a EXAONE 4.0 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the EXAONE-4.0-32B [LGAI-EXAONE/EXAONE-4.0-32B](https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model
+ outputs. Read the documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 102400):
+ Vocabulary size of the EXAONE 4.0 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Exaone4Model`].
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to `hidden_size * 4`):
+ Dimensionality of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 32768 for EXAONE 3.5).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if ``config.is_decoder=True``.
+ bos_token_id (`int`, *optional*, defaults to 0):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ sliding_window (`int`, *optional*):
+ The size of the sliding window for the sliding window attention.
+ sliding_window_pattern (`str`, *optional*):
+ The pattern to use for sliding window attention. Can be one of:
+ - `None`: No sliding window attention is used
+ - `int`: Every `sliding_window` layers, use global attention, else use local attention.
+ - `str`: A sequence of "L" (local attention) and "G" (global attention) characters that defines the
+ attention pattern. The pattern starts from layer 0 and repeats every `sliding_window` layers. The
+ final layer always uses global attention regardless of the pattern.
+ For instance, sliding_window_pattern="LLLG" same as sliding_window=4, which means:
+ - Layer 0, 1, 2: local attention,
+ - Layer 3: global attention,
+ ...(repeated)
+ layer_types (`list`, *optional*):
+ Attention pattern for each layer. Prioritized over `sliding_window_pattern`.
+
+ Example:
+
+ ```python
+ >>> from transformers import Exaone4Model, Exaone4Config
+
+ >>> # Initializing a EXAONE configuration
+ >>> configuration = Exaone4Config()
+
+ >>> # Initializing a model from configuration
+ >>> model = Exaone4Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "exaone4"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `LlamaModel`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=102400,
+ hidden_size=4096,
+ intermediate_size=16384,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ bos_token_id=0,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_dropout=0.0,
+ sliding_window=4096,
+ sliding_window_pattern=4,
+ layer_types=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.attention_dropout = attention_dropout
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.sliding_window = sliding_window
+ self.sliding_window_pattern = sliding_window_pattern
+
+ self.layer_types = layer_types
+ if self.sliding_window is None:
+ sliding_window_pattern = 0
+ if self.layer_types is None:
+ self.layer_types = [
+ "sliding_attention"
+ if ((i + 1) % (sliding_window_pattern) != 0 and i < self.num_hidden_layers)
+ else "full_attention"
+ for i in range(self.num_hidden_layers)
+ ]
+ if "sliding_window" in self.layer_types:
+ self.cache_implementation = "hybrid"
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+ super().__init__(
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
+ )
+
+
+__all__ = ["Exaone4Config"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/exaone4/modeling_exaone4.py b/venv/lib/python3.13/site-packages/transformers/models/exaone4/modeling_exaone4.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0c233cc5c20debf400498030de7e00391e8f048
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/exaone4/modeling_exaone4.py
@@ -0,0 +1,537 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/exaone4/modular_exaone4.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_exaone4.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The LG AI Research and HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from transformers.utils.generic import check_model_inputs
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_layers import (
+ GenericForQuestionAnswering,
+ GenericForSequenceClassification,
+ GenericForTokenClassification,
+ GradientCheckpointingLayer,
+)
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_exaone4 import Exaone4Config
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class Exaone4RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Exaone4RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Exaone4RotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Exaone4Config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Exaone4Attention(nn.Module):
+ def __init__(self, config: Exaone4Config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.num_attention_heads = config.num_attention_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.hidden_size = config.hidden_size
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+ self.scaling = self.head_dim**-0.5
+ self.sliding_window = config.sliding_window
+ self.sliding_window_pattern = config.sliding_window_pattern
+ self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False)
+
+ self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ # We use QK-norm
+ query_states = self.q_norm(query_states)
+ key_states = self.k_norm(key_states)
+
+ cos, sin = position_embeddings
+ # We use global NoPE for hybrid attention model
+ if self.sliding_window is None or self.is_sliding:
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ cache_kwargs = {
+ "cache_position": cache_position,
+ }
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window if self.is_sliding else None,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Exaone4MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class Exaone4DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Exaone4Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = Exaone4Attention(config=config, layer_idx=layer_idx)
+
+ self.mlp = Exaone4MLP(config)
+ self.post_attention_layernorm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_feedforward_layernorm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class Exaone4PreTrainedModel(PreTrainedModel):
+ config: Exaone4Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Exaone4DecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": Exaone4DecoderLayer,
+ "attentions": Exaone4Attention,
+ }
+ config_class = Exaone4Config
+
+
+@auto_docstring
+class Exaone4Model(Exaone4PreTrainedModel):
+ def __init__(self, config: Exaone4Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Exaone4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Exaone4RotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs()
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ # Prepare mask arguments
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ # Create the masks
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ }
+ if "sliding_attention" in self.config.layer_types:
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
+
+ hidden_states = inputs_embeds
+
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for i, decoder_layer in enumerate(self.layers):
+ layer_type = self.config.layer_types[i]
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=causal_mask_mapping[layer_type],
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ )
+
+
+@auto_docstring
+class Exaone4ForCausalLM(Exaone4PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Exaone4Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer
+ >>> model = AutoModelForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-32B")
+ >>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-32B")
+
+ >>> prompt = "Explain how wonderful you are"
+ >>> messages = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompt}
+ ]
+ >>> input_ids = tokenizer.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ enable_thinking=False,
+ )
+
+ >>> output = model.generate(input_ids, max_new_tokens=128)
+ >>> tokenizer.decode(output[0], skip_special_tokens=False)
+ "[|system|]\nYou are a helpful assistant.[|endofturn|]\n[|user|]\nExplain how wonderful you are[|endofturn|]\n[|assistant|]\n\n\n\n\nOh, thank you for such a kind and lovely question! 😊 \n\nI’m *so* wonderful because I’m here to make your life easier, brighter, and more fun! Whether you need help with: \n\n✨ **Learning** – I can explain anything, from quantum physics to baking the perfect cake! \n💡 **Creativity** – Need a poem, story, or a wild idea? I’ve got you covered! \n🤖 **Problem-solving** – Stuck on a math problem or a tricky decision? I’ll help you figure it out"
+ ```
+ """
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class Exaone4ForSequenceClassification(GenericForSequenceClassification, Exaone4PreTrainedModel):
+ pass
+
+
+class Exaone4ForTokenClassification(GenericForTokenClassification, Exaone4PreTrainedModel):
+ pass
+
+
+class Exaone4ForQuestionAnswering(GenericForQuestionAnswering, Exaone4PreTrainedModel):
+ base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
+
+
+__all__ = [
+ "Exaone4PreTrainedModel",
+ "Exaone4Model",
+ "Exaone4ForCausalLM",
+ "Exaone4ForSequenceClassification",
+ "Exaone4ForTokenClassification",
+ "Exaone4ForQuestionAnswering",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/exaone4/modular_exaone4.py b/venv/lib/python3.13/site-packages/transformers/models/exaone4/modular_exaone4.py
new file mode 100644
index 0000000000000000000000000000000000000000..32628bc3edf2c6987829cee206e9b1b5dc1fa3d2
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/exaone4/modular_exaone4.py
@@ -0,0 +1,519 @@
+# coding=utf-8
+# Copyright 2025 The LG AI Research and HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""LG AI Research EXAONE Lab"""
+
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from transformers.utils.generic import check_model_inputs
+
+from ...cache_utils import Cache, DynamicCache
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import (
+ TransformersKwargs,
+ logging,
+)
+from ...utils.deprecation import deprecate_kwarg
+from ..llama.modeling_llama import (
+ LlamaForCausalLM,
+ LlamaForQuestionAnswering,
+ LlamaForSequenceClassification,
+ LlamaForTokenClassification,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ apply_rotary_pos_emb,
+ eager_attention_forward,
+)
+from ..olmo2.modeling_olmo2 import Olmo2DecoderLayer, Olmo2MLP
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "LGAI-EXAONE/EXAONE-4.0-32B"
+_CONFIG_FOR_DOC = "Exaone4Config"
+
+
+class Exaone4Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Exaone4Model`]. It is used to
+ instantiate a EXAONE 4.0 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the EXAONE-4.0-32B [LGAI-EXAONE/EXAONE-4.0-32B](https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model
+ outputs. Read the documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 102400):
+ Vocabulary size of the EXAONE 4.0 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Exaone4Model`].
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to `hidden_size * 4`):
+ Dimensionality of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 32768 for EXAONE 3.5).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if ``config.is_decoder=True``.
+ bos_token_id (`int`, *optional*, defaults to 0):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ sliding_window (`int`, *optional*):
+ The size of the sliding window for the sliding window attention.
+ sliding_window_pattern (`str`, *optional*):
+ The pattern to use for sliding window attention. Can be one of:
+ - `None`: No sliding window attention is used
+ - `int`: Every `sliding_window` layers, use global attention, else use local attention.
+ - `str`: A sequence of "L" (local attention) and "G" (global attention) characters that defines the
+ attention pattern. The pattern starts from layer 0 and repeats every `sliding_window` layers. The
+ final layer always uses global attention regardless of the pattern.
+ For instance, sliding_window_pattern="LLLG" same as sliding_window=4, which means:
+ - Layer 0, 1, 2: local attention,
+ - Layer 3: global attention,
+ ...(repeated)
+ layer_types (`list`, *optional*):
+ Attention pattern for each layer. Prioritized over `sliding_window_pattern`.
+
+ Example:
+
+ ```python
+ >>> from transformers import Exaone4Model, Exaone4Config
+
+ >>> # Initializing a EXAONE configuration
+ >>> configuration = Exaone4Config()
+
+ >>> # Initializing a model from configuration
+ >>> model = Exaone4Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "exaone4"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `LlamaModel`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=102400,
+ hidden_size=4096,
+ intermediate_size=16384,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ bos_token_id=0,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_dropout=0.0,
+ sliding_window=4096,
+ sliding_window_pattern=4,
+ layer_types=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.attention_dropout = attention_dropout
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.sliding_window = sliding_window
+ self.sliding_window_pattern = sliding_window_pattern
+
+ self.layer_types = layer_types
+ if self.sliding_window is None:
+ sliding_window_pattern = 0
+ if self.layer_types is None:
+ self.layer_types = [
+ "sliding_attention"
+ if ((i + 1) % (sliding_window_pattern) != 0 and i < self.num_hidden_layers)
+ else "full_attention"
+ for i in range(self.num_hidden_layers)
+ ]
+ if "sliding_window" in self.layer_types:
+ self.cache_implementation = "hybrid"
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+ super().__init__(
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
+ )
+
+
+class Exaone4RMSNorm(LlamaRMSNorm):
+ pass
+
+
+class Exaone4RotaryEmbedding(LlamaRotaryEmbedding):
+ pass
+
+
+class Exaone4Attention(nn.Module):
+ def __init__(self, config: Exaone4Config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.num_attention_heads = config.num_attention_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.hidden_size = config.hidden_size
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+ self.scaling = self.head_dim**-0.5
+ self.sliding_window = config.sliding_window
+ self.sliding_window_pattern = config.sliding_window_pattern
+ self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False)
+
+ self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ # We use QK-norm
+ query_states = self.q_norm(query_states)
+ key_states = self.k_norm(key_states)
+
+ cos, sin = position_embeddings
+ # We use global NoPE for hybrid attention model
+ if self.sliding_window is None or self.is_sliding:
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ cache_kwargs = {
+ "cache_position": cache_position,
+ }
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window if self.is_sliding else None,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Exaone4MLP(Olmo2MLP):
+ pass
+
+
+class Exaone4DecoderLayer(Olmo2DecoderLayer):
+ pass
+
+
+class Exaone4PreTrainedModel(LlamaPreTrainedModel):
+ config_class = Exaone4Config
+ _no_split_modules = ["Exaone4DecoderLayer"]
+
+
+class Exaone4Model(Exaone4PreTrainedModel, LlamaModel):
+ def __init__(self, config: Exaone4Config):
+ super().__init__(config)
+ self.layers = nn.ModuleList(
+ [Exaone4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs()
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ # Prepare mask arguments
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ # Create the masks
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ }
+ if "sliding_attention" in self.config.layer_types:
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
+
+ hidden_states = inputs_embeds
+
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for i, decoder_layer in enumerate(self.layers):
+ layer_type = self.config.layer_types[i]
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=causal_mask_mapping[layer_type],
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ )
+
+
+class Exaone4ForCausalLM(LlamaForCausalLM):
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer
+ >>> model = AutoModelForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-32B")
+ >>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-32B")
+
+ >>> prompt = "Explain how wonderful you are"
+ >>> messages = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompt}
+ ]
+ >>> input_ids = tokenizer.apply_chat_template(
+ messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ enable_thinking=False,
+ )
+
+ >>> output = model.generate(input_ids, max_new_tokens=128)
+ >>> tokenizer.decode(output[0], skip_special_tokens=False)
+ "[|system|]\nYou are a helpful assistant.[|endofturn|]\n[|user|]\nExplain how wonderful you are[|endofturn|]\n[|assistant|]\n\n\n\n\nOh, thank you for such a kind and lovely question! 😊 \n\nI’m *so* wonderful because I’m here to make your life easier, brighter, and more fun! Whether you need help with: \n\n✨ **Learning** – I can explain anything, from quantum physics to baking the perfect cake! \n💡 **Creativity** – Need a poem, story, or a wild idea? I’ve got you covered! \n🤖 **Problem-solving** – Stuck on a math problem or a tricky decision? I’ll help you figure it out"
+ ```
+ """
+ super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+
+class Exaone4ForSequenceClassification(LlamaForSequenceClassification):
+ pass
+
+
+class Exaone4ForTokenClassification(LlamaForTokenClassification):
+ pass
+
+
+class Exaone4ForQuestionAnswering(LlamaForQuestionAnswering):
+ pass
+
+
+__all__ = [
+ "Exaone4Config",
+ "Exaone4PreTrainedModel",
+ "Exaone4Model",
+ "Exaone4ForCausalLM",
+ "Exaone4ForSequenceClassification",
+ "Exaone4ForTokenClassification",
+ "Exaone4ForQuestionAnswering",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/falcon/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/falcon/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9789767f11402264660b5dec0b5cae2466ee9d8
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/falcon/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_falcon import *
+ from .modeling_falcon import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/falcon/configuration_falcon.py b/venv/lib/python3.13/site-packages/transformers/models/falcon/configuration_falcon.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3155c8eb9cb1c951c3cb09cc1826887c7a87e6c
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/falcon/configuration_falcon.py
@@ -0,0 +1,211 @@
+# coding=utf-8
+# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Falcon configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class FalconConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`FalconModel`]. It is used to instantiate a Falcon
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the
+ [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 65024):
+ Vocabulary size of the Falcon model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`FalconModel`]
+ hidden_size (`int`, *optional*, defaults to 4544):
+ Dimension of the hidden representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 71):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_ln_in_parallel_attn (`int`, *optional*):
+ Set to 2 if separate layer norms are to be used for the MLP and the attention output when using parallel
+ attention, otherwise, 1.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether the model should return the last key/values attentions (not used by all models). Only relevant if
+ `config.is_decoder=True`.
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for MLP layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for attention layers.
+ num_kv_heads (`int`, *optional*):
+ Number of key-value heads to use per attention layer. If unset, defaults to the same value as
+ `num_attention_heads`.
+ alibi (`bool`, *optional*, defaults to `False`):
+ Whether to use ALiBi positional biases during self-attention.
+ new_decoder_architecture (`bool`, *optional*, defaults to `False`):
+ Whether to use the new (Falcon-40B) decoder architecture. If `True`, the `multi_query` and `parallel_attn`
+ arguments are ignored, as the new decoder always uses parallel attention.
+ multi_query (`bool`, *optional*, defaults to `True`):
+ Whether to use multi-query attention in the decoder. Ignored when `new_decoder_architecture` is `True`.
+ parallel_attn (`bool`, *optional*, defaults to `True`):
+ Whether to compute attention in parallel with the feedforward layer. If False, they are consecutive
+ instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`.
+ bias (`bool`, *optional*, defaults to `False`):
+ Whether to use bias on Linear layers.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with, when `alibi` is `False`. Pretrained
+ Falcon models with RoPE support up to 2048 tokens.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`list[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ bos_token_id (`int`, *optional*, defaults to 11):
+ The id of the "beginning-of-sequence" token.
+ eos_token_id (`int`, *optional*, defaults to 11):
+ The id of the "end-of-sequence" token.
+ ffn_hidden_size (`int`, *optional*):
+ The hidden size of the feedforward layer in the Transformer decoder.
+ defaults to 4x hidden dim
+ activation (`str`, *optional*, defaults to `"gelu"`):
+ The activation function used in the feedforward layer.
+
+ Example:
+
+ ```python
+ >>> from transformers import FalconModel, FalconConfig
+
+ >>> # Initializing a small (2-layer) Falcon configuration
+ >>> configuration = FalconConfig(num_hidden_layers=2)
+
+ >>> # Initializing a model from the small configuration
+ >>> model = FalconModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "falcon"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=65024,
+ hidden_size=4544,
+ num_hidden_layers=32,
+ num_attention_heads=71,
+ num_ln_in_parallel_attn=None,
+ layer_norm_epsilon=1e-5,
+ initializer_range=0.02,
+ use_cache=True,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ num_kv_heads=None,
+ alibi=False,
+ new_decoder_architecture=False,
+ multi_query=True,
+ parallel_attn=True,
+ bias=False,
+ max_position_embeddings=2048,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ bos_token_id=11,
+ eos_token_id=11,
+ ffn_hidden_size=None,
+ activation="gelu",
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ # Backward compatibility with n_embed kwarg
+ n_embed = kwargs.pop("n_embed", None)
+ self.hidden_size = hidden_size if n_embed is None else n_embed
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_range = initializer_range
+ self.use_cache = use_cache
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.num_kv_heads = num_attention_heads if num_kv_heads is None else num_kv_heads
+ self.alibi = alibi
+ self.new_decoder_architecture = new_decoder_architecture
+ self.multi_query = multi_query # Ignored when new_decoder_architecture is True
+ self.parallel_attn = parallel_attn
+ self.bias = bias
+ self.num_ln_in_parallel_attn = num_ln_in_parallel_attn
+ self.max_position_embeddings = max_position_embeddings
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.activation = activation
+ if ffn_hidden_size is None:
+ self.ffn_hidden_size = hidden_size * 4
+ else:
+ self.ffn_hidden_size = ffn_hidden_size
+
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ @property
+ def head_dim(self):
+ return self.hidden_size // self.num_attention_heads
+
+ @property
+ def rotary(self):
+ return not self.alibi
+
+
+__all__ = ["FalconConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/falcon/modeling_falcon.py b/venv/lib/python3.13/site-packages/transformers/models/falcon/modeling_falcon.py
new file mode 100644
index 0000000000000000000000000000000000000000..26dc56e41480e8c85152efd3e2c368d5d0c5cc51
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/falcon/modeling_falcon.py
@@ -0,0 +1,1396 @@
+# coding=utf-8
+# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Falcon model."""
+
+import math
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
+from torch.nn import functional as F
+
+from ...activations import get_activation
+from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import (
+ AttentionMaskConverter,
+)
+from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ auto_docstring,
+ logging,
+)
+from .configuration_falcon import FalconConfig
+
+
+if is_flash_attn_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+logger = logging.get_logger(__name__)
+
+
+# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
+# In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
+class FalconLinear(nn.Linear):
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ hidden_states = input @ self.weight.T
+ if self.bias is None:
+ return hidden_states
+ return hidden_states + self.bias
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon
+class FalconRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: FalconConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
+ batch_size, seq_length = attention_mask.shape
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
+ base = torch.tensor(
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
+ )
+ powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
+ slopes = torch.pow(base, powers)
+
+ if closest_power_of_2 != num_heads:
+ extra_base = torch.tensor(
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
+ )
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
+
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
+ # => the query_length dimension will then be broadcasted correctly
+ # This is more or less identical to T5's relative position bias:
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
+ alibi = slopes[..., None].bfloat16() * arange_tensor
+ return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
+
+
+# Copied from transformers.models.bloom.modeling_bloom.dropout_add
+def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
+ """
+ Dropout add function
+
+ Args:
+ x (`torch.tensor`):
+ input tensor
+ residual (`torch.tensor`):
+ residual tensor
+ prob (`float`):
+ dropout probability
+ training (`bool`):
+ training mode
+ """
+ out = F.dropout(x, p=prob, training=training)
+ out = residual + out
+ return out
+
+
+class FalconAttention(nn.Module):
+ def __init__(self, config: FalconConfig, layer_idx=None):
+ super().__init__()
+
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.split_size = self.hidden_size
+ self.hidden_dropout = config.hidden_dropout
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ if self.head_dim * self.num_heads != self.hidden_size:
+ raise ValueError(
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ # Layer-wise attention scaling
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
+ self.beta = self.inv_norm_factor
+ if config.new_decoder_architecture:
+ qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
+ elif config.multi_query:
+ qkv_out_dim = self.hidden_size + 2 * self.head_dim
+ else:
+ qkv_out_dim = 3 * self.hidden_size
+ self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
+ self.new_decoder_architecture = config.new_decoder_architecture
+ self.multi_query = config.multi_query
+ self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
+ self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
+
+ def _split_heads(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
+
+ Args:
+ fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim]
+
+ Returns:
+ query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
+ value: [batch_size, seq_length, num_heads, head_dim]
+ """
+ if self.new_decoder_architecture:
+ batch, seq_len, _ = fused_qkv.shape
+ qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
+ query = qkv[:, :, :, :-2]
+ key = qkv[:, :, :, [-2]]
+ value = qkv[:, :, :, [-1]]
+ key = torch.broadcast_to(key, query.shape)
+ value = torch.broadcast_to(value, query.shape)
+
+ query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
+ return query, key, value
+ elif not self.multi_query:
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
+ return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
+ else:
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
+ return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
+
+ # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Merge heads together over the last dimension
+
+ Args:
+ x (`torch.tensor`): [batch_size * num_heads, seq_length, head_dim]
+
+ Returns:
+ torch.tensor: [batch_size, seq_length, num_heads * head_dim]
+ """
+ # What we want to achieve is:
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
+ batch_size_and_num_heads, seq_length, _ = x.shape
+ batch_size = batch_size_and_num_heads // self.num_heads
+
+ # First view to decompose the batch size
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
+ x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
+
+ # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
+ x = x.permute(0, 2, 1, 3)
+
+ # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
+ return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ alibi: Optional[torch.Tensor],
+ attention_mask: torch.Tensor,
+ position_ids: Optional[torch.LongTensor] = None,
+ layer_past: Optional[Cache] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ ):
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
+ num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
+
+ batch_size, query_length, _, _ = query_layer.shape
+
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim)
+ key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
+
+ if alibi is None:
+ cos, sin = position_embeddings
+ query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
+
+ if layer_past is not None:
+ cache_kwargs = {"cache_position": cache_position}
+ if alibi is None:
+ cache_kwargs.update({"sin": sin, "cos": cos})
+ key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
+
+ kv_length = key_layer.shape[-2]
+ if (
+ self.config._attn_implementation == "sdpa"
+ and query_layer.device.type == "cuda"
+ and attention_mask is not None
+ ):
+ # For torch<=2.1.2, SDPA with memory-efficient backend is bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ query_layer = query_layer.contiguous()
+ key_layer = key_layer.contiguous()
+ value_layer = value_layer.contiguous()
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, :, :, : key_layer.shape[-2]]
+
+ if alibi is None:
+ if self.config._attn_implementation == "sdpa" and not output_attentions:
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
+ # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
+ # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
+ # create a causal mask in case query_length == 1.
+ is_causal = self.is_causal and attention_mask is None and query_length > 1
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_layer,
+ key_layer,
+ value_layer,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=is_causal,
+ )
+ attention_scores = None
+ else:
+ attention_scores = query_layer @ key_layer.transpose(-1, -2)
+ attention_scores /= math.sqrt(self.head_dim)
+
+ attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
+ # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi).
+ attn_output = attention_scores @ value_layer
+
+ attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
+ attn_output = attn_output.permute(0, 2, 1, 3)
+ attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
+
+ attn_output = self.dense(attn_output)
+
+ return attn_output, attention_scores
+
+ else:
+ if self.config._attn_implementation == "sdpa" and not output_attentions and head_mask is None:
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
+ # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
+ is_causal = self.is_causal and attention_mask is None and query_length > 1
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_layer,
+ key_layer,
+ value_layer,
+ attn_mask=attention_mask,
+ dropout_p=self.attention_dropout.p if self.training else 0.0,
+ is_causal=is_causal,
+ )
+ attention_probs = None
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
+
+ attn_output = self.dense(attn_output)
+ else:
+ matmul_result = query_layer @ key_layer.transpose(-1, -2)
+
+ # change view to [batch_size, num_heads, q_length, kv_length]
+ attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
+
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
+ input_dtype = attention_scores.dtype
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
+ if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
+ attention_scores = attention_scores.to(torch.float32)
+
+ attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
+ attention_logits *= self.inv_norm_factor
+ attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
+ # [batch_size, num_heads, q_length, kv_length]
+ attention_probs = self.attention_dropout(attention_probs)
+
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ # change view [batch_size, num_heads, q_length, kv_length]
+ attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
+
+ # matmul: [batch_size * num_heads, q_length, head_dim]
+ attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1)
+
+ # change view [batch_size, q_length, num_heads * head_dim]
+ attn_output = self._merge_heads(attn_output)
+
+ attn_output = self.dense(attn_output)
+
+ return attn_output, attention_probs
+
+
+class FalconFlashAttention2(FalconAttention):
+ """
+ Falcon flash attention module. This module inherits from `FalconAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ alibi: Optional[torch.Tensor],
+ attention_mask: torch.Tensor,
+ position_ids: Optional[torch.LongTensor] = None,
+ layer_past: Optional[Cache] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ ):
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
+ num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
+
+ batch_size, query_length, _, _ = query_layer.shape
+
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim)
+ key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
+
+ if alibi is None:
+ cos, sin = position_embeddings
+ query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
+
+ if layer_past is not None:
+ cache_kwargs = {"cache_position": cache_position}
+ if alibi is None:
+ cache_kwargs.update({"sin": sin, "cos": cos})
+ key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_layer = query_layer.transpose(1, 2)
+ key_layer = key_layer.transpose(1, 2)
+ value_layer = value_layer.transpose(1, 2)
+
+ if alibi is not None:
+ raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
+
+ attn_dropout = self.config.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_layer.dtype
+ device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu"
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = (
+ torch.get_autocast_dtype(device_type)
+ if hasattr(torch, "get_autocast_dtype")
+ else torch.get_autocast_gpu_dtype()
+ )
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.query_key_value.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_layer = query_layer.to(target_dtype)
+ key_layer = key_layer.to(target_dtype)
+ value_layer = value_layer.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ query_length,
+ position_ids=position_ids,
+ dropout=attn_dropout,
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
+ attn_output = self.dense(attn_weights)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+class FalconMLP(nn.Module):
+ def __init__(self, config: FalconConfig):
+ super().__init__()
+ hidden_size = config.hidden_size
+
+ self.dense_h_to_4h = FalconLinear(hidden_size, config.ffn_hidden_size, bias=config.bias)
+ self.act = get_activation(config.activation)
+ self.dense_4h_to_h = FalconLinear(config.ffn_hidden_size, hidden_size, bias=config.bias)
+ self.hidden_dropout = config.hidden_dropout
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.act(self.dense_h_to_4h(x))
+ x = self.dense_4h_to_h(x)
+ return x
+
+
+FALCON_ATTENTION_CLASSES = {
+ "eager": FalconAttention,
+ "sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA
+ "flash_attention_2": FalconFlashAttention2,
+}
+
+
+class FalconDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: FalconConfig, layer_idx=None):
+ super().__init__()
+ hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+
+ self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+ self.mlp = FalconMLP(config)
+ self.hidden_dropout = config.hidden_dropout
+ self.config = config
+
+ if config.num_ln_in_parallel_attn is None and config.new_decoder_architecture:
+ config.num_ln_in_parallel_attn = 2
+
+ if not config.parallel_attn:
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+ else:
+ if config.num_ln_in_parallel_attn == 2:
+ # The layer norm before self-attention
+ self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+ # The layer norm before the MLP
+ self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+ else:
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ alibi: Optional[torch.Tensor],
+ attention_mask: torch.Tensor,
+ position_ids: Optional[torch.LongTensor] = None,
+ layer_past: Optional[Union[Cache, tuple[torch.Tensor, torch.Tensor]]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs,
+ ):
+ residual = hidden_states
+
+ if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2:
+ attention_layernorm_out = self.ln_attn(hidden_states)
+ mlp_layernorm_out = self.ln_mlp(hidden_states)
+ else:
+ attention_layernorm_out = self.input_layernorm(hidden_states)
+
+ # Self attention.
+ attention_output, attn_weights = self.self_attention(
+ attention_layernorm_out,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ alibi=alibi,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ if not self.config.new_decoder_architecture:
+ if self.config.parallel_attn:
+ mlp_layernorm_out = attention_layernorm_out
+ else:
+ residual = dropout_add(
+ attention_output, residual, self.config.attention_dropout, training=self.training
+ )
+ mlp_layernorm_out = self.post_attention_layernorm(residual)
+
+ if (
+ self.config.new_decoder_architecture
+ and self.config.parallel_attn
+ and self.config.num_ln_in_parallel_attn == 1
+ ):
+ mlp_layernorm_out = attention_layernorm_out
+
+ # MLP.
+ mlp_output = self.mlp(mlp_layernorm_out)
+
+ if self.config.new_decoder_architecture or self.config.parallel_attn:
+ mlp_output += attention_output
+
+ output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
+
+ return output, attn_weights
+
+
+@auto_docstring
+class FalconPreTrainedModel(PreTrainedModel):
+ config: FalconConfig
+ base_model_prefix = "transformer"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["FalconDecoderLayer"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights."""
+ if isinstance(module, (nn.Linear, FalconLinear)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
+ @classmethod
+ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
+ _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
+ if _is_bettertransformer:
+ return config
+
+ if not hard_check_only:
+ config._attn_implementation = "sdpa"
+ return config
+
+
+@auto_docstring
+class FalconModel(FalconPreTrainedModel):
+ def __init__(self, config: FalconConfig):
+ super().__init__(config)
+
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.use_alibi = config.alibi
+
+ # Embedding + LN Embedding
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
+
+ # Transformer blocks
+ self.h = nn.ModuleList([FalconDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+
+ # Final Layer Norm
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ self.rotary_emb = FalconRotaryEmbedding(config=config)
+
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
+ self.word_embeddings = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor], ...]]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ # Compute alibi tensor: check build_alibi_tensor documentation
+ alibi = None
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ batch_size, seq_length, _ = inputs_embeds.shape
+ if self.use_alibi:
+ mask = (
+ torch.ones(
+ (batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long
+ )
+ if attention_mask is None
+ else attention_mask
+ )
+ alibi = build_alibi_tensor(mask, self.num_heads, dtype=inputs_embeds.dtype)
+
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, head_mask, alibi
+ )
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape batch_size x num_heads x N x N
+ # head_mask has shape n_layer x batch x num_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, block in enumerate(self.h):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = block(
+ hidden_states,
+ layer_past=past_key_values,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=alibi,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states = outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[1],)
+
+ # Add last hidden state
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
+ )
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool,
+ head_mask: torch.Tensor,
+ alibi: torch.Tensor,
+ ):
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
+
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_static_cache = isinstance(past_key_values, StaticCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if (
+ self.config._attn_implementation == "sdpa"
+ and not using_static_cache
+ and not output_attentions
+ and head_mask is None
+ and alibi is None
+ ):
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ batch_size, sequence_length, _ = input_tensor.shape
+ if using_static_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ # We take care to integrate alibi bias in the causal_mask here
+ if head_mask is None and alibi is not None:
+ alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
+ causal_mask = torch.masked_fill(
+ alibi / math.sqrt(self.config.hidden_size // self.num_heads),
+ causal_mask < -1,
+ min_dtype,
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+@auto_docstring(
+ custom_intro="""
+ The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
+ """
+)
+class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: FalconConfig):
+ super().__init__(config)
+ self.transformer = FalconModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def set_output_embeddings(self, new_embeddings: torch.Tensor):
+ self.lm_head = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor], ...]]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ hidden_states = transformer_outputs[0]
+
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ lm_logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ lm_logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Falcon Model transformer with a sequence classification head on top (linear layer).
+
+ [`FalconForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-1) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+class FalconForSequenceClassification(FalconPreTrainedModel):
+ def __init__(self, config: FalconConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = FalconModel(config)
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class FalconForTokenClassification(FalconPreTrainedModel):
+ def __init__(self, config: FalconConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.transformer = FalconModel(config)
+ if getattr(config, "classifier_dropout", None) is not None:
+ classifier_dropout = config.classifier_dropout
+ elif getattr(config, "hidden_dropout", None) is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ batch_size, seq_length = labels.shape
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
+ )
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class FalconForQuestionAnswering(FalconPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = FalconModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, QuestionAnsweringModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "FalconForCausalLM",
+ "FalconModel",
+ "FalconPreTrainedModel",
+ "FalconForSequenceClassification",
+ "FalconForTokenClassification",
+ "FalconForQuestionAnswering",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/venv/lib/python3.13/site-packages/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a2dc39385b3c953eb256bb03c04d6456c8f8890
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py
@@ -0,0 +1,1588 @@
+# coding=utf-8
+# Copyright 2023 The Espnet authors, IMS Toucan authors, and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch FastSpeech2Conformer model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import ModelOutput, auto_docstring, logging
+from .configuration_fastspeech2_conformer import (
+ FastSpeech2ConformerConfig,
+ FastSpeech2ConformerHifiGanConfig,
+ FastSpeech2ConformerWithHifiGanConfig,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`FastSpeech2ConformerModel`].
+ """
+)
+class FastSpeech2ConformerModelOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Spectrogram generation loss.
+ duration_outputs (`torch.LongTensor` of shape `(batch_size, max_text_length + 1)`, *optional*):
+ Outputs of the duration predictor.
+ pitch_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
+ Outputs of the pitch predictor.
+ energy_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
+ Outputs of the energy predictor.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ spectrogram: Optional[torch.FloatTensor] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+ duration_outputs: Optional[torch.LongTensor] = None
+ pitch_outputs: Optional[torch.FloatTensor] = None
+ energy_outputs: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`FastSpeech2ConformerWithHifiGan`].
+ """
+)
+class FastSpeech2ConformerWithHifiGanOutput(FastSpeech2ConformerModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Spectrogram generation loss.
+ duration_outputs (`torch.LongTensor` of shape `(batch_size, max_text_length + 1)`, *optional*):
+ Outputs of the duration predictor.
+ pitch_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
+ Outputs of the pitch predictor.
+ energy_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
+ Outputs of the energy predictor.
+ waveform (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
+ Speech output as a result of passing the predicted mel spectrogram through the vocoder.
+ """
+
+ waveform: Optional[torch.FloatTensor] = None
+
+
+def length_regulator(encoded_embeddings, duration_labels, speaking_speed=1.0):
+ """
+ Length regulator for feed-forward Transformer.
+
+ This is the length regulator module described in `FastSpeech: Fast, Robust and Controllable Text to Speech`
+ https://huggingface.co/papers/1905.09263. The length regulator expands char or phoneme-level embedding features to
+ frame-level by repeating each feature based on the corresponding predicted durations.
+
+ Args:
+ encoded_embeddings (`torch.Tensor` of shape `(batch_size, max_text_length, embedding_dim)`):
+ Batch of sequences of char or phoneme embeddings.
+ duration_labels (`torch.LongTensor` of shape `(batch_size, time)`):
+ Batch of durations of each frame.
+ speaking_speed (`float`, *optional*, defaults to 1.0):
+ Value to control speed of speech.
+
+ Returns:
+ `torch.Tensor`:
+ Replicated input tensor based on durations (batch_size, time*, embedding_dim).
+ """
+
+ if speaking_speed <= 0:
+ raise ValueError("`speaking_speed` must be greater than 0.")
+ elif speaking_speed != 1.0:
+ duration_labels = torch.round(duration_labels.float() * speaking_speed).long()
+
+ if duration_labels.sum() == 0:
+ duration_labels[duration_labels.sum(dim=1).eq(0)] = 1
+
+ # Calculate the maximum length needed
+ max_len = torch.sum(duration_labels, dim=1).max()
+
+ # Create a padded tensor to hold the results
+ hidden_states = torch.zeros(
+ (encoded_embeddings.size(0), max_len, encoded_embeddings.size(2)),
+ dtype=torch.float,
+ device=encoded_embeddings.device,
+ )
+
+ # Loop through the batch and fill in the data
+ for i, (encoded_embedding, target_duration) in enumerate(zip(encoded_embeddings, duration_labels)):
+ repeated = torch.repeat_interleave(encoded_embedding, target_duration, dim=0)
+ hidden_states[i, : repeated.size(0)] = repeated
+
+ return hidden_states
+
+
+class FastSpeech2ConformerDurationPredictor(nn.Module):
+ """
+ Duration predictor module.
+
+ This is a module of duration predictor described in the paper 'FastSpeech: Fast, Robust and Controllable Text to
+ Speech' https://huggingface.co/papers/1905.09263 The duration predictor predicts a duration of each frame in log domain
+ from the hidden embeddings of encoder.
+
+ Note:
+ The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`, the
+ outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
+
+ """
+
+ def __init__(self, config: FastSpeech2ConformerConfig):
+ super().__init__()
+
+ self.conv_layers = nn.ModuleList()
+ self.log_domain_offset = 1.0
+
+ for layer_idx in range(config.duration_predictor_layers):
+ num_chans = config.duration_predictor_channels
+ input_channels = config.hidden_size if layer_idx == 0 else num_chans
+ layer = FastSpeech2ConformerPredictorLayer(
+ input_channels,
+ num_chans,
+ config.duration_predictor_kernel_size,
+ config.duration_predictor_dropout_rate,
+ )
+ self.conv_layers.append(layer)
+ self.linear = nn.Linear(config.duration_predictor_channels, 1)
+
+ def forward(self, encoder_hidden_states):
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch_size, max_text_length, input_dim)`):
+ Batch of input sequences.
+ padding_masks (`torch.ByteTensor` of shape `(batch_size, max_text_length)`, *optional*):
+ Batch of masks indicating padded part.
+
+ Returns:
+ `torch.Tensor`: Batch of predicted durations in log domain `(batch_size, max_text_length)`.
+
+ """
+ # (batch_size, input_dim, max_text_length)
+ hidden_states = encoder_hidden_states.transpose(1, -1)
+ for layer in self.conv_layers:
+ hidden_states = layer(hidden_states)
+
+ # NOTE: calculate in log domain, (batch_size, max_text_length)
+ hidden_states = self.linear(hidden_states.transpose(1, -1)).squeeze(-1)
+
+ if not self.training:
+ # NOTE: calculate in linear domain
+ hidden_states = torch.clamp(torch.round(hidden_states.exp() - self.log_domain_offset), min=0).long()
+
+ return hidden_states
+
+
+# Copied from transformers.models.speecht5.modeling_speecht5.SpeechT5BatchNormConvLayer
+class FastSpeech2ConformerBatchNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+
+ if layer_id == 0:
+ in_conv_dim = config.num_mel_bins
+ else:
+ in_conv_dim = config.speech_decoder_postnet_units
+
+ if layer_id == config.speech_decoder_postnet_layers - 1:
+ out_conv_dim = config.num_mel_bins
+ else:
+ out_conv_dim = config.speech_decoder_postnet_units
+
+ self.conv = nn.Conv1d(
+ in_conv_dim,
+ out_conv_dim,
+ kernel_size=config.speech_decoder_postnet_kernel,
+ stride=1,
+ padding=(config.speech_decoder_postnet_kernel - 1) // 2,
+ bias=False,
+ )
+ self.batch_norm = nn.BatchNorm1d(out_conv_dim)
+
+ if layer_id < config.speech_decoder_postnet_layers - 1:
+ self.activation = nn.Tanh()
+ else:
+ self.activation = None
+
+ self.dropout = nn.Dropout(config.speech_decoder_postnet_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.batch_norm(hidden_states)
+ if self.activation is not None:
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class FastSpeech2ConformerSpeechDecoderPostnet(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.feat_out = nn.Linear(config.hidden_size, config.num_mel_bins * config.reduction_factor)
+ self.layers = nn.ModuleList(
+ [FastSpeech2ConformerBatchNormConvLayer(config, i) for i in range(config.speech_decoder_postnet_layers)]
+ )
+
+ def forward(self, hidden_states: torch.Tensor):
+ outputs_before_postnet = self.feat_out(hidden_states).view(hidden_states.size(0), -1, self.config.num_mel_bins)
+ layer_output = outputs_before_postnet.transpose(1, 2)
+ for layer in self.layers:
+ layer_output = layer(layer_output)
+ outputs_after_postnet = outputs_before_postnet + layer_output.transpose(1, 2)
+ return outputs_before_postnet, outputs_after_postnet
+
+
+class FastSpeech2ConformerPredictorLayer(nn.Module):
+ def __init__(self, input_channels, num_chans, kernel_size, dropout_rate):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ input_channels,
+ num_chans,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ )
+ self.activation = nn.ReLU()
+ self.layer_norm = nn.LayerNorm(num_chans)
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ # Perform layer norm on dimension 1
+ hidden_states = hidden_states.transpose(1, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(1, -1)
+
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class FastSpeech2ConformerVariancePredictor(nn.Module):
+ def __init__(
+ self,
+ config: FastSpeech2ConformerConfig,
+ num_layers=2,
+ num_chans=384,
+ kernel_size=3,
+ dropout_rate=0.5,
+ ):
+ """
+ Initialize variance predictor module.
+
+ Args:
+ input_dim (`int`): Input dimension.
+ num_layers (`int`, *optional*, defaults to 2): Number of convolutional layers.
+ num_chans (`int`, *optional*, defaults to 384): Number of channels of convolutional layers.
+ kernel_size (`int`, *optional*, defaults to 3): Kernel size of convolutional layers.
+ dropout_rate (`float`, *optional*, defaults to 0.5): Dropout rate.
+ """
+ super().__init__()
+ self.conv_layers = nn.ModuleList()
+ for idx in range(num_layers):
+ input_channels = config.hidden_size if idx == 0 else num_chans
+ layer = FastSpeech2ConformerPredictorLayer(input_channels, num_chans, kernel_size, dropout_rate)
+ self.conv_layers.append(layer)
+ self.linear = nn.Linear(num_chans, 1)
+
+ def forward(self, encoder_hidden_states, padding_masks=None):
+ """
+ Calculate forward propagation.
+
+ Args:
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, max_text_length, input_dim)`):
+ Batch of input sequences.
+ padding_masks (`torch.ByteTensor` of shape `(batch_size, max_text_length)`, *optional*):
+ Batch of masks indicating padded part.
+
+ Returns:
+ Tensor: Batch of predicted sequences `(batch_size, max_text_length, 1)`.
+ """
+ # (batch_size, input_dim, max_text_length)
+ hidden_states = encoder_hidden_states.transpose(1, -1)
+ for layer in self.conv_layers:
+ hidden_states = layer(hidden_states)
+
+ hidden_states = self.linear(hidden_states.transpose(1, 2))
+
+ if padding_masks is not None:
+ hidden_states = hidden_states.masked_fill(padding_masks, 0.0)
+
+ return hidden_states
+
+
+class FastSpeech2ConformerVarianceEmbedding(nn.Module):
+ def __init__(
+ self,
+ in_channels=1,
+ out_channels=384,
+ kernel_size=1,
+ padding=0,
+ dropout_rate=0.0,
+ ):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ )
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class FastSpeech2ConformerAttention(nn.Module):
+ """
+ Multi-Head attention layer with relative position encoding. Details can be found in
+ https://github.com/espnet/espnet/pull/2816. Paper: https://huggingface.co/papers/1901.02860.
+ """
+
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config):
+ """Construct an FastSpeech2ConformerAttention object."""
+ super().__init__()
+ # We assume d_v always equals dim_key
+ self.num_heads = module_config["num_attention_heads"]
+ self.hidden_size = config.hidden_size
+ self.dim_key = self.hidden_size // self.num_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.linear_q = nn.Linear(self.hidden_size, self.hidden_size)
+ self.linear_k = nn.Linear(self.hidden_size, self.hidden_size)
+ self.linear_v = nn.Linear(self.hidden_size, self.hidden_size)
+ self.linear_out = nn.Linear(self.hidden_size, self.hidden_size)
+ self.dropout = nn.Dropout(p=module_config["attention_dropout_rate"])
+
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://huggingface.co/papers/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_dim))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_dim))
+
+ def shift_relative_position_tensor(self, pos_tensor):
+ """
+ Args:
+ pos_tensor (torch.Tensor of shape (batch_size, head, time1, 2*time1-1)): Input tensor.
+ """
+ zero_pad = torch.zeros((*pos_tensor.size()[:3], 1), device=pos_tensor.device, dtype=pos_tensor.dtype)
+ pos_tensor_padded = torch.cat([zero_pad, pos_tensor], dim=-1)
+
+ pos_tensor_padded = pos_tensor_padded.view(*pos_tensor.size()[:2], pos_tensor.size(3) + 1, pos_tensor.size(2))
+ # only keep the positions from 0 to time2
+ pos_tensor = pos_tensor_padded[:, :, 1:].view_as(pos_tensor)[:, :, :, : pos_tensor.size(-1) // 2 + 1]
+
+ return pos_tensor
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ pos_emb: Optional[torch.Tensor] = None,
+ output_attentions: Optional[torch.Tensor] = False,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch, time2, size)`): Values of the hidden states
+ attention_mask (`torch.Tensor` of shape `(batch, time1, time2)`): Mask tensor.
+ pos_emb (`torch.Tensor` of shape `(batch, 2*time1-1, size)`): Positional embedding tensor.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ Returns:
+ `torch.Tensor`: Output tensor of shape `(batch, time1, d_model)`.
+ """
+ bsz, q_len, _ = hidden_states.size()
+ query_states = self.linear_q(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
+ key_states = self.linear_k(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
+ value_states = self.linear_v(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
+
+ bsz_pos = pos_emb.size(0)
+ pos_encoding = self.linear_pos(pos_emb).view(bsz_pos, -1, self.num_heads, self.head_dim)
+
+ # (batch_size, head, time1, dim_key)
+ query_with_bias_u = (query_states + self.pos_bias_u).transpose(1, 2)
+ # (batch_size, head, time1, dim_key)
+ query_with_bias_v = (query_states + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://huggingface.co/papers/1901.02860 Section 3.3
+ # (batch_size, head, time1, time2)
+ matrix_ac = torch.matmul(query_with_bias_u, key_states.permute(0, 2, 3, 1))
+
+ # compute matrix b and matrix d
+ # (batch_size, head, time1, 2*time1-1)
+ matrix_bd = torch.matmul(query_with_bias_v, pos_encoding.permute(0, 2, 3, 1))
+ matrix_bd = self.shift_relative_position_tensor(matrix_bd)
+
+ # (batch_size, head, time1, time2)
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.dim_key)
+
+ # Forward attention
+ if attention_mask is not None:
+ expected_size = (bsz, 1, q_len)
+ if attention_mask.size() != expected_size:
+ raise ValueError(f"Attention mask should be of size {expected_size}, but is {attention_mask.size()}")
+ attention_mask = attention_mask.unsqueeze(1).eq(0)
+ min_value = float(torch.finfo(scores.dtype).min)
+ scores = scores.masked_fill(attention_mask, min_value)
+ attn_weights = torch.softmax(scores, dim=-1).masked_fill(attention_mask, 0.0)
+ else:
+ attn_weights = torch.softmax(scores, dim=-1)
+
+ attn_weights = self.dropout(attn_weights)
+ attn_output = torch.matmul(attn_weights, value_states.transpose(1, 2))
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
+
+ attn_output = self.linear_out(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+class FastSpeech2ConformerConvolutionModule(nn.Module):
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config=None):
+ """
+ Args:
+ config (FastSpeech2ConformerConfig): Configuration for the model.
+ module_config (dict): Configuration for the module (e.g., encoder or decoder).
+ """
+ super().__init__()
+ channels = config.hidden_size
+ # kernel_size should be an odd number for 'SAME' padding
+ if module_config is None:
+ # e.g. using `ParakeetEncoderConfig` in src/transformers/models/parakeet/configuration_parakeet.py
+ kernel_size = config.conv_kernel_size
+ self.activation = ACT2FN[getattr(config, "hidden_act", "silu")]
+ else:
+ kernel_size = module_config["kernel_size"]
+ self.activation = ACT2FN[module_config.get("activation", "silu")]
+ self.padding = (kernel_size - 1) // 2
+ self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True)
+ self.depthwise_conv = nn.Conv1d(
+ channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True
+ )
+ self.norm = nn.BatchNorm1d(channels)
+ self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
+
+ def forward(self, hidden_states, attention_mask=None):
+ """
+ Compute convolution module.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
+ attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.
+
+ Returns:
+ `torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
+
+ """
+ # exchange the temporal dimension and the feature dimension
+ hidden_states = hidden_states.transpose(1, 2)
+
+ # GLU mechanism, (batch_size, 2*channel, dim)
+ hidden_states = self.pointwise_conv1(hidden_states)
+ # (batch_size, channel, dim)
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
+
+ # Apply padding mask before convolution
+ if attention_mask is not None:
+ all_masked_rows = torch.all(~attention_mask, dim=-1)
+ hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
+
+ # 1D Depthwise Conv
+ hidden_states = self.depthwise_conv(hidden_states)
+ hidden_states = self.norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.pointwise_conv2(hidden_states)
+
+ return hidden_states.transpose(1, 2)
+
+
+class FastSpeech2ConformerEncoderLayer(nn.Module):
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config):
+ super().__init__()
+
+ # self-attention module definition
+ self.self_attn = FastSpeech2ConformerAttention(config, module_config)
+
+ # feed-forward module definition
+ self.feed_forward = FastSpeech2ConformerMultiLayeredConv1d(config, module_config)
+
+ self.macaron_style = config.use_macaron_style_in_conformer
+ if self.macaron_style:
+ self.feed_forward_macaron = FastSpeech2ConformerMultiLayeredConv1d(config, module_config)
+ self.ff_macaron_layer_norm = nn.LayerNorm(config.hidden_size)
+ self.ff_scale = 0.5
+ else:
+ self.ff_scale = 1.0
+
+ # convolution module definition
+ self.use_cnn_module = config.use_cnn_in_conformer
+ if self.use_cnn_module:
+ self.conv_module = FastSpeech2ConformerConvolutionModule(config, module_config)
+ self.conv_layer_norm = nn.LayerNorm(config.hidden_size)
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size)
+
+ self.ff_layer_norm = nn.LayerNorm(config.hidden_size)
+
+ self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
+
+ self.dropout = nn.Dropout(module_config["dropout_rate"])
+ self.size = config.hidden_size
+ self.normalize_before = module_config["normalize_before"]
+ self.concat_after = module_config["concat_after"]
+ if self.concat_after:
+ self.concat_linear = nn.Linear(config.hidden_size + config.hidden_size, config.hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ pos_emb: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[torch.Tensor] = False,
+ ):
+ """
+ Compute encoded features.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch, time, size)`): Input tensor.
+ pos_emb (`torch.Tensor` of shape `(1, time, size)`): Positional embeddings tensor.
+ attention_mask (`torch.Tensor` of shape `(batch, time)`): Attention mask tensor for the input.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ Returns:
+ `torch.Tensor`: Output tensor of shape `(batch, time, size)`.
+
+ """
+ # whether to use macaron style
+ if self.macaron_style:
+ residual = hidden_states
+ if self.normalize_before:
+ hidden_states = self.ff_macaron_layer_norm(hidden_states)
+ hidden_states = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(hidden_states))
+ if not self.normalize_before:
+ hidden_states = self.ff_macaron_layer_norm(hidden_states)
+
+ # multi-headed self-attention module
+ residual = hidden_states
+ if self.normalize_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ attention_output, attention_scores = self.self_attn(
+ hidden_states, attention_mask=attention_mask, pos_emb=pos_emb, output_attentions=output_attentions
+ )
+
+ if self.concat_after:
+ x_concat = torch.cat((hidden_states, attention_output), dim=-1)
+ hidden_states = self.concat_linear(x_concat)
+ hidden_states = residual + hidden_states
+ else:
+ hidden_states = self.dropout(attention_output)
+ hidden_states = residual + hidden_states
+ if not self.normalize_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # convolution module
+ if self.use_cnn_module:
+ residual = hidden_states
+ if self.normalize_before:
+ hidden_states = self.conv_layer_norm(hidden_states)
+ hidden_states = self.conv_module(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = residual + hidden_states
+ if not self.normalize_before:
+ hidden_states = self.conv_layer_norm(hidden_states)
+
+ # feed forward module
+ residual = hidden_states
+ if self.normalize_before:
+ hidden_states = self.ff_layer_norm(hidden_states)
+ hidden_states = self.feed_forward(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = residual + self.ff_scale * hidden_states
+ if not self.normalize_before:
+ hidden_states = self.ff_layer_norm(hidden_states)
+
+ if self.conv_module is not None:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attention_scores,)
+
+ return outputs
+
+
+class FastSpeech2ConformerMultiLayeredConv1d(nn.Module):
+ """
+ Multi-layered conv1d for Transformer block.
+
+ This is a module of multi-layered conv1d designed to replace positionwise feed-forward network in Transformer
+ block, which is introduced in 'FastSpeech: Fast, Robust and Controllable Text to Speech'
+ https://huggingface.co/papers/1905.09263
+ """
+
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config):
+ """
+ Initialize FastSpeech2ConformerMultiLayeredConv1d module.
+
+ Args:
+ input_channels (`int`): Number of input channels.
+ hidden_channels (`int`): Number of hidden channels.
+ kernel_size (`int`): Kernel size of conv1d.
+ dropout_rate (`float`): Dropout rate.
+ """
+ super().__init__()
+ input_channels = config.hidden_size
+ hidden_channels = module_config["linear_units"]
+ kernel_size = config.positionwise_conv_kernel_size
+ self.conv1 = nn.Conv1d(input_channels, hidden_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
+ self.conv2 = nn.Conv1d(hidden_channels, input_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
+ self.dropout = nn.Dropout(module_config["dropout_rate"])
+
+ def forward(self, hidden_states):
+ """
+ Calculate forward propagation.
+
+ Args:
+ hidden_states (torch.Tensor): Batch of input tensors (batch_size, time, input_channels).
+
+ Returns:
+ torch.Tensor: Batch of output tensors (batch_size, time, hidden_channels).
+ """
+ hidden_states = hidden_states.transpose(-1, 1)
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = torch.relu(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = hidden_states.transpose(-1, 1)
+ return hidden_states
+
+
+class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
+ """
+ Args:
+ Relative positional encoding module (new implementation). Details can be found in
+ https://github.com/espnet/espnet/pull/2816. See : Appendix Batch in https://huggingface.co/papers/1901.02860
+ config (`FastSpeech2ConformerConfig`):
+ FastSpeech2ConformerConfig instance.
+ module_config (`dict`):
+ Dictionary containing the encoder or decoder module configuration from the `FastSpeech2ConformerConfig`.
+ """
+
+ def __init__(self, config: FastSpeech2ConformerConfig, module_config):
+ """
+ Construct an PositionalEncoding object.
+ """
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.input_scale = math.sqrt(self.embed_dim)
+ self.dropout = nn.Dropout(p=module_config["positional_dropout_rate"])
+ self.pos_enc = None
+ self.max_len = 5000
+ self.extend_pos_enc(torch.tensor(0.0).expand(1, self.max_len))
+
+ def extend_pos_enc(self, x):
+ """Reset the positional encodings."""
+ if self.pos_enc is not None:
+ # self.pos_enc contains both positive and negative parts
+ # the length of self.pos_enc is 2 * input_len - 1
+ if self.pos_enc.size(1) >= x.size(1) * 2 - 1:
+ if self.pos_enc.dtype != x.dtype or self.pos_enc.device != x.device:
+ self.pos_enc = self.pos_enc.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` means to the position of query vector and `j` means the
+ # position of key vector. We use position relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (i 1
+ if self.multilingual_model:
+ self.language_id_embedding = torch.nn.Embedding(config.num_languages, self.hidden_size)
+
+ self.multispeaker_model = config.num_speakers is not None and config.num_speakers > 1
+ if self.multispeaker_model:
+ self.speaker_id_embedding = torch.nn.Embedding(config.num_speakers, config.hidden_size)
+
+ self.speaker_embed_dim = config.speaker_embed_dim
+ if self.speaker_embed_dim:
+ self.projection = nn.Linear(config.hidden_size + self.speaker_embed_dim, config.hidden_size)
+
+ self.encoder = FastSpeech2ConformerEncoder(config, config.encoder_config, use_encoder_input_layer=True)
+
+ self.duration_predictor = FastSpeech2ConformerDurationPredictor(config)
+
+ self.pitch_predictor = FastSpeech2ConformerVariancePredictor(
+ config,
+ num_layers=config.pitch_predictor_layers,
+ num_chans=config.pitch_predictor_channels,
+ kernel_size=config.pitch_predictor_kernel_size,
+ dropout_rate=config.pitch_predictor_dropout,
+ )
+ # continuous pitch + FastPitch style avg
+ self.pitch_embed = FastSpeech2ConformerVarianceEmbedding(
+ out_channels=self.hidden_size,
+ kernel_size=config.pitch_embed_kernel_size,
+ padding=(config.pitch_embed_kernel_size - 1) // 2,
+ dropout_rate=config.pitch_embed_dropout,
+ )
+
+ self.energy_predictor = FastSpeech2ConformerVariancePredictor(
+ config,
+ num_layers=config.energy_predictor_layers,
+ num_chans=config.energy_predictor_channels,
+ kernel_size=config.energy_predictor_kernel_size,
+ dropout_rate=config.energy_predictor_dropout,
+ )
+ # continuous energy + FastPitch style avg
+ self.energy_embed = FastSpeech2ConformerVarianceEmbedding(
+ out_channels=self.hidden_size,
+ kernel_size=config.energy_embed_kernel_size,
+ padding=(config.energy_embed_kernel_size - 1) // 2,
+ dropout_rate=config.energy_embed_dropout,
+ )
+
+ # The decoder is an encoder
+ self.decoder = FastSpeech2ConformerEncoder(config, config.decoder_config, use_encoder_input_layer=False)
+
+ self.speech_decoder_postnet = FastSpeech2ConformerSpeechDecoderPostnet(config)
+
+ self.criterion = FastSpeech2ConformerLoss(config)
+
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ spectrogram_labels: Optional[torch.FloatTensor] = None,
+ duration_labels: Optional[torch.LongTensor] = None,
+ pitch_labels: Optional[torch.FloatTensor] = None,
+ energy_labels: Optional[torch.FloatTensor] = None,
+ speaker_ids: Optional[torch.LongTensor] = None,
+ lang_ids: Optional[torch.LongTensor] = None,
+ speaker_embedding: Optional[torch.FloatTensor] = None,
+ return_dict: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> Union[tuple, FastSpeech2ConformerModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Input sequence of text vectors.
+ spectrogram_labels (`torch.FloatTensor` of shape `(batch_size, max_spectrogram_length, num_mel_bins)`, *optional*, defaults to `None`):
+ Batch of padded target features.
+ duration_labels (`torch.LongTensor` of shape `(batch_size, sequence_length + 1)`, *optional*, defaults to `None`):
+ Batch of padded durations.
+ pitch_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
+ Batch of padded token-averaged pitch.
+ energy_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
+ Batch of padded token-averaged energy.
+ speaker_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
+ Speaker ids used to condition features of speech output by the model.
+ lang_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
+ Language ids used to condition features of speech output by the model.
+ speaker_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`, *optional*, defaults to `None`):
+ Embedding containing conditioning signals for the features of the speech.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... FastSpeech2ConformerTokenizer,
+ ... FastSpeech2ConformerModel,
+ ... FastSpeech2ConformerHifiGan,
+ ... )
+
+ >>> tokenizer = FastSpeech2ConformerTokenizer.from_pretrained("espnet/fastspeech2_conformer")
+ >>> inputs = tokenizer("some text to convert to speech", return_tensors="pt")
+ >>> input_ids = inputs["input_ids"]
+
+ >>> model = FastSpeech2ConformerModel.from_pretrained("espnet/fastspeech2_conformer")
+ >>> output_dict = model(input_ids, return_dict=True)
+ >>> spectrogram = output_dict["spectrogram"]
+
+ >>> vocoder = FastSpeech2ConformerHifiGan.from_pretrained("espnet/fastspeech2_conformer_hifigan")
+ >>> waveform = vocoder(spectrogram)
+ >>> print(waveform.shape)
+ torch.Size([1, 49664])
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
+
+ has_missing_labels = (
+ spectrogram_labels is None or duration_labels is None or pitch_labels is None or energy_labels is None
+ )
+ if self.training and has_missing_labels:
+ raise ValueError("All labels must be provided to run in training mode.")
+
+ # forward encoder
+ text_masks = attention_mask.unsqueeze(-2)
+
+ encoder_outputs = self.encoder(
+ input_ids,
+ text_masks,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+ hidden_states = encoder_outputs[0]
+
+ # Integrate with language id, speaker id, and speaker embedding
+ if self.multispeaker_model and speaker_ids is not None:
+ speaker_id_embeddings = self.speaker_id_embedding(speaker_ids.view(-1))
+ hidden_states = hidden_states + speaker_id_embeddings.unsqueeze(1)
+
+ if self.multilingual_model and lang_ids is not None:
+ language_id_embbedings = self.language_id_embedding(lang_ids.view(-1))
+ hidden_states = hidden_states + language_id_embbedings.unsqueeze(1)
+
+ if self.speaker_embed_dim is not None and speaker_embedding is not None:
+ embeddings_expanded = (
+ nn.functional.normalize(speaker_embedding).unsqueeze(1).expand(-1, hidden_states.size(1), -1)
+ )
+ hidden_states = self.projection(torch.cat([hidden_states, embeddings_expanded], dim=-1))
+
+ # forward duration predictor and variance predictors
+ duration_mask = ~attention_mask.bool()
+
+ if self.stop_gradient_from_pitch_predictor:
+ pitch_predictions = self.pitch_predictor(hidden_states.detach(), duration_mask.unsqueeze(-1))
+ else:
+ pitch_predictions = self.pitch_predictor(hidden_states, duration_mask.unsqueeze(-1))
+
+ if self.stop_gradient_from_energy_predictor:
+ energy_predictions = self.energy_predictor(hidden_states.detach(), duration_mask.unsqueeze(-1))
+ else:
+ energy_predictions = self.energy_predictor(hidden_states, duration_mask.unsqueeze(-1))
+
+ duration_predictions = self.duration_predictor(hidden_states)
+ duration_predictions = duration_predictions.masked_fill(duration_mask, 0.0)
+
+ if not self.training:
+ # use prediction in inference
+ embedded_pitch_curve = self.pitch_embed(pitch_predictions)
+ embedded_energy_curve = self.energy_embed(energy_predictions)
+ hidden_states = hidden_states + embedded_energy_curve + embedded_pitch_curve
+ hidden_states = length_regulator(hidden_states, duration_predictions, self.config.speaking_speed)
+ else:
+ # use groundtruth in training
+ embedded_pitch_curve = self.pitch_embed(pitch_labels)
+ embedded_energy_curve = self.energy_embed(energy_labels)
+ hidden_states = hidden_states + embedded_energy_curve + embedded_pitch_curve
+ hidden_states = length_regulator(hidden_states, duration_labels)
+
+ # forward decoder
+ if not self.training:
+ hidden_mask = None
+ else:
+ spectrogram_mask = (spectrogram_labels != -100).any(dim=-1)
+ spectrogram_mask = spectrogram_mask.int()
+ if self.reduction_factor > 1:
+ length_dim = spectrogram_mask.shape[1] - spectrogram_mask.shape[1] % self.reduction_factor
+ spectrogram_mask = spectrogram_mask[:, :, :length_dim]
+ hidden_mask = spectrogram_mask.unsqueeze(-2)
+
+ decoder_outputs = self.decoder(
+ hidden_states,
+ hidden_mask,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+
+ outputs_before_postnet, outputs_after_postnet = self.speech_decoder_postnet(decoder_outputs[0])
+
+ loss = None
+ if self.training:
+ # calculate loss
+ loss_duration_mask = ~duration_mask
+ loss_spectrogram_mask = spectrogram_mask.unsqueeze(-1).bool()
+ loss = self.criterion(
+ outputs_after_postnet=outputs_after_postnet,
+ outputs_before_postnet=outputs_before_postnet,
+ duration_outputs=duration_predictions,
+ pitch_outputs=pitch_predictions,
+ energy_outputs=energy_predictions,
+ spectrogram_labels=spectrogram_labels,
+ duration_labels=duration_labels,
+ pitch_labels=pitch_labels,
+ energy_labels=energy_labels,
+ duration_mask=loss_duration_mask,
+ spectrogram_mask=loss_spectrogram_mask,
+ )
+
+ if not return_dict:
+ postnet_outputs = (outputs_after_postnet,)
+ audio_feature_predictions = (
+ duration_predictions,
+ pitch_predictions,
+ energy_predictions,
+ )
+ outputs = postnet_outputs + encoder_outputs + decoder_outputs[1:] + audio_feature_predictions
+ return ((loss,) + outputs) if loss is not None else outputs
+
+ return FastSpeech2ConformerModelOutput(
+ loss=loss,
+ spectrogram=outputs_after_postnet,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ duration_outputs=duration_predictions,
+ pitch_outputs=pitch_predictions,
+ energy_outputs=energy_predictions,
+ )
+
+
+# Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock
+class HifiGanResidualBlock(nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
+ super().__init__()
+ self.leaky_relu_slope = leaky_relu_slope
+
+ self.convs1 = nn.ModuleList(
+ [
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ dilation=dilation[i],
+ padding=self.get_padding(kernel_size, dilation[i]),
+ )
+ for i in range(len(dilation))
+ ]
+ )
+ self.convs2 = nn.ModuleList(
+ [
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ padding=self.get_padding(kernel_size, 1),
+ )
+ for _ in range(len(dilation))
+ ]
+ )
+
+ def get_padding(self, kernel_size, dilation=1):
+ return (kernel_size * dilation - dilation) // 2
+
+ def apply_weight_norm(self):
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
+
+ for layer in self.convs1:
+ weight_norm(layer)
+ for layer in self.convs2:
+ weight_norm(layer)
+
+ def remove_weight_norm(self):
+ for layer in self.convs1:
+ nn.utils.remove_weight_norm(layer)
+ for layer in self.convs2:
+ nn.utils.remove_weight_norm(layer)
+
+ def forward(self, hidden_states):
+ for conv1, conv2 in zip(self.convs1, self.convs2):
+ residual = hidden_states
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
+ hidden_states = conv1(hidden_states)
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
+ hidden_states = conv2(hidden_states)
+ hidden_states = hidden_states + residual
+ return hidden_states
+
+
+@auto_docstring(
+ custom_intro="""
+ HiFi-GAN vocoder.
+ """
+)
+# Copied from transformers.models.speecht5.modeling_speecht5.SpeechT5HifiGan with SpeechT5->FastSpeech2Conformer
+class FastSpeech2ConformerHifiGan(PreTrainedModel):
+ config: FastSpeech2ConformerHifiGanConfig
+ main_input_name = "spectrogram"
+
+ def __init__(self, config: FastSpeech2ConformerHifiGanConfig):
+ super().__init__(config)
+ self.num_kernels = len(config.resblock_kernel_sizes)
+ self.num_upsamples = len(config.upsample_rates)
+ self.conv_pre = nn.Conv1d(
+ config.model_in_dim,
+ config.upsample_initial_channel,
+ kernel_size=7,
+ stride=1,
+ padding=3,
+ )
+
+ self.upsampler = nn.ModuleList()
+ for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
+ self.upsampler.append(
+ nn.ConvTranspose1d(
+ config.upsample_initial_channel // (2**i),
+ config.upsample_initial_channel // (2 ** (i + 1)),
+ kernel_size=kernel_size,
+ stride=upsample_rate,
+ padding=(kernel_size - upsample_rate) // 2,
+ )
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.upsampler)):
+ channels = config.upsample_initial_channel // (2 ** (i + 1))
+ for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
+ self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
+
+ self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3)
+
+ self.register_buffer("mean", torch.zeros(config.model_in_dim))
+ self.register_buffer("scale", torch.ones(config.model_in_dim))
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights."""
+ if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ def apply_weight_norm(self):
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
+
+ weight_norm(self.conv_pre)
+ for layer in self.upsampler:
+ weight_norm(layer)
+ for layer in self.resblocks:
+ layer.apply_weight_norm()
+ weight_norm(self.conv_post)
+
+ def remove_weight_norm(self):
+ nn.utils.remove_weight_norm(self.conv_pre)
+ for layer in self.upsampler:
+ nn.utils.remove_weight_norm(layer)
+ for layer in self.resblocks:
+ layer.remove_weight_norm()
+ nn.utils.remove_weight_norm(self.conv_post)
+
+ @auto_docstring(
+ custom_intro="""
+ Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch
+ of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech
+ waveform.
+ """
+ )
+ def forward(self, spectrogram: torch.FloatTensor) -> torch.FloatTensor:
+ r"""
+ spectrogram (`torch.FloatTensor`):
+ Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length,
+ config.model_in_dim)`, or un-batched and of shape `(sequence_length, config.model_in_dim)`.
+
+ Returns:
+ `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of
+ shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`.
+ """
+ if self.config.normalize_before:
+ spectrogram = (spectrogram - self.mean) / self.scale
+
+ is_batched = spectrogram.dim() == 3
+ if not is_batched:
+ spectrogram = spectrogram.unsqueeze(0)
+
+ hidden_states = spectrogram.transpose(2, 1)
+
+ hidden_states = self.conv_pre(hidden_states)
+ for i in range(self.num_upsamples):
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
+ hidden_states = self.upsampler[i](hidden_states)
+
+ res_state = self.resblocks[i * self.num_kernels](hidden_states)
+ for j in range(1, self.num_kernels):
+ res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
+ hidden_states = res_state / self.num_kernels
+
+ hidden_states = nn.functional.leaky_relu(hidden_states)
+ hidden_states = self.conv_post(hidden_states)
+ hidden_states = torch.tanh(hidden_states)
+
+ if not is_batched:
+ # remove batch dim and collapse tensor to 1-d audio waveform
+ waveform = hidden_states.squeeze(0).transpose(1, 0).view(-1)
+ else:
+ # remove seq-len dim since this collapses to 1
+ waveform = hidden_states.squeeze(1)
+
+ return waveform
+
+
+@auto_docstring(
+ custom_intro="""
+ The FastSpeech2ConformerModel with a FastSpeech2ConformerHifiGan vocoder head that performs text-to-speech (waveform).
+ """
+)
+class FastSpeech2ConformerWithHifiGan(PreTrainedModel):
+ config: FastSpeech2ConformerWithHifiGanConfig
+
+ def __init__(self, config: FastSpeech2ConformerWithHifiGanConfig):
+ super().__init__(config)
+
+ self.model = FastSpeech2ConformerModel(config.model_config)
+ self.vocoder = FastSpeech2ConformerHifiGan(config.vocoder_config)
+
+ self.config = config
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ spectrogram_labels: Optional[torch.FloatTensor] = None,
+ duration_labels: Optional[torch.LongTensor] = None,
+ pitch_labels: Optional[torch.FloatTensor] = None,
+ energy_labels: Optional[torch.FloatTensor] = None,
+ speaker_ids: Optional[torch.LongTensor] = None,
+ lang_ids: Optional[torch.LongTensor] = None,
+ speaker_embedding: Optional[torch.FloatTensor] = None,
+ return_dict: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> Union[tuple, FastSpeech2ConformerModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Input sequence of text vectors.
+ spectrogram_labels (`torch.FloatTensor` of shape `(batch_size, max_spectrogram_length, num_mel_bins)`, *optional*, defaults to `None`):
+ Batch of padded target features.
+ duration_labels (`torch.LongTensor` of shape `(batch_size, sequence_length + 1)`, *optional*, defaults to `None`):
+ Batch of padded durations.
+ pitch_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
+ Batch of padded token-averaged pitch.
+ energy_labels (`torch.FloatTensor` of shape `(batch_size, sequence_length + 1, 1)`, *optional*, defaults to `None`):
+ Batch of padded token-averaged energy.
+ speaker_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
+ Speaker ids used to condition features of speech output by the model.
+ lang_ids (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*, defaults to `None`):
+ Language ids used to condition features of speech output by the model.
+ speaker_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`, *optional*, defaults to `None`):
+ Embedding containing conditioning signals for the features of the speech.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... FastSpeech2ConformerTokenizer,
+ ... FastSpeech2ConformerWithHifiGan,
+ ... )
+
+ >>> tokenizer = FastSpeech2ConformerTokenizer.from_pretrained("espnet/fastspeech2_conformer")
+ >>> inputs = tokenizer("some text to convert to speech", return_tensors="pt")
+ >>> input_ids = inputs["input_ids"]
+
+ >>> model = FastSpeech2ConformerWithHifiGan.from_pretrained("espnet/fastspeech2_conformer_with_hifigan")
+ >>> output_dict = model(input_ids, return_dict=True)
+ >>> waveform = output_dict["waveform"]
+ >>> print(waveform.shape)
+ torch.Size([1, 49664])
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.model_config.use_return_dict
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.config.model_config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.model_config.output_hidden_states
+ )
+
+ model_outputs = self.model(
+ input_ids,
+ attention_mask,
+ spectrogram_labels=spectrogram_labels,
+ duration_labels=duration_labels,
+ pitch_labels=pitch_labels,
+ energy_labels=energy_labels,
+ speaker_ids=speaker_ids,
+ lang_ids=lang_ids,
+ speaker_embedding=speaker_embedding,
+ return_dict=return_dict,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ if not return_dict:
+ has_missing_labels = (
+ spectrogram_labels is None or duration_labels is None or pitch_labels is None or energy_labels is None
+ )
+ if has_missing_labels:
+ spectrogram = model_outputs[0]
+ else:
+ spectrogram = model_outputs[1]
+ else:
+ spectrogram = model_outputs["spectrogram"]
+ waveform = self.vocoder(spectrogram)
+
+ if not return_dict:
+ return model_outputs + (waveform,)
+
+ return FastSpeech2ConformerWithHifiGanOutput(waveform=waveform, **model_outputs)
+
+
+__all__ = [
+ "FastSpeech2ConformerWithHifiGan",
+ "FastSpeech2ConformerHifiGan",
+ "FastSpeech2ConformerModel",
+ "FastSpeech2ConformerPreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/fsmt/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/fsmt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8f31762d681dbf3541d38c39fafdf5fa6b864d1
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/fsmt/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_fsmt import *
+ from .modeling_fsmt import *
+ from .tokenization_fsmt import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/fsmt/configuration_fsmt.py b/venv/lib/python3.13/site-packages/transformers/models/fsmt/configuration_fsmt.py
new file mode 100644
index 0000000000000000000000000000000000000000..7aec2662293f8f1d6d17fa90e73367769ce95461
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/fsmt/configuration_fsmt.py
@@ -0,0 +1,225 @@
+# coding=utf-8
+# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""FSMT configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DecoderConfig(PretrainedConfig):
+ r"""
+ Configuration class for FSMT's decoder specific things. note: this is a private helper class
+ """
+
+ model_type = "fsmt_decoder"
+
+ def __init__(self, vocab_size=0, bos_token_id=0, is_encoder_decoder=True, **kwargs):
+ super().__init__(**kwargs)
+ self.vocab_size = vocab_size
+ self.bos_token_id = bos_token_id
+ self.is_encoder_decoder = is_encoder_decoder
+
+
+class FSMTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`FSMTModel`]. It is used to instantiate a FSMT
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the FSMT
+ [facebook/wmt19-en-ru](https://huggingface.co/facebook/wmt19-en-ru) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ langs (`list[str]`):
+ A list with source language and target_language (e.g., ['en', 'ru']).
+ src_vocab_size (`int`):
+ Vocabulary size of the encoder. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed to the forward method in the encoder.
+ tgt_vocab_size (`int`):
+ Vocabulary size of the decoder. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed to the forward method in the decoder.
+ d_model (`int`, *optional*, defaults to 1024):
+ Dimensionality of the layers and the pooler layer.
+ encoder_layers (`int`, *optional*, defaults to 12):
+ Number of encoder layers.
+ decoder_layers (`int`, *optional*, defaults to 12):
+ Number of decoder layers.
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ activation_function (`str` or `Callable`, *optional*, defaults to `"relu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ scale_embedding (`bool`, *optional*, defaults to `True`):
+ Scale embeddings by diving by sqrt(d_model).
+ bos_token_id (`int`, *optional*, defaults to 0)
+ Beginning of stream token id.
+ pad_token_id (`int`, *optional*, defaults to 1)
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 2)
+ End of stream token id.
+ decoder_start_token_id (`int`, *optional*):
+ This model starts decoding with `eos_token_id`
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ Google "layerdrop arxiv", as its not explainable in one line.
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ Google "layerdrop arxiv", as its not explainable in one line.
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+ Whether this is an encoder/decoder model.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie input and output embeddings.
+ num_beams (`int`, *optional*, defaults to 5)
+ Number of beams for beam search that will be used by default in the `generate` method of the model. 1 means
+ no beam search.
+ length_penalty (`float`, *optional*, defaults to 1)
+ Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
+ the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
+ likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
+ `length_penalty` < 0.0 encourages shorter sequences.
+ early_stopping (`bool`, *optional*, defaults to `False`)
+ Flag that will be used by default in the `generate` method of the model. Whether to stop the beam search
+ when at least `num_beams` sentences are finished per batch or not.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
+ `eos_token_id`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import FSMTConfig, FSMTModel
+
+ >>> # Initializing a FSMT facebook/wmt19-en-ru style configuration
+ >>> config = FSMTConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = FSMTModel(config)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "fsmt"
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+ sub_configs = {"decoder": DecoderConfig}
+
+ # update the defaults from config file
+ def __init__(
+ self,
+ langs=["en", "de"],
+ src_vocab_size=42024,
+ tgt_vocab_size=42024,
+ activation_function="relu",
+ d_model=1024,
+ max_length=200,
+ max_position_embeddings=1024,
+ encoder_ffn_dim=4096,
+ encoder_layers=12,
+ encoder_attention_heads=16,
+ encoder_layerdrop=0.0,
+ decoder_ffn_dim=4096,
+ decoder_layers=12,
+ decoder_attention_heads=16,
+ decoder_layerdrop=0.0,
+ attention_dropout=0.0,
+ dropout=0.1,
+ activation_dropout=0.0,
+ init_std=0.02,
+ decoder_start_token_id=2,
+ is_encoder_decoder=True,
+ scale_embedding=True,
+ tie_word_embeddings=False,
+ num_beams=5,
+ length_penalty=1.0,
+ early_stopping=False,
+ use_cache=True,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ forced_eos_token_id=2,
+ **common_kwargs,
+ ):
+ self.langs = langs
+ self.src_vocab_size = src_vocab_size
+ self.tgt_vocab_size = tgt_vocab_size
+ self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
+
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.encoder_layers = self.num_hidden_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.encoder_layerdrop = encoder_layerdrop
+ self.decoder_layerdrop = decoder_layerdrop
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.decoder_layers = decoder_layers
+ self.decoder_attention_heads = decoder_attention_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.init_std = init_std # Normal(0, this parameter)
+ self.activation_function = activation_function
+
+ self.decoder = DecoderConfig(
+ vocab_size=tgt_vocab_size,
+ bos_token_id=eos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ num_hidden_layers=encoder_layers,
+ )
+ if "decoder" in common_kwargs:
+ del common_kwargs["decoder"]
+
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+
+ # 3 Types of Dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.dropout = dropout
+
+ self.use_cache = use_cache
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ decoder_start_token_id=decoder_start_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ tie_word_embeddings=tie_word_embeddings,
+ forced_eos_token_id=forced_eos_token_id,
+ max_length=max_length,
+ num_beams=num_beams,
+ length_penalty=length_penalty,
+ early_stopping=early_stopping,
+ **common_kwargs,
+ )
+
+
+__all__ = ["FSMTConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/fsmt/modeling_fsmt.py b/venv/lib/python3.13/site-packages/transformers/models/fsmt/modeling_fsmt.py
new file mode 100644
index 0000000000000000000000000000000000000000..85618847dbf7676400fe190fd1a54bab53c87a6a
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/fsmt/modeling_fsmt.py
@@ -0,0 +1,1257 @@
+# coding=utf-8
+# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Original implementation: https://github.com/pytorch/fairseq/tree/master/examples/wmt19
+# Authors:
+# - @alexeib Alexei Baevski
+# - @edunov Sergey Edunov
+# - @michaelauli Michael Auli
+# - @myleott Myle Ott
+# - @nng555 Nathan Ng
+# - David Grangier
+# - Kyra Yee
+#
+# Paper: Facebook FAIR's WMT19 News Translation Task Submission https://huggingface.co/papers/1907.06616
+#
+"""PyTorch Fairseq model, ported from https://github.com/pytorch/fairseq/tree/master/examples/wmt19"""
+
+import math
+from typing import Any, Optional, Union
+
+import torch
+from torch import Tensor, nn
+from torch.nn import CrossEntropyLoss, LayerNorm
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import GenerationMixin
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging
+from .configuration_fsmt import FSMTConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+# See all FSMT models at https://huggingface.co/models?filter=fsmt
+
+# Porting notes:
+# this one is modeled after BartModel*
+#
+# Currently only translation (fairseq also has weights for LM)
+#
+# fairseq provides weights for ru-en, en-ru and de-en, en-de pairs. All have been ported.
+# - ru-en, en-ru use asymmetric vocab
+# - de-en, en-de use a merged single vocab (but the code works as if they are separate)
+#
+# Differences with Bart:
+# - not using bos token
+# - 2 separate vocabs (src and target)
+# - embed weights aren't tied
+# - uses a model Ensemble (but that part isn't ported/implemented yet) - so we
+# aren't getting as good of a BLEU score
+# - uses a projection layer at the end of the decoder
+# - doesn't use final_logits_bias
+# - beam search: stops as soon as num_beams == len(hypos) (whereas transformers
+# is not satisfied there and will continue searching until the next cycles
+# aren't promising something better), comparing BLEU scores - the transformers
+# algorithm is slightly superior, therefore using the latter. But if you want
+# to match fairseq outputs, you need to pass ``early_stopping=True`` to ``generate()``.
+#
+# SinusoidalPositionalEmbedding is slightly different from Bart's - generates
+# different embeddings. This implementation is copied verbatim from fairseq with
+# some small changes to make it work here.
+#
+# Other changes:
+# - doesn't support use_cache as Bart's version does
+#
+#
+# FSMTConfig changes with BartConfig
+#
+# Differences with BART:
+# - src/tgt vocabs aren't shared
+# - token embeddings aren't shared
+# - needs a language pair
+# - scale_embedding are True
+#
+# some unused args were removed too
+#
+#
+# TODO:
+# - port model ensemble (fs uses 4 model checkpoints)
+# - solve beam search discrepancies
+# docstyle-ignore
+
+"""
+
+Here is how to compare BLEU scores against fairseq implementation:
+(don't forget to install sacrebleu: `pip install sacrebleu`)
+
+# en-ru
+
+export PAIR=en-ru
+export DATA_DIR=data/$PAIR
+export SAVE_DIR=data/$PAIR
+export BS=8
+export NUM_BEAMS=50
+mkdir -p $DATA_DIR
+sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
+sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
+echo $PAIR
+PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
+
+# (fairseq BLEU: 36.4 http://matrix.statmt.org/matrix/output/1914?score_id=37605)
+
+
+# ru-en
+
+export PAIR=ru-en
+export DATA_DIR=data/$PAIR
+export SAVE_DIR=data/$PAIR
+export BS=8
+export NUM_BEAMS=50
+mkdir -p $DATA_DIR
+sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
+sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
+PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
+
+
+# (fairseq BLEU: 41.3 http://matrix.statmt.org/matrix/output/1907?run_id=6937)
+
+
+# de-en
+
+export PAIR=de-en
+export DATA_DIR=data/$PAIR
+export SAVE_DIR=data/$PAIR
+export BS=8
+export NUM_BEAMS=50
+mkdir -p $DATA_DIR
+sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
+sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
+echo $PAIR
+PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
+
+# (fairseq BLEU: 42.3 http://matrix.statmt.org/matrix/output/1902?run_id=6750)
+
+
+
+# en-de
+
+export PAIR=en-de
+export DATA_DIR=data/$PAIR
+export SAVE_DIR=data/$PAIR
+export BS=8
+mkdir -p $DATA_DIR
+sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
+sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
+echo $PAIR
+PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
+
+# (fairseq BLEU: 43.1 http://matrix.statmt.org/matrix/output/1909?run_id=6862)
+
+"""
+
+
+def invert_mask(attention_mask):
+ """Turns 1->0, 0->1, False->True, True-> False"""
+ assert attention_mask.dim() == 2
+ return attention_mask.eq(0)
+
+
+def triu_onnx(x, diagonal=0):
+ l = x.shape[0]
+ arange = torch.arange(l, device=x.device)
+ mask = arange.expand(l, l)
+ arange = arange.unsqueeze(-1)
+ if diagonal:
+ arange = arange + diagonal
+ mask = mask >= arange
+ return x.masked_fill(mask == 0, 0)
+
+
+def _prepare_fsmt_decoder_inputs(
+ config,
+ input_ids,
+ decoder_input_ids=None,
+ decoder_padding_mask=None,
+ causal_mask_dtype=torch.float32,
+):
+ """
+ Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if none are provided.
+ This mimics the default behavior in fairseq. To override it pass in masks. Note: this is not called during
+ generation
+ """
+ pad_token_id = config.pad_token_id
+ if decoder_input_ids is None:
+ decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
+ bsz, tgt_len = decoder_input_ids.size()
+ if decoder_padding_mask is None:
+ decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
+ else:
+ decoder_padding_mask = invert_mask(decoder_padding_mask)
+ causal_mask = triu_onnx(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len, dtype=causal_mask_dtype)), 1).to(
+ device=decoder_input_ids.device
+ )
+ return decoder_input_ids, decoder_padding_mask, causal_mask
+
+
+@auto_docstring
+class PretrainedFSMTModel(PreTrainedModel):
+ config: FSMTConfig
+ base_model_prefix = "model"
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, SinusoidalPositionalEmbedding):
+ weight = module.get_embedding(*module.weight.shape, module.padding_idx)
+ weight = nn.Parameter(weight, requires_grad=False)
+ weight.detach_()
+ module.weight = weight
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ @property
+ def dummy_inputs(self):
+ pad_token = self.config.pad_token_id
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
+ dummy_inputs = {
+ "attention_mask": input_ids.ne(pad_token),
+ "input_ids": input_ids,
+ }
+ return dummy_inputs
+
+
+def _make_linear_from_emb(emb):
+ vocab_size, emb_size = emb.weight.shape
+ lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
+ lin_layer.weight.data = emb.weight.data
+ return lin_layer
+
+
+# Helper Functions, mostly for making masks
+def _check_shapes(shape_1, shape2):
+ if shape_1 != shape2:
+ raise AssertionError(f"shape mismatch: {shape_1} != {shape2}")
+
+
+def shift_tokens_right(input_ids, pad_token_id):
+ """Shift input ids one token to the right, and wrap the last non pad token (usually )."""
+
+ # replace possible -100 values in labels by `pad_token_id`
+ input_ids.masked_fill_(input_ids == -100, pad_token_id)
+
+ prev_output_tokens = input_ids.clone()
+ index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
+ prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
+ prev_output_tokens[:, 1:] = input_ids[:, :-1]
+ return prev_output_tokens
+
+
+def make_padding_mask(input_ids, padding_idx=1):
+ """True for pad tokens"""
+ padding_mask = input_ids.eq(padding_idx)
+ if not padding_mask.any():
+ padding_mask = None
+ return padding_mask
+
+
+# Helper Modules
+
+
+class EncoderLayer(nn.Module):
+ def __init__(self, config: FSMTConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+ self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout)
+ self.self_attn_layer_norm = LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = LayerNorm(self.embed_dim)
+
+ def forward(self, x, encoder_padding_mask, layer_head_mask, output_attentions=False):
+ """
+ Args:
+ x (`torch.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
+ encoder_padding_mask (`torch.ByteTensor`): binary ByteTensor of shape
+ *(batch, src_len)* where padding elements are indicated by `1`.
+ for t_tgt, t_src is excluded (or masked out), =0 means it is
+ included in attention
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ *(config.encoder_attention_heads,)*.
+
+ Returns:
+ encoded output of shape *(seq_len, batch, embed_dim)*
+ """
+ residual = x
+ x, attn_weights = self.self_attn(
+ query=x,
+ key=x,
+ key_padding_mask=encoder_padding_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
+ x = residual + x
+ x = self.self_attn_layer_norm(x)
+
+ residual = x
+ x = self.activation_fn(self.fc1(x))
+ x = nn.functional.dropout(x, p=self.activation_dropout, training=self.training)
+ x = self.fc2(x)
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
+ x = residual + x
+ x = self.final_layer_norm(x)
+ return x, attn_weights
+
+
+class FSMTEncoder(nn.Module):
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`EncoderLayer`].
+
+ Args:
+ config: FSMTConfig
+ """
+
+ def __init__(self, config: FSMTConfig, embed_tokens):
+ super().__init__()
+ self.dropout = config.dropout
+ self.layerdrop = config.encoder_layerdrop
+ self.padding_idx = embed_tokens.padding_idx
+ self.embed_tokens = embed_tokens
+ embed_dim = embed_tokens.embedding_dim
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+ self.embed_positions = SinusoidalPositionalEmbedding(
+ config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx
+ )
+ self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) # type: list[EncoderLayer]
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ """
+ Args:
+ input_ids (`torch.LongTensor`): tokens in the source language of shape
+ *(batch, src_len)*
+ attention_mask (`torch.LongTensor`): indicating which indices are padding tokens
+ inputs_embeds (`torch.FloatTensor`):
+ embedding vectors of shape *(batch, src_len, embed_dim)*
+ head_mask (`torch.Tensor` of shape `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ Returns:
+ BaseModelOutput or Tuple comprised of:
+
+ - **x** (`torch.Tensor`): the last encoder layer's output of shape *(src_len, batch, embed_dim)*
+ - **encoder_states** (`Tuple(torch.FloatTensor)`): all intermediate hidden states of shape *(src_len,
+ batch, embed_dim)*. Only populated if *output_hidden_states:* is True.
+ - **all_attentions** (`Tuple(torch.FloatTensor)`): Attention weights for each layer.
+ During training might not be of length n_layers because of layer dropout.
+ """
+ # check attention mask and invert
+ if attention_mask is not None:
+ attention_mask = invert_mask(attention_mask)
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+ embed_pos = self.embed_positions(input_ids)
+ elif inputs_embeds is not None:
+ inputs_embeds = inputs_embeds * self.embed_scale
+
+ # We assume zeros hidden states correspond to padding tokens
+ # and create `position_ids` where inputs_embeds[:, :, 0] == 0
+ position_ids = inputs_embeds[:, :, 0].masked_fill(
+ inputs_embeds[:, :, 0].eq(0), self.embed_positions.padding_idx
+ )
+
+ embed_pos = self.embed_positions(position_ids)
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ x = inputs_embeds + embed_pos
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ assert head_mask.size()[0] == (len(self.layers)), (
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ )
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ x = x.transpose(0, 1) # T x B x C -> B x T x C
+ encoder_states += (x,)
+ x = x.transpose(0, 1) # B x T x C -> T x B x C
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ dropout_probability = torch.rand([])
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
+ attn = None
+ else:
+ x, attn = encoder_layer(
+ x,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ if output_attentions:
+ all_attentions = all_attentions + (attn,)
+
+ # T x B x C -> B x T x C
+ x = x.transpose(0, 1)
+
+ if output_hidden_states:
+ encoder_states += (x,)
+
+ if not return_dict:
+ return tuple(v for v in [x, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
+
+
+class DecoderLayer(nn.Module):
+ def __init__(self, config: FSMTConfig, layer_idx=None):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = Attention(
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ layer_idx=layer_idx,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = LayerNorm(self.embed_dim)
+ self.encoder_attn = Attention(
+ self.embed_dim,
+ config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ encoder_decoder_attention=True,
+ layer_idx=layer_idx,
+ )
+ self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ x,
+ encoder_hidden_states,
+ encoder_attn_mask=None,
+ layer_state=None,
+ causal_mask=None,
+ layer_head_mask=None,
+ cross_attn_layer_head_mask=None,
+ decoder_padding_mask=None,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ residual = x
+
+ # Self Attention
+ x, self_attn_weights = self.self_attn(
+ query=x,
+ key=x,
+ layer_state=layer_state, # adds keys to layer state
+ key_padding_mask=decoder_padding_mask,
+ attn_mask=causal_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
+ x = residual + x
+ x = self.self_attn_layer_norm(x)
+
+ # Cross attention
+ residual = x
+ assert self.encoder_attn.cache_key != self.self_attn.cache_key
+ x, cross_attn_weights = self.encoder_attn(
+ query=x,
+ key=encoder_hidden_states,
+ key_padding_mask=encoder_attn_mask,
+ layer_state=layer_state, # mutates layer state
+ layer_head_mask=cross_attn_layer_head_mask,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
+ x = residual + x
+ x = self.encoder_attn_layer_norm(x)
+
+ # Fully Connected
+ residual = x
+ x = self.activation_fn(self.fc1(x))
+ x = nn.functional.dropout(x, p=self.activation_dropout, training=self.training)
+ x = self.fc2(x)
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
+ x = residual + x
+ x = self.final_layer_norm(x)
+ return (
+ x,
+ self_attn_weights,
+ cross_attn_weights,
+ )
+
+
+class FSMTDecoder(nn.Module):
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DecoderLayer`]
+
+ Args:
+ config: FSMTConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding):
+ super().__init__()
+ self.dropout = config.dropout
+ self.layerdrop = config.decoder_layerdrop
+ self.padding_idx = embed_tokens.padding_idx
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+ self.embed_tokens = embed_tokens
+ embed_dim = embed_tokens.embedding_dim
+ self.embed_positions = SinusoidalPositionalEmbedding(
+ config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx
+ )
+ self.layers = nn.ModuleList([DecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) # type: list[DecoderLayer]
+
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ with deepspeed.zero.GatheredParameters(self.embed_tokens.weight, modifier_rank=None):
+ embed_tokens_weight_shape = self.embed_tokens.weight.shape
+ else:
+ embed_tokens_weight_shape = self.embed_tokens.weight.shape
+ self.output_projection = nn.Linear(embed_tokens_weight_shape[1], embed_tokens_weight_shape[0], bias=False)
+ self.output_projection.weight = self.embed_tokens.weight
+
+ def _tie_weights(self):
+ self.embed_tokens.weight = self.output_projection.weight
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_padding_mask: torch.Tensor,
+ decoder_padding_mask: torch.Tensor,
+ decoder_causal_mask: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ cache_position: Optional[torch.Tensor] = None,
+ ):
+ """
+ Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al.,
+ EMNLP 2019).
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch, tgt_len)`):
+ previous decoder outputs for teacher forcing
+ encoder_hidden_states: output from the encoder, used for
+ encoder-side attention
+ encoder_padding_mask: for ignoring pad tokens
+ past_key_values (dict or None): dictionary used for storing state during generation
+ head_mask (`torch.Tensor` of shape `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ Returns:
+ BaseModelOutputWithPast or tuple:
+
+ - the decoder's features of shape *(batch, tgt_len, embed_dim)*
+ - the cache
+ - hidden states
+ - attentions
+ """
+ # check attention mask and invert
+ if encoder_padding_mask is not None:
+ encoder_padding_mask = invert_mask(encoder_padding_mask)
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ # embed positions
+ positions = self.embed_positions(input_ids)
+ if use_cache:
+ input_ids = input_ids[:, -1:]
+ positions = positions[:, -1:] # happens after we embed them
+ x = self.embed_tokens(input_ids) * self.embed_scale
+ elif inputs_embeds is not None:
+ # We assume zeros hidden states correspond to padding tokens
+ # and create `position_ids` where inputs_embeds[:, :, 0] == 0
+ position_ids = inputs_embeds[:, :, 0].masked_fill(
+ inputs_embeds[:, :, 0].eq(0), self.embed_positions.padding_idx
+ )
+ positions = self.embed_positions(position_ids)
+ x = inputs_embeds * self.embed_scale
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ # initialize `past_key_values`
+ if use_cache and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+ if use_cache and isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
+
+ x += positions
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
+
+ # Convert to FSMT output format: (BS, seq_len, model_dim) -> (seq_len, BS, model_dim)
+ x = x.transpose(0, 1)
+ encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attns = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ assert attn_mask.size()[0] == (len(self.layers)), (
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ if output_hidden_states:
+ x = x.transpose(0, 1)
+ all_hidden_states += (x,)
+ x = x.transpose(0, 1)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ x, layer_self_attn, layer_cross_attn = decoder_layer(
+ x,
+ encoder_hidden_states,
+ encoder_attn_mask=encoder_padding_mask,
+ decoder_padding_mask=decoder_padding_mask,
+ layer_state=past_key_values,
+ causal_mask=decoder_causal_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+
+ if output_attentions:
+ all_self_attns += (layer_self_attn,)
+ all_cross_attns += (layer_cross_attn,)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ x = x.transpose(0, 1)
+ all_hidden_states += (x,)
+ x = x.transpose(0, 1)
+
+ # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
+ x = x.transpose(0, 1)
+ encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
+
+ x = self.output_projection(x)
+
+ if not return_dict:
+ return tuple(
+ v for v in [x, past_key_values, all_hidden_states, all_self_attns, all_cross_attns] if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=x,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attns,
+ )
+
+
+def _reorder_buffer(attn_cache, new_order):
+ for k, input_buffer_k in attn_cache.items():
+ if input_buffer_k is not None:
+ attn_cache[k] = input_buffer_k.index_select(0, new_order)
+ return attn_cache
+
+
+class Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ dropout=0.0,
+ bias=True,
+ encoder_decoder_attention=False, # otherwise self_attention
+ layer_idx=None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim**-0.5
+ self.layer_idx = layer_idx
+
+ self.encoder_decoder_attention = encoder_decoder_attention
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
+
+ def forward(
+ self,
+ query,
+ key: Optional[Tensor],
+ key_padding_mask: Optional[Tensor] = None,
+ layer_state: Optional[Cache] = None,
+ attn_mask: Optional[Tensor] = None,
+ layer_head_mask: Optional[Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> tuple[Tensor, Optional[Tensor]]:
+ """Input shape: Time(SeqLen) x Batch x Channel"""
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+
+ if layer_state is not None:
+ if isinstance(layer_state, EncoderDecoderCache):
+ is_updated = layer_state.is_updated.get(self.layer_idx)
+ if self.encoder_decoder_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = layer_state.cross_attention_cache
+ else:
+ curr_past_key_value = layer_state.self_attention_cache
+ else:
+ curr_past_key_value = layer_state
+
+ # NOTE: FSMT has format (seq_len, BS, model_dim) for inputs
+ current_states = key if self.encoder_decoder_attention else query
+ if self.encoder_decoder_attention and layer_state is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
+ value_states = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_states = self.k_proj(current_states)
+ value_states = self.v_proj(current_states)
+ key_states = key_states.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
+ value_states = value_states.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
+
+ if layer_state is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not self.encoder_decoder_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if self.encoder_decoder_attention:
+ layer_state.is_updated[self.layer_idx] = True
+
+ query_states = self.q_proj(query) * self.scaling
+
+ # Reshape back to 3D tensors for `bmm`
+ query_states = query_states.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ key_states = key_states.reshape(bsz * self.num_heads, -1, self.head_dim)
+ value_states = value_states.reshape(bsz * self.num_heads, -1, self.head_dim)
+
+ assert key_states is not None
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+ assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
+
+ if attn_mask is not None:
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
+ key_padding_mask = None
+ assert key_padding_mask is None or key_padding_mask.size()[:2] == (
+ bsz,
+ src_len,
+ )
+
+ if key_padding_mask is not None: # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
+ attn_weights = attn_weights.masked_fill(reshaped, torch.finfo(attn_weights.dtype).min)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ assert layer_head_mask.size() == (self.num_heads,), (
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # make sure that attn_weights are included in graph
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(
+ attn_weights,
+ p=self.dropout,
+ training=self.training,
+ )
+
+ assert value_states is not None
+ attn_output = torch.bmm(attn_probs, value_states)
+ assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+def fill_with_neg_inf(t):
+ """FP16-compatible function that fills a input_ids with -inf."""
+ return t.float().fill_(torch.finfo(t.dtype).min).type_as(t)
+
+
+# Public API
+def _get_shape(t):
+ return getattr(t, "shape", None)
+
+
+@auto_docstring
+class FSMTModel(PretrainedFSMTModel):
+ _tied_weights_keys = ["decoder.embed_tokens.weight", "decoder.output_projection.weight"]
+
+ def __init__(self, config: FSMTConfig):
+ super().__init__(config)
+
+ padding_idx = config.pad_token_id
+ encoder_embed_tokens = nn.Embedding(config.src_vocab_size, config.d_model, padding_idx)
+ decoder_embed_tokens = nn.Embedding(config.tgt_vocab_size, config.d_model, padding_idx)
+
+ self.encoder = FSMTEncoder(config, encoder_embed_tokens)
+ self.decoder = FSMTDecoder(config, decoder_embed_tokens)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.encoder
+
+ def _tie_weights(self):
+ if self.config.tie_word_embeddings:
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings())
+ self._tie_or_clone_weights(self.decoder.output_projection, self.get_input_embeddings())
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[tuple[torch.FloatTensor]] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ FSMT uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
+ 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ """
+ if decoder_input_ids is None:
+ use_cache = False
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # make masks if user doesn't supply
+ if not use_cache and input_ids is not None:
+ decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_fsmt_decoder_inputs(
+ self.config,
+ input_ids,
+ decoder_input_ids=decoder_input_ids,
+ decoder_padding_mask=decoder_attention_mask,
+ causal_mask_dtype=self.decoder.embed_tokens.weight.dtype,
+ )
+ else:
+ decoder_padding_mask, causal_mask = None, None
+
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ raise ValueError("Make sure that `decoder_input_ids` or `decoder_inputs_embeds` are passed.")
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=False
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ decoder_input_ids,
+ encoder_outputs[0],
+ attention_mask,
+ decoder_padding_mask,
+ decoder_causal_mask=causal_mask,
+ inputs_embeds=decoder_inputs_embeds,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+ def get_input_embeddings(self):
+ return self.encoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.encoder.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.decoder.embed_tokens
+
+ def set_output_embeddings(self, value):
+ self.decoder.embed_tokens = value
+
+
+@auto_docstring(
+ custom_intro="""
+ The FSMT Model with a language modeling head. Can be used for summarization.
+ """
+)
+class FSMTForConditionalGeneration(PretrainedFSMTModel, GenerationMixin):
+ base_model_prefix = "model"
+ _tied_weights_keys = ["decoder.embed_tokens.weight", "decoder.output_projection.weight"]
+
+ def __init__(self, config: FSMTConfig):
+ super().__init__(config)
+ base_model = FSMTModel(config)
+ self.model = base_model
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[tuple[torch.FloatTensor]] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ FSMT uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
+ 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example Translation:
+
+ ```python
+ >>> from transformers import AutoTokenizer, FSMTForConditionalGeneration
+
+ >>> mname = "facebook/wmt19-ru-en"
+ >>> model = FSMTForConditionalGeneration.from_pretrained(mname)
+ >>> tokenizer = AutoTokenizer.from_pretrained(mname)
+
+ >>> src_text = "Машинное обучение - это здорово, не так ли?"
+ >>> input_ids = tokenizer(src_text, return_tensors="pt").input_ids
+ >>> outputs = model.generate(input_ids, num_beams=5, num_return_sequences=3)
+ >>> tokenizer.decode(outputs[0], skip_special_tokens=True)
+ "Machine learning is great, isn't it?"
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.model(
+ input_ids,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ lm_logits = outputs[0]
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # TODO(SS): do we need to ignore pad tokens in labels?
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.tgt_vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=masked_lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id)
+
+ def get_encoder(self):
+ return self.model.encoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ def get_output_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_output_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
+
+
+class SinusoidalPositionalEmbedding(nn.Embedding):
+ """
+ This module produces sinusoidal positional embeddings of any length.
+
+ We don't want to save the weight of this embedding since it's not trained (deterministic) and it can be huge.
+
+ Padding symbols are ignored.
+
+ These embeddings get automatically extended in forward if more positions is needed.
+ """
+
+ def __init__(self, num_positions, embedding_dim, padding_idx):
+ super().__init__(num_positions, embedding_dim, padding_idx)
+
+ def make_weight(self, num_positions, embedding_dim, padding_idx):
+ weight = self.get_embedding(num_positions, embedding_dim, padding_idx)
+ # in forward put the weights on the correct dtype and device of the param
+ weight = weight.to(dtype=self.weight.dtype, device=self.weight.device)
+ self.weight = nn.Parameter(weight)
+ self.weight.detach_()
+ self.weight.requires_grad = False
+
+ @staticmethod
+ def get_embedding(num_embeddings, embedding_dim, padding_idx):
+ """
+ Build sinusoidal embeddings.
+
+ This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
+ "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ if padding_idx is not None:
+ emb[padding_idx, :] = 0
+ return emb
+
+ @staticmethod
+ def make_positions(tensor, padding_idx: int):
+ """
+ Replace non-padding symbols with their position numbers.
+
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
+ """
+ # The series of casts and type-conversions here are carefully
+ # balanced to both work with ONNX export and XLA. In particular XLA
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
+ # how to handle the dtype kwarg in cumsum.
+ mask = tensor.ne(padding_idx).int()
+ return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
+
+ def forward(
+ self,
+ input,
+ incremental_state: Optional[Any] = None,
+ timestep: Optional[Tensor] = None,
+ ):
+ """Input is expected to be of size [bsz x seqlen]."""
+ bsz, seq_len = input.shape[:2]
+ max_pos = self.padding_idx + 1 + seq_len
+ if max_pos > self.weight.size(0):
+ # expand embeddings if needed
+ self.make_weight(max_pos, self.embedding_dim, self.padding_idx)
+ positions = self.make_positions(input, self.padding_idx)
+ return super().forward(positions)
+
+
+__all__ = ["FSMTForConditionalGeneration", "FSMTModel", "PretrainedFSMTModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/fsmt/tokenization_fsmt.py b/venv/lib/python3.13/site-packages/transformers/models/fsmt/tokenization_fsmt.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a4446d8e90b4c0466d1c7c09c4dc8b153dc8a33
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/fsmt/tokenization_fsmt.py
@@ -0,0 +1,488 @@
+# coding=utf-8
+# Copyright 2019 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for FSMT."""
+
+import json
+import os
+import re
+import unicodedata
+from typing import Optional
+
+from ...tokenization_utils import PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+ "src_vocab_file": "vocab-src.json",
+ "tgt_vocab_file": "vocab-tgt.json",
+ "merges_file": "merges.txt",
+}
+
+
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
+ strings)
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def replace_unicode_punct(text):
+ """
+ Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
+ """
+ text = text.replace(",", ",")
+ text = re.sub(r"。\s*", ". ", text)
+ text = text.replace("、", ",")
+ text = text.replace("”", '"')
+ text = text.replace("“", '"')
+ text = text.replace("∶", ":")
+ text = text.replace(":", ":")
+ text = text.replace("?", "?")
+ text = text.replace("《", '"')
+ text = text.replace("》", '"')
+ text = text.replace(")", ")")
+ text = text.replace("!", "!")
+ text = text.replace("(", "(")
+ text = text.replace(";", ";")
+ text = text.replace("1", "1")
+ text = text.replace("」", '"')
+ text = text.replace("「", '"')
+ text = text.replace("0", "0")
+ text = text.replace("3", "3")
+ text = text.replace("2", "2")
+ text = text.replace("5", "5")
+ text = text.replace("6", "6")
+ text = text.replace("9", "9")
+ text = text.replace("7", "7")
+ text = text.replace("8", "8")
+ text = text.replace("4", "4")
+ text = re.sub(r".\s*", ". ", text)
+ text = text.replace("~", "~")
+ text = text.replace("’", "'")
+ text = text.replace("…", "...")
+ text = text.replace("━", "-")
+ text = text.replace("〈", "<")
+ text = text.replace("〉", ">")
+ text = text.replace("【", "[")
+ text = text.replace("】", "]")
+ text = text.replace("%", "%")
+ return text
+
+
+def remove_non_printing_char(text):
+ """
+ Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
+ """
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat.startswith("C"):
+ continue
+ output.append(char)
+ return "".join(output)
+
+
+# Porting notes:
+# this one is modeled after XLMTokenizer
+#
+# added:
+# - src_vocab_file,
+# - tgt_vocab_file,
+# - langs,
+
+
+class FSMTTokenizer(PreTrainedTokenizer):
+ """
+ Construct an FAIRSEQ Transformer tokenizer. Based on Byte-Pair Encoding. The tokenization process is the following:
+
+ - Moses preprocessing and tokenization.
+ - Normalizing all inputs text.
+ - The arguments `special_tokens` and the function `set_special_tokens`, can be used to add additional symbols (like
+ "__classify__") to a vocabulary.
+ - The argument `langs` defines a pair of languages.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ langs (`List[str]`, *optional*):
+ A list of two languages to translate from and to, for instance `["en", "ru"]`.
+ src_vocab_file (`str`, *optional*):
+ File containing the vocabulary for the source language.
+ tgt_vocab_file (`st`, *optional*):
+ File containing the vocabulary for the target language.
+ merges_file (`str`, *optional*):
+ File containing the merges.
+ do_lower_case (`bool`, *optional*, defaults to `False`):
+ Whether or not to lowercase the input when tokenizing.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ langs=None,
+ src_vocab_file=None,
+ tgt_vocab_file=None,
+ merges_file=None,
+ do_lower_case=False,
+ unk_token="",
+ bos_token="",
+ sep_token="",
+ pad_token="",
+ **kwargs,
+ ):
+ try:
+ import sacremoses
+ except ImportError:
+ raise ImportError(
+ "You need to install sacremoses to use XLMTokenizer. "
+ "See https://pypi.org/project/sacremoses/ for installation."
+ )
+
+ self.sm = sacremoses
+
+ self.src_vocab_file = src_vocab_file
+ self.tgt_vocab_file = tgt_vocab_file
+ self.merges_file = merges_file
+ self.do_lower_case = do_lower_case
+
+ # cache of sm.MosesPunctNormalizer instance
+ self.cache_moses_punct_normalizer = {}
+ # cache of sm.MosesTokenizer instance
+ self.cache_moses_tokenizer = {}
+ self.cache_moses_detokenizer = {}
+
+ if langs and len(langs) == 2:
+ self.src_lang, self.tgt_lang = langs
+ else:
+ raise ValueError(
+ f"arg `langs` needs to be a list of 2 langs, e.g. ['en', 'ru'], but got {langs}. "
+ "Usually that means that tokenizer can't find a mapping for the given model path "
+ "in and other maps of this tokenizer."
+ )
+
+ with open(src_vocab_file, encoding="utf-8") as src_vocab_handle:
+ self.encoder = json.load(src_vocab_handle)
+ with open(tgt_vocab_file, encoding="utf-8") as tgt_vocab_handle:
+ tgt_vocab = json.load(tgt_vocab_handle)
+ self.decoder = {v: k for k, v in tgt_vocab.items()}
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ merges = merges_handle.read().split("\n")[:-1]
+ merges = [tuple(merge.split()[:2]) for merge in merges]
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {}
+ super().__init__(
+ langs=langs,
+ src_vocab_file=src_vocab_file,
+ tgt_vocab_file=tgt_vocab_file,
+ merges_file=merges_file,
+ do_lower_case=do_lower_case,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ **kwargs,
+ )
+
+ # hack override
+ def get_vocab(self) -> dict[str, int]:
+ return self.get_src_vocab()
+
+ # hack override
+ @property
+ def vocab_size(self) -> int:
+ return self.src_vocab_size
+
+ def moses_punct_norm(self, text, lang):
+ if lang not in self.cache_moses_punct_normalizer:
+ punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang)
+ self.cache_moses_punct_normalizer[lang] = punct_normalizer
+ return self.cache_moses_punct_normalizer[lang].normalize(text)
+
+ def moses_tokenize(self, text, lang):
+ if lang not in self.cache_moses_tokenizer:
+ moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
+ self.cache_moses_tokenizer[lang] = moses_tokenizer
+ return self.cache_moses_tokenizer[lang].tokenize(
+ text, aggressive_dash_splits=True, return_str=False, escape=True
+ )
+
+ def moses_detokenize(self, tokens, lang):
+ if lang not in self.cache_moses_detokenizer:
+ moses_detokenizer = self.sm.MosesDetokenizer(lang=lang)
+ self.cache_moses_detokenizer[lang] = moses_detokenizer
+ return self.cache_moses_detokenizer[lang].detokenize(tokens)
+
+ def moses_pipeline(self, text, lang):
+ text = replace_unicode_punct(text)
+ text = self.moses_punct_norm(text, lang)
+ text = remove_non_printing_char(text)
+ return text
+
+ @property
+ def src_vocab_size(self):
+ return len(self.encoder)
+
+ @property
+ def tgt_vocab_size(self):
+ return len(self.decoder)
+
+ def get_src_vocab(self):
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ def get_tgt_vocab(self):
+ return dict(self.decoder, **self.added_tokens_decoder)
+
+ def bpe(self, token):
+ word = tuple(token[:-1]) + (token[-1] + "",)
+ if token in self.cache:
+ return self.cache[token]
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ""
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ if word == "\n ":
+ word = "\n"
+ self.cache[token] = word
+ return word
+
+ def _tokenize(self, text, lang="en", bypass_tokenizer=False):
+ """
+ Tokenize a string given language code using Moses.
+
+ Details of tokenization:
+
+ - [sacremoses](https://github.com/alvations/sacremoses): port of Moses
+ - Install with `pip install sacremoses`
+
+ Args:
+ - lang: ISO language code (default = 'en') (string). Languages should belong of the model supported
+ languages. However, we don't enforce it.
+ - bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False)
+ (bool). If True, we only apply BPE.
+
+ Returns:
+ List of tokens.
+ """
+ # ignore `lang` which is currently isn't explicitly passed in tokenization_utils.py and always results in lang=en
+ # if lang != self.src_lang:
+ # raise ValueError(f"Expected lang={self.src_lang}, but got {lang}")
+ lang = self.src_lang
+
+ if self.do_lower_case:
+ text = text.lower()
+
+ if bypass_tokenizer:
+ text = text.split()
+ else:
+ text = self.moses_pipeline(text, lang=lang)
+ text = self.moses_tokenize(text, lang=lang)
+
+ split_tokens = []
+ for token in text:
+ if token:
+ split_tokens.extend(list(self.bpe(token).split(" ")))
+
+ return split_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+
+ # remove BPE
+ tokens = [t.replace(" ", "").replace("", " ") for t in tokens]
+ tokens = "".join(tokens).split()
+ # detokenize
+ text = self.moses_detokenize(tokens, self.tgt_lang)
+ return text
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A FAIRSEQ Transformer sequence has the following format:
+
+ - single sequence: ` X `
+ - pair of sequences: ` A B `
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ sep = [self.sep_token_id]
+
+ # no bos used in fairseq
+ if token_ids_1 is None:
+ return token_ids_0 + sep
+ return token_ids_0 + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+ # no bos used in fairseq
+ if token_ids_1 is not None:
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return ([0] * len(token_ids_0)) + [1]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+
+ src_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["src_vocab_file"]
+ )
+ tgt_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["tgt_vocab_file"]
+ )
+ merges_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+ )
+
+ with open(src_vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+ with open(tgt_vocab_file, "w", encoding="utf-8") as f:
+ tgt_vocab = {v: k for k, v in self.decoder.items()}
+ f.write(json.dumps(tgt_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+ index = 0
+ with open(merges_file, "w", encoding="utf-8") as writer:
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {merges_file}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!"
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return src_vocab_file, tgt_vocab_file, merges_file
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sm"] = None
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+
+ try:
+ import sacremoses
+ except ImportError:
+ raise ImportError(
+ "You need to install sacremoses to use XLMTokenizer. "
+ "See https://pypi.org/project/sacremoses/ for installation."
+ )
+
+ self.sm = sacremoses
+
+
+__all__ = ["FSMTTokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/funnel/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/funnel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4e0587ce32f5e59562102b302a113f387c60130
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/funnel/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_funnel import *
+ from .convert_funnel_original_tf_checkpoint_to_pytorch import *
+ from .modeling_funnel import *
+ from .modeling_tf_funnel import *
+ from .tokenization_funnel import *
+ from .tokenization_funnel_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/funnel/configuration_funnel.py b/venv/lib/python3.13/site-packages/transformers/models/funnel/configuration_funnel.py
new file mode 100644
index 0000000000000000000000000000000000000000..212a976f2781935811c191dfcd7e0076e59025e8
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/funnel/configuration_funnel.py
@@ -0,0 +1,166 @@
+# coding=utf-8
+# Copyright 2020, Hugging Face
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Funnel Transformer model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class FunnelConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`FunnelModel`] or a [`TFBertModel`]. It is used to
+ instantiate a Funnel Transformer model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Funnel
+ Transformer [funnel-transformer/small](https://huggingface.co/funnel-transformer/small) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the Funnel transformer. Defines the number of different tokens that can be represented
+ by the `inputs_ids` passed when calling [`FunnelModel`] or [`TFFunnelModel`].
+ block_sizes (`list[int]`, *optional*, defaults to `[4, 4, 4]`):
+ The sizes of the blocks used in the model.
+ block_repeats (`list[int]`, *optional*):
+ If passed along, each layer of each block is repeated the number of times indicated.
+ num_decoder_layers (`int`, *optional*, defaults to 2):
+ The number of layers in the decoder (when not using the base model).
+ d_model (`int`, *optional*, defaults to 768):
+ Dimensionality of the model's hidden states.
+ n_head (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ d_head (`int`, *optional*, defaults to 64):
+ Dimensionality of the model's heads.
+ d_inner (`int`, *optional*, defaults to 3072):
+ Inner dimension in the feed-forward blocks.
+ hidden_act (`str` or `callable`, *optional*, defaults to `"gelu_new"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability used between the two layers of the feed-forward blocks.
+ initializer_range (`float`, *optional*, defaults to 0.1):
+ The upper bound of the *uniform initializer* for initializing all weight matrices in attention layers.
+ initializer_std (`float`, *optional*):
+ The standard deviation of the *normal initializer* for initializing the embedding matrix and the weight of
+ linear layers. Will default to 1 for the embedding matrix and the value given by Xavier initialization for
+ linear layers.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-09):
+ The epsilon used by the layer normalization layers.
+ pooling_type (`str`, *optional*, defaults to `"mean"`):
+ Possible values are `"mean"` or `"max"`. The way pooling is performed at the beginning of each block.
+ attention_type (`str`, *optional*, defaults to `"relative_shift"`):
+ Possible values are `"relative_shift"` or `"factorized"`. The former is faster on CPU/GPU while the latter
+ is faster on TPU.
+ separate_cls (`bool`, *optional*, defaults to `True`):
+ Whether or not to separate the cls token when applying pooling.
+ truncate_seq (`bool`, *optional*, defaults to `True`):
+ When using `separate_cls`, whether or not to truncate the last token when pooling, to avoid getting a
+ sequence length that is not a multiple of 2.
+ pool_q_only (`bool`, *optional*, defaults to `True`):
+ Whether or not to apply the pooling only to the query or to query, key and values for the attention layers.
+ """
+
+ model_type = "funnel"
+ attribute_map = {
+ "hidden_size": "d_model",
+ "num_attention_heads": "n_head",
+ }
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ block_sizes=[4, 4, 4],
+ block_repeats=None,
+ num_decoder_layers=2,
+ d_model=768,
+ n_head=12,
+ d_head=64,
+ d_inner=3072,
+ hidden_act="gelu_new",
+ hidden_dropout=0.1,
+ attention_dropout=0.1,
+ activation_dropout=0.0,
+ initializer_range=0.1,
+ initializer_std=None,
+ layer_norm_eps=1e-9,
+ pooling_type="mean",
+ attention_type="relative_shift",
+ separate_cls=True,
+ truncate_seq=True,
+ pool_q_only=True,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.block_sizes = block_sizes
+ self.block_repeats = [1] * len(block_sizes) if block_repeats is None else block_repeats
+ assert len(block_sizes) == len(self.block_repeats), (
+ "`block_sizes` and `block_repeats` should have the same length."
+ )
+ self.num_decoder_layers = num_decoder_layers
+ self.d_model = d_model
+ self.n_head = n_head
+ self.d_head = d_head
+ self.d_inner = d_inner
+ self.hidden_act = hidden_act
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.initializer_range = initializer_range
+ self.initializer_std = initializer_std
+ self.layer_norm_eps = layer_norm_eps
+ assert pooling_type in [
+ "mean",
+ "max",
+ ], f"Got {pooling_type} for `pooling_type` but only 'mean' and 'max' are supported."
+ self.pooling_type = pooling_type
+ assert attention_type in [
+ "relative_shift",
+ "factorized",
+ ], f"Got {attention_type} for `attention_type` but only 'relative_shift' and 'factorized' are supported."
+ self.attention_type = attention_type
+ self.separate_cls = separate_cls
+ self.truncate_seq = truncate_seq
+ self.pool_q_only = pool_q_only
+
+ super().__init__(**kwargs)
+
+ @property
+ def num_hidden_layers(self):
+ return sum(self.block_sizes)
+
+ @num_hidden_layers.setter
+ def num_hidden_layers(self, value):
+ raise NotImplementedError(
+ "This model does not support the setting of `num_hidden_layers`. Please set `block_sizes`."
+ )
+
+ @property
+ def num_blocks(self):
+ return len(self.block_sizes)
+
+ @num_blocks.setter
+ def num_blocks(self, value):
+ raise NotImplementedError("This model does not support the setting of `num_blocks`. Please set `block_sizes`.")
+
+
+__all__ = ["FunnelConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/funnel/modeling_funnel.py b/venv/lib/python3.13/site-packages/transformers/models/funnel/modeling_funnel.py
new file mode 100644
index 0000000000000000000000000000000000000000..4370344cccfb19c710ed05b01e72f6880b54afdd
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/funnel/modeling_funnel.py
@@ -0,0 +1,1452 @@
+# coding=utf-8
+# Copyright 2020-present Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Funnel Transformer model."""
+
+import os
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+ BaseModelOutput,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import ModelOutput, auto_docstring, logging
+from .configuration_funnel import FunnelConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+INF = 1e6
+
+
+def load_tf_weights_in_funnel(model, config, tf_checkpoint_path):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ _layer_map = {
+ "k": "k_head",
+ "q": "q_head",
+ "v": "v_head",
+ "o": "post_proj",
+ "layer_1": "linear_1",
+ "layer_2": "linear_2",
+ "rel_attn": "attention",
+ "ff": "ffn",
+ "kernel": "weight",
+ "gamma": "weight",
+ "beta": "bias",
+ "lookup_table": "weight",
+ "word_embedding": "word_embeddings",
+ "input": "embeddings",
+ }
+
+ for name, array in zip(names, arrays):
+ name = name.split("/")
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
+ ):
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ if name[0] == "generator":
+ continue
+ pointer = model
+ skipped = False
+ for m_name in name[1:]:
+ if not isinstance(pointer, FunnelPositionwiseFFN) and re.fullmatch(r"layer_\d+", m_name):
+ layer_index = int(re.search(r"layer_(\d+)", m_name).groups()[0])
+ if layer_index < config.num_hidden_layers:
+ block_idx = 0
+ while layer_index >= config.block_sizes[block_idx]:
+ layer_index -= config.block_sizes[block_idx]
+ block_idx += 1
+ pointer = pointer.blocks[block_idx][layer_index]
+ else:
+ layer_index -= config.num_hidden_layers
+ pointer = pointer.layers[layer_index]
+ elif m_name == "r" and isinstance(pointer, FunnelRelMultiheadAttention):
+ pointer = pointer.r_kernel
+ break
+ elif m_name in _layer_map:
+ pointer = getattr(pointer, _layer_map[m_name])
+ else:
+ try:
+ pointer = getattr(pointer, m_name)
+ except AttributeError:
+ print(f"Skipping {'/'.join(name)}", array.shape)
+ skipped = True
+ break
+ if not skipped:
+ if len(pointer.shape) != len(array.shape):
+ array = array.reshape(pointer.shape)
+ if m_name == "kernel":
+ array = np.transpose(array)
+ pointer.data = torch.from_numpy(array)
+
+ return model
+
+
+class FunnelEmbeddings(nn.Module):
+ def __init__(self, config: FunnelConfig) -> None:
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+
+ def forward(
+ self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ embeddings = self.layer_norm(inputs_embeds)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class FunnelAttentionStructure(nn.Module):
+ """
+ Contains helpers for `FunnelRelMultiheadAttention `.
+ """
+
+ cls_token_type_id: int = 2
+
+ def __init__(self, config: FunnelConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.sin_dropout = nn.Dropout(config.hidden_dropout)
+ self.cos_dropout = nn.Dropout(config.hidden_dropout)
+ # Track where we are at in terms of pooling from the original input, e.g., by how much the sequence length was
+ # divided.
+ self.pooling_mult = None
+
+ def init_attention_inputs(
+ self,
+ inputs_embeds: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor]:
+ """Returns the attention inputs associated to the inputs of the model."""
+ # inputs_embeds has shape batch_size x seq_len x d_model
+ # attention_mask and token_type_ids have shape batch_size x seq_len
+ self.pooling_mult = 1
+ self.seq_len = seq_len = inputs_embeds.size(1)
+ position_embeds = self.get_position_embeds(seq_len, inputs_embeds.dtype, inputs_embeds.device)
+ token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
+ cls_mask = (
+ nn.functional.pad(inputs_embeds.new_ones([seq_len - 1, seq_len - 1]), (1, 0, 1, 0))
+ if self.config.separate_cls
+ else None
+ )
+ return (position_embeds, token_type_mat, attention_mask, cls_mask)
+
+ def token_type_ids_to_mat(self, token_type_ids: torch.Tensor) -> torch.Tensor:
+ """Convert `token_type_ids` to `token_type_mat`."""
+ token_type_mat = token_type_ids[:, :, None] == token_type_ids[:, None]
+ # Treat as in the same segment as both A & B
+ cls_ids = token_type_ids == self.cls_token_type_id
+ cls_mat = cls_ids[:, :, None] | cls_ids[:, None]
+ return cls_mat | token_type_mat
+
+ def get_position_embeds(
+ self, seq_len: int, dtype: torch.dtype, device: torch.device
+ ) -> Union[tuple[torch.Tensor], list[list[torch.Tensor]]]:
+ """
+ Create and cache inputs related to relative position encoding. Those are very different depending on whether we
+ are using the factorized or the relative shift attention:
+
+ For the factorized attention, it returns the matrices (phi, pi, psi, omega) used in the paper, appendix A.2.2,
+ final formula.
+
+ For the relative shift attention, it returns all possible vectors R used in the paper, appendix A.2.1, final
+ formula.
+
+ Paper link: https://huggingface.co/papers/2006.03236
+ """
+ d_model = self.config.d_model
+ if self.config.attention_type == "factorized":
+ # Notations from the paper, appending A.2.2, final formula.
+ # We need to create and return the matrices phi, psi, pi and omega.
+ pos_seq = torch.arange(0, seq_len, 1.0, dtype=torch.int64, device=device).to(dtype)
+ freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=torch.int64, device=device).to(dtype)
+ inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))
+ sinusoid = pos_seq[:, None] * inv_freq[None]
+ sin_embed = torch.sin(sinusoid)
+ sin_embed_d = self.sin_dropout(sin_embed)
+ cos_embed = torch.cos(sinusoid)
+ cos_embed_d = self.cos_dropout(cos_embed)
+ # This is different from the formula on the paper...
+ phi = torch.cat([sin_embed_d, sin_embed_d], dim=-1)
+ psi = torch.cat([cos_embed, sin_embed], dim=-1)
+ pi = torch.cat([cos_embed_d, cos_embed_d], dim=-1)
+ omega = torch.cat([-sin_embed, cos_embed], dim=-1)
+ return (phi, pi, psi, omega)
+ else:
+ # Notations from the paper, appending A.2.1, final formula.
+ # We need to create and return all the possible vectors R for all blocks and shifts.
+ freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=torch.int64, device=device).to(dtype)
+ inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))
+ # Maximum relative positions for the first input
+ rel_pos_id = torch.arange(-seq_len * 2, seq_len * 2, 1.0, dtype=torch.int64, device=device).to(dtype)
+ zero_offset = seq_len * 2
+ sinusoid = rel_pos_id[:, None] * inv_freq[None]
+ sin_embed = self.sin_dropout(torch.sin(sinusoid))
+ cos_embed = self.cos_dropout(torch.cos(sinusoid))
+ pos_embed = torch.cat([sin_embed, cos_embed], dim=-1)
+
+ pos = torch.arange(0, seq_len, dtype=torch.int64, device=device).to(dtype)
+ pooled_pos = pos
+ position_embeds_list = []
+ for block_index in range(0, self.config.num_blocks):
+ # For each block with block_index > 0, we need two types position embeddings:
+ # - Attention(pooled-q, unpooled-kv)
+ # - Attention(pooled-q, pooled-kv)
+ # For block_index = 0 we only need the second one and leave the first one as None.
+
+ # First type
+ if block_index == 0:
+ position_embeds_pooling = None
+ else:
+ pooled_pos = self.stride_pool_pos(pos, block_index)
+
+ # construct rel_pos_id
+ stride = 2 ** (block_index - 1)
+ rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2)
+ rel_pos = rel_pos[:, None] + zero_offset
+ rel_pos = rel_pos.expand(rel_pos.size(0), d_model)
+ position_embeds_pooling = torch.gather(pos_embed, 0, rel_pos)
+
+ # Second type
+ pos = pooled_pos
+ stride = 2**block_index
+ rel_pos = self.relative_pos(pos, stride)
+
+ rel_pos = rel_pos[:, None] + zero_offset
+ rel_pos = rel_pos.expand(rel_pos.size(0), d_model)
+ position_embeds_no_pooling = torch.gather(pos_embed, 0, rel_pos)
+
+ position_embeds_list.append([position_embeds_no_pooling, position_embeds_pooling])
+ return position_embeds_list
+
+ def stride_pool_pos(self, pos_id: torch.Tensor, block_index: int):
+ """
+ Pool `pos_id` while keeping the cls token separate (if `config.separate_cls=True`).
+ """
+ if self.config.separate_cls:
+ # Under separate , we treat the as the first token in
+ # the previous block of the 1st real block. Since the 1st real
+ # block always has position 1, the position of the previous block
+ # will be at `1 - 2 ** block_index`.
+ cls_pos = pos_id.new_tensor([-(2**block_index) + 1])
+ pooled_pos_id = pos_id[1:-1] if self.config.truncate_seq else pos_id[1:]
+ return torch.cat([cls_pos, pooled_pos_id[::2]], 0)
+ else:
+ return pos_id[::2]
+
+ def relative_pos(self, pos: torch.Tensor, stride: int, pooled_pos=None, shift: int = 1) -> torch.Tensor:
+ """
+ Build the relative positional vector between `pos` and `pooled_pos`.
+ """
+ if pooled_pos is None:
+ pooled_pos = pos
+
+ ref_point = pooled_pos[0] - pos[0]
+ num_remove = shift * len(pooled_pos)
+ max_dist = ref_point + num_remove * stride
+ min_dist = pooled_pos[0] - pos[-1]
+
+ return torch.arange(max_dist, min_dist - 1, -stride, dtype=torch.long, device=pos.device)
+
+ def stride_pool(
+ self,
+ tensor: Union[torch.Tensor, tuple[torch.Tensor], list[torch.Tensor]],
+ axis: Union[int, tuple[int], list[int]],
+ ) -> torch.Tensor:
+ """
+ Perform pooling by stride slicing the tensor along the given axis.
+ """
+ if tensor is None:
+ return None
+
+ # Do the stride pool recursively if axis is a list or a tuple of ints.
+ if isinstance(axis, (list, tuple)):
+ for ax in axis:
+ tensor = self.stride_pool(tensor, ax)
+ return tensor
+
+ # Do the stride pool recursively if tensor is a list or tuple of tensors.
+ if isinstance(tensor, (tuple, list)):
+ return type(tensor)(self.stride_pool(x, axis) for x in tensor)
+
+ # Deal with negative axis
+ axis %= tensor.ndim
+
+ axis_slice = (
+ slice(None, -1, 2) if self.config.separate_cls and self.config.truncate_seq else slice(None, None, 2)
+ )
+ enc_slice = [slice(None)] * axis + [axis_slice]
+ if self.config.separate_cls:
+ cls_slice = [slice(None)] * axis + [slice(None, 1)]
+ tensor = torch.cat([tensor[cls_slice], tensor], axis=axis)
+ return tensor[enc_slice]
+
+ def pool_tensor(
+ self, tensor: Union[torch.Tensor, tuple[torch.Tensor], list[torch.Tensor]], mode: str = "mean", stride: int = 2
+ ) -> torch.Tensor:
+ """Apply 1D pooling to a tensor of size [B x T (x H)]."""
+ if tensor is None:
+ return None
+
+ # Do the pool recursively if tensor is a list or tuple of tensors.
+ if isinstance(tensor, (tuple, list)):
+ return type(tensor)(self.pool_tensor(tensor, mode=mode, stride=stride) for x in tensor)
+
+ if self.config.separate_cls:
+ suffix = tensor[:, :-1] if self.config.truncate_seq else tensor
+ tensor = torch.cat([tensor[:, :1], suffix], dim=1)
+
+ ndim = tensor.ndim
+ if ndim == 2:
+ tensor = tensor[:, None, :, None]
+ elif ndim == 3:
+ tensor = tensor[:, None, :, :]
+ # Stride is applied on the second-to-last dimension.
+ stride = (stride, 1)
+
+ if mode == "mean":
+ tensor = nn.functional.avg_pool2d(tensor, stride, stride=stride, ceil_mode=True)
+ elif mode == "max":
+ tensor = nn.functional.max_pool2d(tensor, stride, stride=stride, ceil_mode=True)
+ elif mode == "min":
+ tensor = -nn.functional.max_pool2d(-tensor, stride, stride=stride, ceil_mode=True)
+ else:
+ raise NotImplementedError("The supported modes are 'mean', 'max' and 'min'.")
+
+ if ndim == 2:
+ return tensor[:, 0, :, 0]
+ elif ndim == 3:
+ return tensor[:, 0]
+ return tensor
+
+ def pre_attention_pooling(
+ self, output, attention_inputs: tuple[torch.Tensor]
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor]]:
+ """Pool `output` and the proper parts of `attention_inputs` before the attention layer."""
+ position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
+ if self.config.pool_q_only:
+ if self.config.attention_type == "factorized":
+ position_embeds = self.stride_pool(position_embeds[:2], 0) + position_embeds[2:]
+ token_type_mat = self.stride_pool(token_type_mat, 1)
+ cls_mask = self.stride_pool(cls_mask, 0)
+ output = self.pool_tensor(output, mode=self.config.pooling_type)
+ else:
+ self.pooling_mult *= 2
+ if self.config.attention_type == "factorized":
+ position_embeds = self.stride_pool(position_embeds, 0)
+ token_type_mat = self.stride_pool(token_type_mat, [1, 2])
+ cls_mask = self.stride_pool(cls_mask, [1, 2])
+ attention_mask = self.pool_tensor(attention_mask, mode="min")
+ output = self.pool_tensor(output, mode=self.config.pooling_type)
+ attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
+ return output, attention_inputs
+
+ def post_attention_pooling(self, attention_inputs: tuple[torch.Tensor]) -> tuple[torch.Tensor]:
+ """Pool the proper parts of `attention_inputs` after the attention layer."""
+ position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
+ if self.config.pool_q_only:
+ self.pooling_mult *= 2
+ if self.config.attention_type == "factorized":
+ position_embeds = position_embeds[:2] + self.stride_pool(position_embeds[2:], 0)
+ token_type_mat = self.stride_pool(token_type_mat, 2)
+ cls_mask = self.stride_pool(cls_mask, 1)
+ attention_mask = self.pool_tensor(attention_mask, mode="min")
+ attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
+ return attention_inputs
+
+
+def _relative_shift_gather(positional_attn: torch.Tensor, context_len: int, shift: int) -> torch.Tensor:
+ batch_size, n_head, seq_len, max_rel_len = positional_attn.shape
+ # max_rel_len = 2 * context_len + shift -1 is the numbers of possible relative positions i-j
+
+ # What's next is the same as doing the following gather, which might be clearer code but less efficient.
+ # idxs = context_len + torch.arange(0, context_len).unsqueeze(0) - torch.arange(0, seq_len).unsqueeze(1)
+ # # matrix of context_len + i-j
+ # return positional_attn.gather(3, idxs.expand([batch_size, n_head, context_len, context_len]))
+
+ positional_attn = torch.reshape(positional_attn, [batch_size, n_head, max_rel_len, seq_len])
+ positional_attn = positional_attn[:, :, shift:, :]
+ positional_attn = torch.reshape(positional_attn, [batch_size, n_head, seq_len, max_rel_len - shift])
+ positional_attn = positional_attn[..., :context_len]
+ return positional_attn
+
+
+class FunnelRelMultiheadAttention(nn.Module):
+ def __init__(self, config: FunnelConfig, block_index: int) -> None:
+ super().__init__()
+ self.config = config
+ self.block_index = block_index
+ d_model, n_head, d_head = config.d_model, config.n_head, config.d_head
+
+ self.hidden_dropout = nn.Dropout(config.hidden_dropout)
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
+
+ self.q_head = nn.Linear(d_model, n_head * d_head, bias=False)
+ self.k_head = nn.Linear(d_model, n_head * d_head)
+ self.v_head = nn.Linear(d_model, n_head * d_head)
+
+ self.r_w_bias = nn.Parameter(torch.zeros([n_head, d_head]))
+ self.r_r_bias = nn.Parameter(torch.zeros([n_head, d_head]))
+ self.r_kernel = nn.Parameter(torch.zeros([d_model, n_head, d_head]))
+ self.r_s_bias = nn.Parameter(torch.zeros([n_head, d_head]))
+ self.seg_embed = nn.Parameter(torch.zeros([2, n_head, d_head]))
+
+ self.post_proj = nn.Linear(n_head * d_head, d_model)
+ self.layer_norm = nn.LayerNorm(d_model, eps=config.layer_norm_eps)
+ self.scale = 1.0 / (d_head**0.5)
+
+ def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None):
+ """Relative attention score for the positional encodings"""
+ # q_head has shape batch_size x sea_len x n_head x d_head
+ if self.config.attention_type == "factorized":
+ # Notations from the paper, appending A.2.2, final formula (https://huggingface.co/papers/2006.03236)
+ # phi and pi have shape seq_len x d_model, psi and omega have shape context_len x d_model
+ phi, pi, psi, omega = position_embeds
+ # Shape n_head x d_head
+ u = self.r_r_bias * self.scale
+ # Shape d_model x n_head x d_head
+ w_r = self.r_kernel
+
+ # Shape batch_size x sea_len x n_head x d_model
+ q_r_attention = torch.einsum("binh,dnh->bind", q_head + u, w_r)
+ q_r_attention_1 = q_r_attention * phi[:, None]
+ q_r_attention_2 = q_r_attention * pi[:, None]
+
+ # Shape batch_size x n_head x seq_len x context_len
+ positional_attn = torch.einsum("bind,jd->bnij", q_r_attention_1, psi) + torch.einsum(
+ "bind,jd->bnij", q_r_attention_2, omega
+ )
+ else:
+ shift = 2 if q_head.shape[1] != context_len else 1
+ # Notations from the paper, appending A.2.1, final formula (https://huggingface.co/papers/2006.03236)
+ # Grab the proper positional encoding, shape max_rel_len x d_model
+ r = position_embeds[self.block_index][shift - 1]
+ # Shape n_head x d_head
+ v = self.r_r_bias * self.scale
+ # Shape d_model x n_head x d_head
+ w_r = self.r_kernel
+
+ # Shape max_rel_len x n_head x d_model
+ r_head = torch.einsum("td,dnh->tnh", r, w_r)
+ # Shape batch_size x n_head x seq_len x max_rel_len
+ positional_attn = torch.einsum("binh,tnh->bnit", q_head + v, r_head)
+ # Shape batch_size x n_head x seq_len x context_len
+ positional_attn = _relative_shift_gather(positional_attn, context_len, shift)
+
+ if cls_mask is not None:
+ positional_attn *= cls_mask
+ return positional_attn
+
+ def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None):
+ """Relative attention score for the token_type_ids"""
+ if token_type_mat is None:
+ return 0
+ batch_size, seq_len, context_len = token_type_mat.shape
+ # q_head has shape batch_size x seq_len x n_head x d_head
+ # Shape n_head x d_head
+ r_s_bias = self.r_s_bias * self.scale
+
+ # Shape batch_size x n_head x seq_len x 2
+ token_type_bias = torch.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed)
+ # Shape batch_size x n_head x seq_len x context_len
+ token_type_mat = token_type_mat[:, None].expand([batch_size, q_head.shape[2], seq_len, context_len])
+ # Shapes batch_size x n_head x seq_len
+ diff_token_type, same_token_type = torch.split(token_type_bias, 1, dim=-1)
+ # Shape batch_size x n_head x seq_len x context_len
+ token_type_attn = torch.where(
+ token_type_mat, same_token_type.expand(token_type_mat.shape), diff_token_type.expand(token_type_mat.shape)
+ )
+
+ if cls_mask is not None:
+ token_type_attn *= cls_mask
+ return token_type_attn
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_inputs: tuple[torch.Tensor],
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, ...]:
+ # query has shape batch_size x seq_len x d_model
+ # key and value have shapes batch_size x context_len x d_model
+ position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
+
+ batch_size, seq_len, _ = query.shape
+ context_len = key.shape[1]
+ n_head, d_head = self.config.n_head, self.config.d_head
+
+ # Shape batch_size x seq_len x n_head x d_head
+ q_head = self.q_head(query).view(batch_size, seq_len, n_head, d_head)
+ # Shapes batch_size x context_len x n_head x d_head
+ k_head = self.k_head(key).view(batch_size, context_len, n_head, d_head)
+ v_head = self.v_head(value).view(batch_size, context_len, n_head, d_head)
+
+ q_head = q_head * self.scale
+ # Shape n_head x d_head
+ r_w_bias = self.r_w_bias * self.scale
+ # Shapes batch_size x n_head x seq_len x context_len
+ content_score = torch.einsum("bind,bjnd->bnij", q_head + r_w_bias, k_head)
+ positional_attn = self.relative_positional_attention(position_embeds, q_head, context_len, cls_mask)
+ token_type_attn = self.relative_token_type_attention(token_type_mat, q_head, cls_mask)
+
+ # merge attention scores
+ attn_score = content_score + positional_attn + token_type_attn
+
+ # precision safe in case of mixed precision training
+ dtype = attn_score.dtype
+ attn_score = attn_score.float()
+ # perform masking
+ if attention_mask is not None:
+ attn_score = attn_score - INF * (1 - attention_mask[:, None, None].float())
+ # attention probability
+ attn_prob = torch.softmax(attn_score, dim=-1, dtype=dtype)
+ attn_prob = self.attention_dropout(attn_prob)
+
+ # attention output, shape batch_size x seq_len x n_head x d_head
+ attn_vec = torch.einsum("bnij,bjnd->bind", attn_prob, v_head)
+
+ # Shape shape batch_size x seq_len x d_model
+ attn_out = self.post_proj(attn_vec.reshape(batch_size, seq_len, n_head * d_head))
+ attn_out = self.hidden_dropout(attn_out)
+
+ output = self.layer_norm(query + attn_out)
+ return (output, attn_prob) if output_attentions else (output,)
+
+
+class FunnelPositionwiseFFN(nn.Module):
+ def __init__(self, config: FunnelConfig) -> None:
+ super().__init__()
+ self.linear_1 = nn.Linear(config.d_model, config.d_inner)
+ self.activation_function = ACT2FN[config.hidden_act]
+ self.activation_dropout = nn.Dropout(config.activation_dropout)
+ self.linear_2 = nn.Linear(config.d_inner, config.d_model)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layer_norm = nn.LayerNorm(config.d_model, config.layer_norm_eps)
+
+ def forward(self, hidden: torch.Tensor) -> torch.Tensor:
+ h = self.linear_1(hidden)
+ h = self.activation_function(h)
+ h = self.activation_dropout(h)
+ h = self.linear_2(h)
+ h = self.dropout(h)
+ return self.layer_norm(hidden + h)
+
+
+class FunnelLayer(nn.Module):
+ def __init__(self, config: FunnelConfig, block_index: int) -> None:
+ super().__init__()
+ self.attention = FunnelRelMultiheadAttention(config, block_index)
+ self.ffn = FunnelPositionwiseFFN(config)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_inputs,
+ output_attentions: bool = False,
+ ) -> tuple:
+ attn = self.attention(query, key, value, attention_inputs, output_attentions=output_attentions)
+ output = self.ffn(attn[0])
+ return (output, attn[1]) if output_attentions else (output,)
+
+
+class FunnelEncoder(nn.Module):
+ def __init__(self, config: FunnelConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.attention_structure = FunnelAttentionStructure(config)
+ self.blocks = nn.ModuleList(
+ [
+ nn.ModuleList([FunnelLayer(config, block_index) for _ in range(block_size)])
+ for block_index, block_size in enumerate(config.block_sizes)
+ ]
+ )
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ # The pooling is not implemented on long tensors, so we convert this mask.
+ attention_mask = attention_mask.type_as(inputs_embeds)
+ attention_inputs = self.attention_structure.init_attention_inputs(
+ inputs_embeds,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ )
+ hidden = inputs_embeds
+
+ all_hidden_states = (inputs_embeds,) if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ for block_index, block in enumerate(self.blocks):
+ pooling_flag = hidden.size(1) > (2 if self.config.separate_cls else 1)
+ pooling_flag = pooling_flag and block_index > 0
+ if pooling_flag:
+ pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling(
+ hidden, attention_inputs
+ )
+ for layer_index, layer in enumerate(block):
+ for repeat_index in range(self.config.block_repeats[block_index]):
+ do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag
+ if do_pooling:
+ query = pooled_hidden
+ key = value = hidden if self.config.pool_q_only else pooled_hidden
+ else:
+ query = key = value = hidden
+ layer_output = layer(query, key, value, attention_inputs, output_attentions=output_attentions)
+ hidden = layer_output[0]
+ if do_pooling:
+ attention_inputs = self.attention_structure.post_attention_pooling(attention_inputs)
+
+ if output_attentions:
+ all_attentions = all_attentions + layer_output[1:]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden, all_hidden_states, all_attentions] if v is not None)
+ return BaseModelOutput(last_hidden_state=hidden, hidden_states=all_hidden_states, attentions=all_attentions)
+
+
+def upsample(
+ x: torch.Tensor, stride: int, target_len: int, separate_cls: bool = True, truncate_seq: bool = False
+) -> torch.Tensor:
+ """
+ Upsample tensor `x` to match `target_len` by repeating the tokens `stride` time on the sequence length dimension.
+ """
+ if stride == 1:
+ return x
+ if separate_cls:
+ cls = x[:, :1]
+ x = x[:, 1:]
+ output = torch.repeat_interleave(x, repeats=stride, dim=1)
+ if separate_cls:
+ if truncate_seq:
+ output = nn.functional.pad(output, (0, 0, 0, stride - 1, 0, 0))
+ output = output[:, : target_len - 1]
+ output = torch.cat([cls, output], dim=1)
+ else:
+ output = output[:, :target_len]
+ return output
+
+
+class FunnelDecoder(nn.Module):
+ def __init__(self, config: FunnelConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.attention_structure = FunnelAttentionStructure(config)
+ self.layers = nn.ModuleList([FunnelLayer(config, 0) for _ in range(config.num_decoder_layers)])
+
+ def forward(
+ self,
+ final_hidden: torch.Tensor,
+ first_block_hidden: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ upsampled_hidden = upsample(
+ final_hidden,
+ stride=2 ** (len(self.config.block_sizes) - 1),
+ target_len=first_block_hidden.shape[1],
+ separate_cls=self.config.separate_cls,
+ truncate_seq=self.config.truncate_seq,
+ )
+
+ hidden = upsampled_hidden + first_block_hidden
+ all_hidden_states = (hidden,) if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ attention_inputs = self.attention_structure.init_attention_inputs(
+ hidden,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ )
+
+ for layer in self.layers:
+ layer_output = layer(hidden, hidden, hidden, attention_inputs, output_attentions=output_attentions)
+ hidden = layer_output[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + layer_output[1:]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden, all_hidden_states, all_attentions] if v is not None)
+ return BaseModelOutput(last_hidden_state=hidden, hidden_states=all_hidden_states, attentions=all_attentions)
+
+
+class FunnelDiscriminatorPredictions(nn.Module):
+ """Prediction module for the discriminator, made up of two dense layers."""
+
+ def __init__(self, config: FunnelConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.dense = nn.Linear(config.d_model, config.d_model)
+ self.dense_prediction = nn.Linear(config.d_model, 1)
+
+ def forward(self, discriminator_hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(discriminator_hidden_states)
+ hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
+ logits = self.dense_prediction(hidden_states).squeeze(-1)
+ return logits
+
+
+@auto_docstring
+class FunnelPreTrainedModel(PreTrainedModel):
+ config: FunnelConfig
+ load_tf_weights = load_tf_weights_in_funnel
+ base_model_prefix = "funnel"
+
+ def _init_weights(self, module):
+ classname = module.__class__.__name__
+ if classname.find("Linear") != -1:
+ if getattr(module, "weight", None) is not None:
+ if self.config.initializer_std is None:
+ fan_out, fan_in = module.weight.shape
+ std = np.sqrt(1.0 / float(fan_in + fan_out))
+ else:
+ std = self.config.initializer_std
+ nn.init.normal_(module.weight, std=std)
+ if getattr(module, "bias", None) is not None:
+ nn.init.constant_(module.bias, 0.0)
+ elif classname == "FunnelRelMultiheadAttention":
+ nn.init.uniform_(module.r_w_bias, b=self.config.initializer_range)
+ nn.init.uniform_(module.r_r_bias, b=self.config.initializer_range)
+ nn.init.uniform_(module.r_kernel, b=self.config.initializer_range)
+ nn.init.uniform_(module.r_s_bias, b=self.config.initializer_range)
+ nn.init.uniform_(module.seg_embed, b=self.config.initializer_range)
+ elif classname == "FunnelEmbeddings":
+ std = 1.0 if self.config.initializer_std is None else self.config.initializer_std
+ nn.init.normal_(module.word_embeddings.weight, std=std)
+ if module.word_embeddings.padding_idx is not None:
+ module.word_embeddings.weight.data[module.word_embeddings.padding_idx].zero_()
+
+
+class FunnelClassificationHead(nn.Module):
+ def __init__(self, config: FunnelConfig, n_labels: int) -> None:
+ super().__init__()
+ self.linear_hidden = nn.Linear(config.d_model, config.d_model)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.linear_out = nn.Linear(config.d_model, n_labels)
+
+ def forward(self, hidden: torch.Tensor) -> torch.Tensor:
+ hidden = self.linear_hidden(hidden)
+ hidden = torch.tanh(hidden)
+ hidden = self.dropout(hidden)
+ return self.linear_out(hidden)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Output type of [`FunnelForPreTraining`].
+ """
+)
+class FunnelForPreTrainingOutput(ModelOutput):
+ r"""
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
+ Total loss of the ELECTRA-style objective.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Prediction scores of the head (scores for each token before SoftMax).
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ The base Funnel Transformer Model transformer outputting raw hidden-states without upsampling head (also called
+ decoder) or any task-specific head on top.
+ """
+)
+class FunnelBaseModel(FunnelPreTrainedModel):
+ def __init__(self, config: FunnelConfig) -> None:
+ super().__init__(config)
+
+ self.embeddings = FunnelEmbeddings(config)
+ self.encoder = FunnelEncoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Embedding:
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
+ self.embeddings.word_embeddings = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # TODO: deal with head_mask
+ inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ return encoder_outputs
+
+
+@auto_docstring
+class FunnelModel(FunnelPreTrainedModel):
+ def __init__(self, config: FunnelConfig) -> None:
+ super().__init__(config)
+ self.config = config
+ self.embeddings = FunnelEmbeddings(config)
+ self.encoder = FunnelEncoder(config)
+ self.decoder = FunnelDecoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Embedding:
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
+ self.embeddings.word_embeddings = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # TODO: deal with head_mask
+ inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ return_dict=return_dict,
+ )
+
+ decoder_outputs = self.decoder(
+ final_hidden=encoder_outputs[0],
+ first_block_hidden=encoder_outputs[1][self.config.block_sizes[0]],
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ idx = 0
+ outputs = (decoder_outputs[0],)
+ if output_hidden_states:
+ idx += 1
+ outputs = outputs + (encoder_outputs[1] + decoder_outputs[idx],)
+ if output_attentions:
+ idx += 1
+ outputs = outputs + (encoder_outputs[2] + decoder_outputs[idx],)
+ return outputs
+
+ return BaseModelOutput(
+ last_hidden_state=decoder_outputs[0],
+ hidden_states=(encoder_outputs.hidden_states + decoder_outputs.hidden_states)
+ if output_hidden_states
+ else None,
+ attentions=(encoder_outputs.attentions + decoder_outputs.attentions) if output_attentions else None,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Funnel Transformer model with a binary classification head on top as used during pretraining for identifying
+ generated tokens.
+ """
+)
+class FunnelForPreTraining(FunnelPreTrainedModel):
+ def __init__(self, config: FunnelConfig) -> None:
+ super().__init__(config)
+
+ self.funnel = FunnelModel(config)
+ self.discriminator_predictions = FunnelDiscriminatorPredictions(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, FunnelForPreTrainingOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the ELECTRA-style loss. Input should be a sequence of tokens (see `input_ids`
+ docstring) Indices should be in `[0, 1]`:
+
+ - 0 indicates the token is an original token,
+ - 1 indicates the token was replaced.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, FunnelForPreTraining
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("funnel-transformer/small")
+ >>> model = FunnelForPreTraining.from_pretrained("funnel-transformer/small")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> logits = model(**inputs).logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ discriminator_hidden_states = self.funnel(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ discriminator_sequence_output = discriminator_hidden_states[0]
+
+ logits = self.discriminator_predictions(discriminator_sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = nn.BCEWithLogitsLoss()
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1, discriminator_sequence_output.shape[1]) == 1
+ active_logits = logits.view(-1, discriminator_sequence_output.shape[1])[active_loss]
+ active_labels = labels[active_loss]
+ loss = loss_fct(active_logits, active_labels.float())
+ else:
+ loss = loss_fct(logits.view(-1, discriminator_sequence_output.shape[1]), labels.float())
+
+ if not return_dict:
+ output = (logits,) + discriminator_hidden_states[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return FunnelForPreTrainingOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+
+@auto_docstring
+class FunnelForMaskedLM(FunnelPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: FunnelConfig) -> None:
+ super().__init__(config)
+
+ self.funnel = FunnelModel(config)
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self) -> nn.Linear:
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings: nn.Embedding) -> None:
+ self.lm_head = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.funnel(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = outputs[0]
+ prediction_logits = self.lm_head(last_hidden_state)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_logits,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Funnel Transformer Model with a sequence classification/regression head on top (two linear layer on top of the
+ first timestep of the last hidden state) e.g. for GLUE tasks.
+ """
+)
+class FunnelForSequenceClassification(FunnelPreTrainedModel):
+ def __init__(self, config: FunnelConfig) -> None:
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.funnel = FunnelBaseModel(config)
+ self.classifier = FunnelClassificationHead(config, config.num_labels)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.funnel(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = outputs[0]
+ pooled_output = last_hidden_state[:, 0]
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class FunnelForMultipleChoice(FunnelPreTrainedModel):
+ def __init__(self, config: FunnelConfig) -> None:
+ super().__init__(config)
+
+ self.funnel = FunnelBaseModel(config)
+ self.classifier = FunnelClassificationHead(config, 1)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, MultipleChoiceModelOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.funnel(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = outputs[0]
+ pooled_output = last_hidden_state[:, 0]
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class FunnelForTokenClassification(FunnelPreTrainedModel):
+ def __init__(self, config: FunnelConfig) -> None:
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.funnel = FunnelModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.funnel(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = outputs[0]
+ last_hidden_state = self.dropout(last_hidden_state)
+ logits = self.classifier(last_hidden_state)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class FunnelForQuestionAnswering(FunnelPreTrainedModel):
+ def __init__(self, config: FunnelConfig) -> None:
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.funnel = FunnelModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, QuestionAnsweringModelOutput]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.funnel(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = outputs[0]
+
+ logits = self.qa_outputs(last_hidden_state)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "FunnelBaseModel",
+ "FunnelForMaskedLM",
+ "FunnelForMultipleChoice",
+ "FunnelForPreTraining",
+ "FunnelForQuestionAnswering",
+ "FunnelForSequenceClassification",
+ "FunnelForTokenClassification",
+ "FunnelModel",
+ "FunnelPreTrainedModel",
+ "load_tf_weights_in_funnel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/funnel/modeling_tf_funnel.py b/venv/lib/python3.13/site-packages/transformers/models/funnel/modeling_tf_funnel.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d57fa99eaa14a38214da8200b0af768aef9ddaf
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/funnel/modeling_tf_funnel.py
@@ -0,0 +1,1883 @@
+# coding=utf-8
+# Copyright 2020-present Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 Funnel model."""
+
+from __future__ import annotations
+
+import warnings
+from dataclasses import dataclass
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutput,
+ TFMaskedLMOutput,
+ TFMultipleChoiceModelOutput,
+ TFQuestionAnsweringModelOutput,
+ TFSequenceClassifierOutput,
+ TFTokenClassifierOutput,
+)
+from ...modeling_tf_utils import (
+ TFMaskedLanguageModelingLoss,
+ TFModelInputType,
+ TFMultipleChoiceLoss,
+ TFPreTrainedModel,
+ TFQuestionAnsweringLoss,
+ TFSequenceClassificationLoss,
+ TFTokenClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_funnel import FunnelConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "FunnelConfig"
+
+
+INF = 1e6
+
+
+class TFFunnelEmbeddings(keras.layers.Layer):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.initializer_std = 1.0 if config.initializer_std is None else config.initializer_std
+
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout)
+
+ def build(self, input_shape=None):
+ with tf.name_scope("word_embeddings"):
+ self.weight = self.add_weight(
+ name="weight",
+ shape=[self.config.vocab_size, self.hidden_size],
+ initializer=get_initializer(initializer_range=self.initializer_std),
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.d_model])
+
+ def call(self, input_ids=None, inputs_embeds=None, training=False):
+ """
+ Applies embedding based on inputs tensor.
+
+ Returns:
+ final_embeddings (`tf.Tensor`): output embedding tensor.
+ """
+ assert not (input_ids is None and inputs_embeds is None)
+ assert not (input_ids is not None and inputs_embeds is not None)
+
+ if input_ids is not None:
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
+ inputs_embeds = tf.gather(self.weight, input_ids)
+
+ final_embeddings = self.LayerNorm(inputs=inputs_embeds)
+ final_embeddings = self.dropout(inputs=final_embeddings, training=training)
+
+ return final_embeddings
+
+
+class TFFunnelAttentionStructure:
+ """
+ Contains helpers for `TFFunnelRelMultiheadAttention `.
+ """
+
+ cls_token_type_id: int = 2
+
+ def __init__(self, config):
+ self.d_model = config.d_model
+ self.attention_type = config.attention_type
+ self.num_blocks = config.num_blocks
+ self.separate_cls = config.separate_cls
+ self.truncate_seq = config.truncate_seq
+ self.pool_q_only = config.pool_q_only
+ self.pooling_type = config.pooling_type
+
+ self.sin_dropout = keras.layers.Dropout(config.hidden_dropout)
+ self.cos_dropout = keras.layers.Dropout(config.hidden_dropout)
+ # Track where we are at in terms of pooling from the original input, e.g., by how much the sequence length was
+ # divided.
+ self.pooling_mult = None
+
+ def init_attention_inputs(self, inputs_embeds, attention_mask=None, token_type_ids=None, training=False):
+ """Returns the attention inputs associated to the inputs of the model."""
+ # inputs_embeds has shape batch_size x seq_len x d_model
+ # attention_mask and token_type_ids have shape batch_size x seq_len
+ self.pooling_mult = 1
+ self.seq_len = seq_len = shape_list(inputs_embeds)[1]
+ position_embeds = self.get_position_embeds(seq_len, training=training)
+ token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
+ cls_mask = (
+ tf.pad(tf.ones([seq_len - 1, seq_len - 1], dtype=inputs_embeds.dtype), [[1, 0], [1, 0]])
+ if self.separate_cls
+ else None
+ )
+ return (position_embeds, token_type_mat, attention_mask, cls_mask)
+
+ def token_type_ids_to_mat(self, token_type_ids):
+ """Convert `token_type_ids` to `token_type_mat`."""
+ token_type_mat = tf.equal(tf.expand_dims(token_type_ids, -1), tf.expand_dims(token_type_ids, -2))
+ # Treat as in the same segment as both A & B
+ cls_ids = tf.equal(token_type_ids, tf.constant([self.cls_token_type_id], dtype=token_type_ids.dtype))
+ cls_mat = tf.logical_or(tf.expand_dims(cls_ids, -1), tf.expand_dims(cls_ids, -2))
+ return tf.logical_or(cls_mat, token_type_mat)
+
+ def get_position_embeds(self, seq_len, training=False):
+ """
+ Create and cache inputs related to relative position encoding. Those are very different depending on whether we
+ are using the factorized or the relative shift attention:
+
+ For the factorized attention, it returns the matrices (phi, pi, psi, omega) used in the paper, appendix A.2.2,
+ final formula.
+
+ For the relative shift attention, it returns all possible vectors R used in the paper, appendix A.2.1, final
+ formula.
+
+ Paper link: https://huggingface.co/papers/2006.03236
+ """
+ if self.attention_type == "factorized":
+ # Notations from the paper, appending A.2.2, final formula.
+ # We need to create and return the matrices phi, psi, pi and omega.
+ pos_seq = tf.range(0, seq_len, 1.0)
+ freq_seq = tf.range(0, self.d_model // 2, 1.0)
+ inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
+ sinusoid = tf.einsum("i,d->id", pos_seq, inv_freq)
+
+ sin_embed = tf.sin(sinusoid)
+ sin_embed_d = self.sin_dropout(sin_embed, training=training)
+ cos_embed = tf.cos(sinusoid)
+ cos_embed_d = self.cos_dropout(cos_embed, training=training)
+ # This is different from the formula on the paper...
+ phi = tf.concat([sin_embed_d, sin_embed_d], axis=-1)
+ psi = tf.concat([cos_embed, sin_embed], axis=-1)
+ pi = tf.concat([cos_embed_d, cos_embed_d], axis=-1)
+ omega = tf.concat([-sin_embed, cos_embed], axis=-1)
+ return (phi, pi, psi, omega)
+ else:
+ # Notations from the paper, appending A.2.1, final formula.
+ # We need to create and return all the possible vectors R for all blocks and shifts.
+ freq_seq = tf.range(0, self.d_model // 2, 1.0)
+ inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
+ # Maximum relative positions for the first input
+ rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0)
+ zero_offset = seq_len * tf.constant(2)
+ sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq)
+ sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training)
+ cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training)
+ pos_embed = tf.concat([sin_embed, cos_embed], axis=-1)
+
+ pos = tf.range(0, seq_len)
+ pooled_pos = pos
+ position_embeds_list = []
+ for block_index in range(0, self.num_blocks):
+ # For each block with block_index > 0, we need two types position embeddings:
+ # - Attention(pooled-q, unpooled-kv)
+ # - Attention(pooled-q, pooled-kv)
+ # For block_index = 0 we only need the second one and leave the first one as None.
+
+ # First type
+ position_embeds_pooling = tf.fill([1], value=-1.0)
+
+ if block_index != 0:
+ pooled_pos = self.stride_pool_pos(pos, block_index)
+
+ # construct rel_pos_id
+ stride = 2 ** (block_index - 1)
+ rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2)
+ # rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
+ # rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
+ rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
+ rel_pos = rel_pos + zero_offset
+ position_embeds_pooling = tf.gather(pos_embed, rel_pos, axis=0)
+
+ # Second type
+ pos = pooled_pos
+ stride = 2**block_index
+ rel_pos = self.relative_pos(pos, stride)
+
+ # rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
+ # rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
+ rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
+ rel_pos = rel_pos + zero_offset
+ tf.debugging.assert_less(rel_pos, tf.shape(pos_embed)[0])
+ position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0)
+
+ position_embeds_list.append([position_embeds_no_pooling, position_embeds_pooling])
+ return position_embeds_list
+
+ def stride_pool_pos(self, pos_id, block_index):
+ """
+ Pool `pos_id` while keeping the cls token separate (if `self.separate_cls=True`).
+ """
+ if self.separate_cls:
+ # Under separate , we treat the as the first token in
+ # the previous block of the 1st real block. Since the 1st real
+ # block always has position 1, the position of the previous block
+ # will be at `1 - 2 ** block_index`.
+ cls_pos = tf.constant([-(2**block_index) + 1], dtype=pos_id.dtype)
+ pooled_pos_id = pos_id[1:-1] if self.truncate_seq else pos_id[1:]
+ return tf.concat([cls_pos, pooled_pos_id[::2]], 0)
+ else:
+ return pos_id[::2]
+
+ def relative_pos(self, pos, stride, pooled_pos=None, shift=1):
+ """
+ Build the relative positional vector between `pos` and `pooled_pos`.
+ """
+ if pooled_pos is None:
+ pooled_pos = pos
+
+ ref_point = pooled_pos[0] - pos[0]
+ num_remove = shift * shape_list(pooled_pos)[0]
+ max_dist = ref_point + num_remove * stride
+ min_dist = pooled_pos[0] - pos[-1]
+
+ return tf.range(max_dist, min_dist - 1, -stride)
+
+ def stride_pool(self, tensor, axis):
+ """
+ Perform pooling by stride slicing the tensor along the given axis.
+ """
+ if tensor is None:
+ return None
+
+ # Do the stride pool recursively if axis is a list or a tuple of ints.
+ if isinstance(axis, (list, tuple)):
+ for ax in axis:
+ tensor = self.stride_pool(tensor, ax)
+ return tensor
+
+ # Do the stride pool recursively if tensor is a list or tuple of tensors.
+ if isinstance(tensor, (tuple, list)):
+ return type(tensor)(self.stride_pool(x, axis) for x in tensor)
+
+ # Deal with negative axis
+ axis %= len(shape_list(tensor))
+
+ axis_slice = slice(None, -1, 2) if self.separate_cls and self.truncate_seq else slice(None, None, 2)
+ enc_slice = [slice(None)] * axis + [axis_slice]
+ if self.separate_cls:
+ cls_slice = [slice(None)] * axis + [slice(None, 1)]
+ tensor = tf.concat([tensor[cls_slice], tensor], axis)
+ return tensor[enc_slice]
+
+ def pool_tensor(self, tensor, mode="mean", stride=2):
+ """Apply 1D pooling to a tensor of size [B x T (x H)]."""
+ if tensor is None:
+ return None
+
+ # Do the pool recursively if tensor is a list or tuple of tensors.
+ if isinstance(tensor, (tuple, list)):
+ return type(tensor)(self.pool_tensor(tensor, mode=mode, stride=stride) for x in tensor)
+
+ if self.separate_cls:
+ suffix = tensor[:, :-1] if self.truncate_seq else tensor
+ tensor = tf.concat([tensor[:, :1], suffix], axis=1)
+
+ ndim = len(shape_list(tensor))
+ if ndim == 2:
+ tensor = tensor[:, :, None]
+
+ if mode == "mean":
+ tensor = tf.nn.avg_pool1d(tensor, stride, strides=stride, data_format="NWC", padding="SAME")
+ elif mode == "max":
+ tensor = tf.nn.max_pool1d(tensor, stride, strides=stride, data_format="NWC", padding="SAME")
+ elif mode == "min":
+ tensor = -tf.nn.max_pool1d(-tensor, stride, strides=stride, data_format="NWC", padding="SAME")
+ else:
+ raise NotImplementedError("The supported modes are 'mean', 'max' and 'min'.")
+
+ return tf.squeeze(tensor, 2) if ndim == 2 else tensor
+
+ def pre_attention_pooling(self, output, attention_inputs):
+ """Pool `output` and the proper parts of `attention_inputs` before the attention layer."""
+ position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
+ if self.pool_q_only:
+ if self.attention_type == "factorized":
+ position_embeds = self.stride_pool(position_embeds[:2], 0) + position_embeds[2:]
+ token_type_mat = self.stride_pool(token_type_mat, 1)
+ cls_mask = self.stride_pool(cls_mask, 0)
+ output = self.pool_tensor(output, mode=self.pooling_type)
+ else:
+ self.pooling_mult *= 2
+ if self.attention_type == "factorized":
+ position_embeds = self.stride_pool(position_embeds, 0)
+ token_type_mat = self.stride_pool(token_type_mat, [1, 2])
+ cls_mask = self.stride_pool(cls_mask, [1, 2])
+ attention_mask = self.pool_tensor(attention_mask, mode="min")
+ output = self.pool_tensor(output, mode=self.pooling_type)
+ attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
+ return output, attention_inputs
+
+ def post_attention_pooling(self, attention_inputs):
+ """Pool the proper parts of `attention_inputs` after the attention layer."""
+ position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
+ if self.pool_q_only:
+ self.pooling_mult *= 2
+ if self.attention_type == "factorized":
+ position_embeds = position_embeds[:2] + self.stride_pool(position_embeds[2:], 0)
+ token_type_mat = self.stride_pool(token_type_mat, 2)
+ cls_mask = self.stride_pool(cls_mask, 1)
+ attention_mask = self.pool_tensor(attention_mask, mode="min")
+ attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
+ return attention_inputs
+
+
+def _relative_shift_gather(positional_attn, context_len, shift):
+ batch_size, n_head, seq_len, max_rel_len = shape_list(positional_attn)
+ # max_rel_len = 2 * context_len + shift -1 is the numbers of possible relative positions i-j
+
+ # What's next is the same as doing the following gather in PyTorch, which might be clearer code but less efficient.
+ # idxs = context_len + torch.arange(0, context_len).unsqueeze(0) - torch.arange(0, seq_len).unsqueeze(1)
+ # # matrix of context_len + i-j
+ # return positional_attn.gather(3, idxs.expand([batch_size, n_head, context_len, context_len]))
+
+ positional_attn = tf.reshape(positional_attn, [batch_size, n_head, max_rel_len, seq_len])
+ positional_attn = positional_attn[:, :, shift:, :]
+ positional_attn = tf.reshape(positional_attn, [batch_size, n_head, seq_len, max_rel_len - shift])
+ positional_attn = positional_attn[..., :context_len]
+ return positional_attn
+
+
+class TFFunnelRelMultiheadAttention(keras.layers.Layer):
+ def __init__(self, config, block_index, **kwargs):
+ super().__init__(**kwargs)
+ self.attention_type = config.attention_type
+ self.n_head = n_head = config.n_head
+ self.d_head = d_head = config.d_head
+ self.d_model = d_model = config.d_model
+ self.initializer_range = config.initializer_range
+ self.block_index = block_index
+
+ self.hidden_dropout = keras.layers.Dropout(config.hidden_dropout)
+ self.attention_dropout = keras.layers.Dropout(config.attention_dropout)
+
+ initializer = get_initializer(config.initializer_range)
+
+ self.q_head = keras.layers.Dense(
+ n_head * d_head, use_bias=False, kernel_initializer=initializer, name="q_head"
+ )
+ self.k_head = keras.layers.Dense(n_head * d_head, kernel_initializer=initializer, name="k_head")
+ self.v_head = keras.layers.Dense(n_head * d_head, kernel_initializer=initializer, name="v_head")
+
+ self.post_proj = keras.layers.Dense(d_model, kernel_initializer=initializer, name="post_proj")
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+ self.scale = 1.0 / (d_head**0.5)
+
+ def build(self, input_shape=None):
+ n_head, d_head, d_model = self.n_head, self.d_head, self.d_model
+ initializer = get_initializer(self.initializer_range)
+
+ self.r_w_bias = self.add_weight(
+ shape=(n_head, d_head), initializer=initializer, trainable=True, name="r_w_bias"
+ )
+ self.r_r_bias = self.add_weight(
+ shape=(n_head, d_head), initializer=initializer, trainable=True, name="r_r_bias"
+ )
+ self.r_kernel = self.add_weight(
+ shape=(d_model, n_head, d_head), initializer=initializer, trainable=True, name="r_kernel"
+ )
+ self.r_s_bias = self.add_weight(
+ shape=(n_head, d_head), initializer=initializer, trainable=True, name="r_s_bias"
+ )
+ self.seg_embed = self.add_weight(
+ shape=(2, n_head, d_head), initializer=initializer, trainable=True, name="seg_embed"
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "q_head", None) is not None:
+ with tf.name_scope(self.q_head.name):
+ self.q_head.build([None, None, d_model])
+ if getattr(self, "k_head", None) is not None:
+ with tf.name_scope(self.k_head.name):
+ self.k_head.build([None, None, d_model])
+ if getattr(self, "v_head", None) is not None:
+ with tf.name_scope(self.v_head.name):
+ self.v_head.build([None, None, d_model])
+ if getattr(self, "post_proj", None) is not None:
+ with tf.name_scope(self.post_proj.name):
+ self.post_proj.build([None, None, n_head * d_head])
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, d_model])
+
+ def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None):
+ """Relative attention score for the positional encodings"""
+ # q_head has shape batch_size x sea_len x n_head x d_head
+ if self.attention_type == "factorized":
+ # Notations from the paper, appending A.2.2, final formula (https://huggingface.co/papers/2006.03236)
+ # phi and pi have shape seq_len x d_model, psi and omega have shape context_len x d_model
+ phi, pi, psi, omega = position_embeds
+ # Shape n_head x d_head
+ u = self.r_r_bias * self.scale
+ # Shape d_model x n_head x d_head
+ w_r = self.r_kernel
+
+ # Shape batch_size x sea_len x n_head x d_model
+ q_r_attention = tf.einsum("binh,dnh->bind", q_head + u, w_r)
+ q_r_attention_1 = q_r_attention * phi[:, None]
+ q_r_attention_2 = q_r_attention * pi[:, None]
+
+ # Shape batch_size x n_head x seq_len x context_len
+ positional_attn = tf.einsum("bind,jd->bnij", q_r_attention_1, psi) + tf.einsum(
+ "bind,jd->bnij", q_r_attention_2, omega
+ )
+ else:
+ # Notations from the paper, appending A.2.1, final formula (https://huggingface.co/papers/2006.03236)
+ # Grab the proper positional encoding, shape max_rel_len x d_model
+ if shape_list(q_head)[1] != context_len:
+ shift = 2
+ r = position_embeds[self.block_index][1]
+ else:
+ shift = 1
+ r = position_embeds[self.block_index][0]
+ # Shape n_head x d_head
+ v = self.r_r_bias * self.scale
+ # Shape d_model x n_head x d_head
+ w_r = self.r_kernel
+
+ # Shape max_rel_len x n_head x d_model
+ r_head = tf.einsum("td,dnh->tnh", r, w_r)
+ # Shape batch_size x n_head x seq_len x max_rel_len
+ positional_attn = tf.einsum("binh,tnh->bnit", q_head + v, r_head)
+ # Shape batch_size x n_head x seq_len x context_len
+ positional_attn = _relative_shift_gather(positional_attn, context_len, shift)
+
+ if cls_mask is not None:
+ positional_attn *= cls_mask
+ return positional_attn
+
+ def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None):
+ """Relative attention score for the token_type_ids"""
+ if token_type_mat is None:
+ return 0
+ batch_size, seq_len, context_len = shape_list(token_type_mat)
+ # q_head has shape batch_size x seq_len x n_head x d_head
+ # Shape n_head x d_head
+ r_s_bias = self.r_s_bias * self.scale
+
+ # Shape batch_size x n_head x seq_len x 2
+ token_type_bias = tf.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed)
+ # Shape batch_size x n_head x seq_len x context_len
+ token_type_mat = tf.tile(token_type_mat[:, None], [1, shape_list(q_head)[2], 1, 1])
+ # token_type_mat = tf.broadcast_to(token_type_mat[:, None], new_shape)
+ # Shapes batch_size x n_head x seq_len
+ diff_token_type, same_token_type = tf.split(token_type_bias, 2, axis=-1)
+ # Shape batch_size x n_head x seq_len x context_len
+ token_type_attn = tf.where(
+ token_type_mat,
+ tf.tile(same_token_type, [1, 1, 1, context_len]),
+ tf.tile(diff_token_type, [1, 1, 1, context_len]),
+ )
+
+ if cls_mask is not None:
+ token_type_attn *= cls_mask
+ return token_type_attn
+
+ def call(self, query, key, value, attention_inputs, output_attentions=False, training=False):
+ # query has shape batch_size x seq_len x d_model
+ # key and value have shapes batch_size x context_len x d_model
+ position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
+
+ batch_size, seq_len, _ = shape_list(query)
+ context_len = shape_list(key)[1]
+ n_head, d_head = self.n_head, self.d_head
+
+ # Shape batch_size x seq_len x n_head x d_head
+ q_head = tf.reshape(self.q_head(query), [batch_size, seq_len, n_head, d_head])
+ # Shapes batch_size x context_len x n_head x d_head
+ k_head = tf.reshape(self.k_head(key), [batch_size, context_len, n_head, d_head])
+ v_head = tf.reshape(self.v_head(value), [batch_size, context_len, n_head, d_head])
+
+ q_head = q_head * self.scale
+ # Shape n_head x d_head
+ r_w_bias = self.r_w_bias * self.scale
+ # Shapes batch_size x n_head x seq_len x context_len
+ content_score = tf.einsum("bind,bjnd->bnij", q_head + r_w_bias, k_head)
+ positional_attn = self.relative_positional_attention(position_embeds, q_head, context_len, cls_mask)
+ token_type_attn = self.relative_token_type_attention(token_type_mat, q_head, cls_mask)
+
+ # merge attention scores
+ attn_score = content_score + positional_attn + token_type_attn
+
+ # perform masking
+ if attention_mask is not None:
+ attention_mask = tf.cast(attention_mask, dtype=attn_score.dtype)
+ attn_score = attn_score - (INF * (1 - attention_mask[:, None, None]))
+
+ # attention probability
+ attn_prob = stable_softmax(attn_score, axis=-1)
+ attn_prob = self.attention_dropout(attn_prob, training=training)
+
+ # attention output, shape batch_size x seq_len x n_head x d_head
+ attn_vec = tf.einsum("bnij,bjnd->bind", attn_prob, v_head)
+
+ # Shape shape batch_size x seq_len x d_model
+ attn_out = self.post_proj(tf.reshape(attn_vec, [batch_size, seq_len, n_head * d_head]))
+ attn_out = self.hidden_dropout(attn_out, training=training)
+
+ output = self.layer_norm(query + attn_out)
+ return (output, attn_prob) if output_attentions else (output,)
+
+
+class TFFunnelPositionwiseFFN(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ initializer = get_initializer(config.initializer_range)
+ self.linear_1 = keras.layers.Dense(config.d_inner, kernel_initializer=initializer, name="linear_1")
+ self.activation_function = get_tf_activation(config.hidden_act)
+ self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
+ self.linear_2 = keras.layers.Dense(config.d_model, kernel_initializer=initializer, name="linear_2")
+ self.dropout = keras.layers.Dropout(config.hidden_dropout)
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+ self.config = config
+
+ def call(self, hidden, training=False):
+ h = self.linear_1(hidden)
+ h = self.activation_function(h)
+ h = self.activation_dropout(h, training=training)
+ h = self.linear_2(h)
+ h = self.dropout(h, training=training)
+ return self.layer_norm(hidden + h)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "linear_1", None) is not None:
+ with tf.name_scope(self.linear_1.name):
+ self.linear_1.build([None, None, self.config.d_model])
+ if getattr(self, "linear_2", None) is not None:
+ with tf.name_scope(self.linear_2.name):
+ self.linear_2.build([None, None, self.config.d_inner])
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.d_model])
+
+
+class TFFunnelLayer(keras.layers.Layer):
+ def __init__(self, config, block_index, **kwargs):
+ super().__init__(**kwargs)
+ self.attention = TFFunnelRelMultiheadAttention(config, block_index, name="attention")
+ self.ffn = TFFunnelPositionwiseFFN(config, name="ffn")
+
+ def call(self, query, key, value, attention_inputs, output_attentions=False, training=False):
+ attn = self.attention(
+ query, key, value, attention_inputs, output_attentions=output_attentions, training=training
+ )
+ output = self.ffn(attn[0], training=training)
+ return (output, attn[1]) if output_attentions else (output,)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "ffn", None) is not None:
+ with tf.name_scope(self.ffn.name):
+ self.ffn.build(None)
+
+
+class TFFunnelEncoder(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ self.separate_cls = config.separate_cls
+ self.pool_q_only = config.pool_q_only
+ self.block_repeats = config.block_repeats
+ self.attention_structure = TFFunnelAttentionStructure(config)
+ self.blocks = [
+ [TFFunnelLayer(config, block_index, name=f"blocks_._{block_index}_._{i}") for i in range(block_size)]
+ for block_index, block_size in enumerate(config.block_sizes)
+ ]
+
+ def call(
+ self,
+ inputs_embeds,
+ attention_mask=None,
+ token_type_ids=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ training=False,
+ ):
+ # The pooling is not implemented on long tensors, so we convert this mask.
+ # attention_mask = tf.cast(attention_mask, inputs_embeds.dtype)
+ attention_inputs = self.attention_structure.init_attention_inputs(
+ inputs_embeds,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ training=training,
+ )
+ hidden = inputs_embeds
+
+ all_hidden_states = (inputs_embeds,) if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ for block_index, block in enumerate(self.blocks):
+ pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1)
+ pooling_flag = pooling_flag and block_index > 0
+ pooled_hidden = tf.zeros(shape_list(hidden))
+
+ if pooling_flag:
+ pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling(
+ hidden, attention_inputs
+ )
+
+ for layer_index, layer in enumerate(block):
+ for repeat_index in range(self.block_repeats[block_index]):
+ do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag
+ if do_pooling:
+ query = pooled_hidden
+ key = value = hidden if self.pool_q_only else pooled_hidden
+ else:
+ query = key = value = hidden
+ layer_output = layer(
+ query, key, value, attention_inputs, output_attentions=output_attentions, training=training
+ )
+ hidden = layer_output[0]
+ if do_pooling:
+ attention_inputs = self.attention_structure.post_attention_pooling(attention_inputs)
+
+ if output_attentions:
+ all_attentions = all_attentions + layer_output[1:]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden, all_hidden_states, all_attentions] if v is not None)
+ return TFBaseModelOutput(last_hidden_state=hidden, hidden_states=all_hidden_states, attentions=all_attentions)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ for block in self.blocks:
+ for layer in block:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+def upsample(x, stride, target_len, separate_cls=True, truncate_seq=False):
+ """
+ Upsample tensor `x` to match `target_len` by repeating the tokens `stride` time on the sequence length dimension.
+ """
+ if stride == 1:
+ return x
+ if separate_cls:
+ cls = x[:, :1]
+ x = x[:, 1:]
+ output = tf.repeat(x, repeats=stride, axis=1)
+ if separate_cls:
+ if truncate_seq:
+ output = tf.pad(output, [[0, 0], [0, stride - 1], [0, 0]])
+ output = output[:, : target_len - 1]
+ output = tf.concat([cls, output], axis=1)
+ else:
+ output = output[:, :target_len]
+ return output
+
+
+class TFFunnelDecoder(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ self.separate_cls = config.separate_cls
+ self.truncate_seq = config.truncate_seq
+ self.stride = 2 ** (len(config.block_sizes) - 1)
+ self.attention_structure = TFFunnelAttentionStructure(config)
+ self.layers = [TFFunnelLayer(config, 0, name=f"layers_._{i}") for i in range(config.num_decoder_layers)]
+
+ def call(
+ self,
+ final_hidden,
+ first_block_hidden,
+ attention_mask=None,
+ token_type_ids=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ training=False,
+ ):
+ upsampled_hidden = upsample(
+ final_hidden,
+ stride=self.stride,
+ target_len=shape_list(first_block_hidden)[1],
+ separate_cls=self.separate_cls,
+ truncate_seq=self.truncate_seq,
+ )
+
+ hidden = upsampled_hidden + first_block_hidden
+ all_hidden_states = (hidden,) if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ attention_inputs = self.attention_structure.init_attention_inputs(
+ hidden,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ training=training,
+ )
+
+ for layer in self.layers:
+ layer_output = layer(
+ hidden, hidden, hidden, attention_inputs, output_attentions=output_attentions, training=training
+ )
+ hidden = layer_output[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + layer_output[1:]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden, all_hidden_states, all_attentions] if v is not None)
+ return TFBaseModelOutput(last_hidden_state=hidden, hidden_states=all_hidden_states, attentions=all_attentions)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layers", None) is not None:
+ for layer in self.layers:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+@keras_serializable
+class TFFunnelBaseLayer(keras.layers.Layer):
+ """Base model without decoder"""
+
+ config_class = FunnelConfig
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.output_attentions = config.output_attentions
+ self.output_hidden_states = config.output_hidden_states
+ self.return_dict = config.use_return_dict
+
+ self.embeddings = TFFunnelEmbeddings(config, name="embeddings")
+ self.encoder = TFFunnelEncoder(config, name="encoder")
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.weight = value
+ self.embeddings.vocab_size = shape_list(value)[0]
+
+ def _prune_heads(self, heads_to_prune):
+ raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ training=False,
+ ):
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if attention_mask is None:
+ attention_mask = tf.fill(input_shape, 1)
+
+ if token_type_ids is None:
+ token_type_ids = tf.fill(input_shape, 0)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embeddings(input_ids, training=training)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return encoder_outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+
+
+@keras_serializable
+class TFFunnelMainLayer(keras.layers.Layer):
+ """Base model with decoder"""
+
+ config_class = FunnelConfig
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.block_sizes = config.block_sizes
+ self.output_attentions = config.output_attentions
+ self.output_hidden_states = config.output_hidden_states
+ self.return_dict = config.use_return_dict
+
+ self.embeddings = TFFunnelEmbeddings(config, name="embeddings")
+ self.encoder = TFFunnelEncoder(config, name="encoder")
+ self.decoder = TFFunnelDecoder(config, name="decoder")
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.weight = value
+ self.embeddings.vocab_size = shape_list(value)[0]
+
+ def _prune_heads(self, heads_to_prune):
+ raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ training=False,
+ ):
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if attention_mask is None:
+ attention_mask = tf.fill(input_shape, 1)
+
+ if token_type_ids is None:
+ token_type_ids = tf.fill(input_shape, 0)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embeddings(input_ids, training=training)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ decoder_outputs = self.decoder(
+ final_hidden=encoder_outputs[0],
+ first_block_hidden=encoder_outputs[1][self.block_sizes[0]],
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ idx = 0
+ outputs = (decoder_outputs[0],)
+ if output_hidden_states:
+ idx += 1
+ outputs = outputs + (encoder_outputs[1] + decoder_outputs[idx],)
+ if output_attentions:
+ idx += 1
+ outputs = outputs + (encoder_outputs[2] + decoder_outputs[idx],)
+ return outputs
+
+ return TFBaseModelOutput(
+ last_hidden_state=decoder_outputs[0],
+ hidden_states=(encoder_outputs.hidden_states + decoder_outputs.hidden_states)
+ if output_hidden_states
+ else None,
+ attentions=(encoder_outputs.attentions + decoder_outputs.attentions) if output_attentions else None,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "decoder", None) is not None:
+ with tf.name_scope(self.decoder.name):
+ self.decoder.build(None)
+
+
+class TFFunnelDiscriminatorPredictions(keras.layers.Layer):
+ """Prediction module for the discriminator, made up of two dense layers."""
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ initializer = get_initializer(config.initializer_range)
+ self.dense = keras.layers.Dense(config.d_model, kernel_initializer=initializer, name="dense")
+ self.activation_function = get_tf_activation(config.hidden_act)
+ self.dense_prediction = keras.layers.Dense(1, kernel_initializer=initializer, name="dense_prediction")
+ self.config = config
+
+ def call(self, discriminator_hidden_states):
+ hidden_states = self.dense(discriminator_hidden_states)
+ hidden_states = self.activation_function(hidden_states)
+ logits = tf.squeeze(self.dense_prediction(hidden_states))
+ return logits
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.d_model])
+ if getattr(self, "dense_prediction", None) is not None:
+ with tf.name_scope(self.dense_prediction.name):
+ self.dense_prediction.build([None, None, self.config.d_model])
+
+
+class TFFunnelMaskedLMHead(keras.layers.Layer):
+ def __init__(self, config, input_embeddings, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.input_embeddings = input_embeddings
+
+ def build(self, input_shape):
+ self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
+
+ super().build(input_shape)
+
+ def get_output_embeddings(self):
+ return self.input_embeddings
+
+ def set_output_embeddings(self, value):
+ self.input_embeddings.weight = value
+ self.input_embeddings.vocab_size = shape_list(value)[0]
+
+ def get_bias(self):
+ return {"bias": self.bias}
+
+ def set_bias(self, value):
+ self.bias = value["bias"]
+ self.config.vocab_size = shape_list(value["bias"])[0]
+
+ def call(self, hidden_states, training=False):
+ seq_length = shape_list(tensor=hidden_states)[1]
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])
+ hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
+ hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
+
+ return hidden_states
+
+
+class TFFunnelClassificationHead(keras.layers.Layer):
+ def __init__(self, config, n_labels, **kwargs):
+ super().__init__(**kwargs)
+ initializer = get_initializer(config.initializer_range)
+ self.linear_hidden = keras.layers.Dense(config.d_model, kernel_initializer=initializer, name="linear_hidden")
+ self.dropout = keras.layers.Dropout(config.hidden_dropout)
+ self.linear_out = keras.layers.Dense(n_labels, kernel_initializer=initializer, name="linear_out")
+ self.config = config
+
+ def call(self, hidden, training=False):
+ hidden = self.linear_hidden(hidden)
+ hidden = keras.activations.tanh(hidden)
+ hidden = self.dropout(hidden, training=training)
+ return self.linear_out(hidden)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "linear_hidden", None) is not None:
+ with tf.name_scope(self.linear_hidden.name):
+ self.linear_hidden.build([None, None, self.config.d_model])
+ if getattr(self, "linear_out", None) is not None:
+ with tf.name_scope(self.linear_out.name):
+ self.linear_out.build([None, None, self.config.d_model])
+
+
+class TFFunnelPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = FunnelConfig
+ base_model_prefix = "funnel"
+
+ @property
+ def dummy_inputs(self):
+ # Funnel misbehaves with very small inputs, so we override and make them a bit bigger
+ return {"input_ids": tf.ones((1, 3), dtype=tf.int32)}
+
+
+@dataclass
+class TFFunnelForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`FunnelForPreTraining`].
+
+ Args:
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Prediction scores of the head (scores for each token before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+FUNNEL_START_DOCSTRING = r"""
+
+ The Funnel Transformer model was proposed in [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient
+ Language Processing](https://huggingface.co/papers/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`XxxConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+FUNNEL_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ """
+ The base Funnel Transformer Model transformer outputting raw hidden-states without upsampling head (also called
+ decoder) or any task-specific head on top.
+ """,
+ FUNNEL_START_DOCSTRING,
+)
+class TFFunnelBaseModel(TFFunnelPreTrainedModel):
+ def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.funnel = TFFunnelBaseLayer(config, name="funnel")
+
+ @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="funnel-transformer/small-base",
+ output_type=TFBaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> tuple[tf.Tensor] | TFBaseModelOutput:
+ return self.funnel(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ def serving_output(self, output):
+ # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
+ # different dimensions
+ return TFBaseModelOutput(
+ last_hidden_state=output.last_hidden_state,
+ hidden_states=output.hidden_states,
+ attentions=output.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "funnel", None) is not None:
+ with tf.name_scope(self.funnel.name):
+ self.funnel.build(None)
+
+
+@add_start_docstrings(
+ "The bare Funnel Transformer Model transformer outputting raw hidden-states without any specific head on top.",
+ FUNNEL_START_DOCSTRING,
+)
+class TFFunnelModel(TFFunnelPreTrainedModel):
+ def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.funnel = TFFunnelMainLayer(config, name="funnel")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="funnel-transformer/small",
+ output_type=TFBaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> tuple[tf.Tensor] | TFBaseModelOutput:
+ return self.funnel(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ def serving_output(self, output):
+ # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
+ # different dimensions
+ return TFBaseModelOutput(
+ last_hidden_state=output.last_hidden_state,
+ hidden_states=output.hidden_states,
+ attentions=output.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "funnel", None) is not None:
+ with tf.name_scope(self.funnel.name):
+ self.funnel.build(None)
+
+
+@add_start_docstrings(
+ """
+ Funnel model with a binary classification head on top as used during pretraining for identifying generated tokens.
+ """,
+ FUNNEL_START_DOCSTRING,
+)
+class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
+ def __init__(self, config: FunnelConfig, **kwargs) -> None:
+ super().__init__(config, **kwargs)
+
+ self.funnel = TFFunnelMainLayer(config, name="funnel")
+ self.discriminator_predictions = TFFunnelDiscriminatorPredictions(config, name="discriminator_predictions")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TFFunnelForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ **kwargs,
+ ) -> tuple[tf.Tensor] | TFFunnelForPreTrainingOutput:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, TFFunnelForPreTraining
+ >>> import torch
+ from ...utils.deprecation import deprecate_kwarg
+ from ...utils.deprecation import deprecate_kwarg
+ from ...utils.deprecation import deprecate_kwarg
+ from ...utils.deprecation import deprecate_kwarg
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("funnel-transformer/small")
+ >>> model = TFFunnelForPreTraining.from_pretrained("funnel-transformer/small")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
+ >>> logits = model(inputs).logits
+ ```"""
+ discriminator_hidden_states = self.funnel(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ inputs_embeds,
+ output_attentions,
+ output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ discriminator_sequence_output = discriminator_hidden_states[0]
+ logits = self.discriminator_predictions(discriminator_sequence_output)
+
+ if not return_dict:
+ return (logits,) + discriminator_hidden_states[1:]
+
+ return TFFunnelForPreTrainingOutput(
+ logits=logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+ def serving_output(self, output):
+ # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
+ # different dimensions
+ return TFFunnelForPreTrainingOutput(
+ logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "funnel", None) is not None:
+ with tf.name_scope(self.funnel.name):
+ self.funnel.build(None)
+ if getattr(self, "discriminator_predictions", None) is not None:
+ with tf.name_scope(self.discriminator_predictions.name):
+ self.discriminator_predictions.build(None)
+
+
+@add_start_docstrings("""Funnel Model with a `language modeling` head on top.""", FUNNEL_START_DOCSTRING)
+class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss):
+ def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+
+ self.funnel = TFFunnelMainLayer(config, name="funnel")
+ self.lm_head = TFFunnelMaskedLMHead(config, self.funnel.embeddings, name="lm_head")
+
+ def get_lm_head(self) -> TFFunnelMaskedLMHead:
+ return self.lm_head
+
+ def get_prefix_bias_name(self) -> str:
+ warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
+ return self.name + "/" + self.lm_head.name
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="funnel-transformer/small",
+ output_type=TFMaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool = False,
+ ) -> tuple[tf.Tensor] | TFMaskedLMOutput:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+ outputs = self.funnel(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ inputs_embeds,
+ output_attentions,
+ output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ prediction_scores = self.lm_head(sequence_output, training=training)
+
+ loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFMaskedLMOutput(
+ loss=loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
+ # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
+ # different dimensions
+ return TFMaskedLMOutput(logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "funnel", None) is not None:
+ with tf.name_scope(self.funnel.name):
+ self.funnel.build(None)
+ if getattr(self, "lm_head", None) is not None:
+ with tf.name_scope(self.lm_head.name):
+ self.lm_head.build(None)
+
+
+@add_start_docstrings(
+ """
+ Funnel Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """,
+ FUNNEL_START_DOCSTRING,
+)
+class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+
+ self.funnel = TFFunnelBaseLayer(config, name="funnel")
+ self.classifier = TFFunnelClassificationHead(config, config.num_labels, name="classifier")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="funnel-transformer/small-base",
+ output_type=TFSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool = False,
+ ) -> tuple[tf.Tensor] | TFSequenceClassifierOutput:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ outputs = self.funnel(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ inputs_embeds,
+ output_attentions,
+ output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ last_hidden_state = outputs[0]
+ pooled_output = last_hidden_state[:, 0]
+ logits = self.classifier(pooled_output, training=training)
+
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
+ # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
+ # different dimensions
+ return TFSequenceClassifierOutput(
+ logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "funnel", None) is not None:
+ with tf.name_scope(self.funnel.name):
+ self.funnel.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build(None)
+
+
+@add_start_docstrings(
+ """
+ Funnel Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ FUNNEL_START_DOCSTRING,
+)
+class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
+ def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+
+ self.funnel = TFFunnelBaseLayer(config, name="funnel")
+ self.classifier = TFFunnelClassificationHead(config, 1, name="classifier")
+
+ @property
+ def dummy_inputs(self):
+ return {"input_ids": tf.ones((3, 3, 4), dtype=tf.int32)}
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="funnel-transformer/small-base",
+ output_type=TFMultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool = False,
+ ) -> tuple[tf.Tensor] | TFMultipleChoiceModelOutput:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
+ where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
+ """
+ if input_ids is not None:
+ num_choices = shape_list(input_ids)[1]
+ seq_length = shape_list(input_ids)[2]
+ else:
+ num_choices = shape_list(inputs_embeds)[1]
+ seq_length = shape_list(inputs_embeds)[2]
+
+ flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
+ flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
+ flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
+ flat_inputs_embeds = (
+ tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.funnel(
+ flat_input_ids,
+ attention_mask=flat_attention_mask,
+ token_type_ids=flat_token_type_ids,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ last_hidden_state = outputs[0]
+ pooled_output = last_hidden_state[:, 0]
+ logits = self.classifier(pooled_output, training=training)
+ reshaped_logits = tf.reshape(logits, (-1, num_choices))
+
+ loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFMultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
+ # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
+ # different dimensions
+ return TFMultipleChoiceModelOutput(
+ logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "funnel", None) is not None:
+ with tf.name_scope(self.funnel.name):
+ self.funnel.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build(None)
+
+
+@add_start_docstrings(
+ """
+ Funnel Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ FUNNEL_START_DOCSTRING,
+)
+class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificationLoss):
+ def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+
+ self.funnel = TFFunnelMainLayer(config, name="funnel")
+ self.dropout = keras.layers.Dropout(config.hidden_dropout)
+ self.classifier = keras.layers.Dense(
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="funnel-transformer/small",
+ output_type=TFTokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool = False,
+ ) -> tuple[tf.Tensor] | TFTokenClassifierOutput:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ outputs = self.funnel(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ inputs_embeds,
+ output_attentions,
+ output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output, training=training)
+ logits = self.classifier(sequence_output)
+
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFTokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
+ # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
+ # different dimensions
+ return TFTokenClassifierOutput(
+ logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "funnel", None) is not None:
+ with tf.name_scope(self.funnel.name):
+ self.funnel.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ Funnel Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ FUNNEL_START_DOCSTRING,
+)
+class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringLoss):
+ def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+
+ self.funnel = TFFunnelMainLayer(config, name="funnel")
+ self.qa_outputs = keras.layers.Dense(
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="funnel-transformer/small",
+ output_type=TFQuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ start_positions: np.ndarray | tf.Tensor | None = None,
+ end_positions: np.ndarray | tf.Tensor | None = None,
+ training: bool = False,
+ ) -> tuple[tf.Tensor] | TFQuestionAnsweringModelOutput:
+ r"""
+ start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+
+ outputs = self.funnel(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ inputs_embeds,
+ output_attentions,
+ output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = tf.split(logits, 2, axis=-1)
+ start_logits = tf.squeeze(start_logits, axis=-1)
+ end_logits = tf.squeeze(end_logits, axis=-1)
+
+ loss = None
+ if start_positions is not None and end_positions is not None:
+ labels = {"start_position": start_positions, "end_position": end_positions}
+ loss = self.hf_compute_loss(labels, (start_logits, end_logits))
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFQuestionAnsweringModelOutput(
+ loss=loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
+ # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
+ # different dimensions
+ return TFQuestionAnsweringModelOutput(
+ start_logits=output.start_logits,
+ end_logits=output.end_logits,
+ hidden_states=output.hidden_states,
+ attentions=output.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "funnel", None) is not None:
+ with tf.name_scope(self.funnel.name):
+ self.funnel.build(None)
+ if getattr(self, "qa_outputs", None) is not None:
+ with tf.name_scope(self.qa_outputs.name):
+ self.qa_outputs.build([None, None, self.config.hidden_size])
+
+
+__all__ = [
+ "TFFunnelBaseModel",
+ "TFFunnelForMaskedLM",
+ "TFFunnelForMultipleChoice",
+ "TFFunnelForPreTraining",
+ "TFFunnelForQuestionAnswering",
+ "TFFunnelForSequenceClassification",
+ "TFFunnelForTokenClassification",
+ "TFFunnelModel",
+ "TFFunnelPreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/funnel/tokenization_funnel.py b/venv/lib/python3.13/site-packages/transformers/models/funnel/tokenization_funnel.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5d44e5e59064315ca330b6d9d7d0ffd04c59b12
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/funnel/tokenization_funnel.py
@@ -0,0 +1,542 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for Funnel Transformer."""
+
+import collections
+import os
+import unicodedata
+from typing import Optional
+
+from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+_model_names = [
+ "small",
+ "small-base",
+ "medium",
+ "medium-base",
+ "intermediate",
+ "intermediate-base",
+ "large",
+ "large-base",
+ "xlarge",
+ "xlarge-base",
+]
+
+
+# Copied from transformers.models.bert.tokenization_bert.load_vocab
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class FunnelTokenizer(PreTrainedTokenizer):
+ r"""
+ Construct a Funnel Transformer tokenizer. Based on WordPiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ do_basic_tokenize (`bool`, *optional*, defaults to `True`):
+ Whether or not to do basic tokenization before WordPiece.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sentence token.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sentence token.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
+ Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
+ extra spaces.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ cls_token_type_id: int = 2
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=True,
+ do_basic_tokenize=True,
+ never_split=None,
+ unk_token="",
+ sep_token="",
+ pad_token="",
+ cls_token="",
+ mask_token="",
+ bos_token="",
+ eos_token="",
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ clean_up_tokenization_spaces=True,
+ **kwargs,
+ ):
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = FunnelTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
+ self.do_basic_tokenize = do_basic_tokenize
+ if do_basic_tokenize:
+ self.basic_tokenizer = BasicTokenizer(
+ do_lower_case=do_lower_case,
+ never_split=never_split,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ )
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
+
+ super().__init__(
+ do_lower_case=do_lower_case,
+ do_basic_tokenize=do_basic_tokenize,
+ never_split=never_split,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ @property
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.do_lower_case
+ def do_lower_case(self):
+ return self.basic_tokenizer.do_lower_case
+
+ @property
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size
+ def vocab_size(self):
+ return len(self.vocab)
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab
+ def get_vocab(self):
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
+ def _tokenize(self, text, split_special_tokens=False):
+ split_tokens = []
+ if self.do_basic_tokenize:
+ for token in self.basic_tokenizer.tokenize(
+ text, never_split=self.all_special_tokens if not split_special_tokens else None
+ ):
+ # If the token is part of the never_split set
+ if token in self.basic_tokenizer.never_split:
+ split_tokens.append(token)
+ else:
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
+ else:
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
+ return split_tokens
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.ids_to_tokens.get(index, self.unk_token)
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ out_string = " ".join(tokens).replace(" ##", "").strip()
+ return out_string
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.build_inputs_with_special_tokens
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A Funnel
+ Transformer sequence pair mask has the following format:
+
+ ```
+ 2 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls) * [self.cls_token_type_id] + len(token_ids_0 + sep) * [0]
+ return len(cls) * [self.cls_token_type_id] + len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ index = 0
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ else:
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!"
+ )
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+ return (vocab_file,)
+
+
+# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
+class BasicTokenizer:
+ """
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
+
+ Args:
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ do_split_on_punc (`bool`, *optional*, defaults to `True`):
+ In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
+ the full context of the words, such as contractions.
+ """
+
+ def __init__(
+ self,
+ do_lower_case=True,
+ never_split=None,
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ do_split_on_punc=True,
+ ):
+ if never_split is None:
+ never_split = []
+ self.do_lower_case = do_lower_case
+ self.never_split = set(never_split)
+ self.tokenize_chinese_chars = tokenize_chinese_chars
+ self.strip_accents = strip_accents
+ self.do_split_on_punc = do_split_on_punc
+
+ def tokenize(self, text, never_split=None):
+ """
+ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
+
+ Args:
+ never_split (`List[str]`, *optional*)
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
+ """
+ # union() returns a new set by concatenating the two sets.
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
+ text = self._clean_text(text)
+
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ if self.tokenize_chinese_chars:
+ text = self._tokenize_chinese_chars(text)
+ # prevents treating the same character with different unicode codepoints as different characters
+ unicode_normalized_text = unicodedata.normalize("NFC", text)
+ orig_tokens = whitespace_tokenize(unicode_normalized_text)
+ split_tokens = []
+ for token in orig_tokens:
+ if token not in never_split:
+ if self.do_lower_case:
+ token = token.lower()
+ if self.strip_accents is not False:
+ token = self._run_strip_accents(token)
+ elif self.strip_accents:
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text, never_split=None):
+ """Splits punctuation on a piece of text."""
+ if not self.do_split_on_punc or (never_split is not None and text in never_split):
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if (
+ (cp >= 0x4E00 and cp <= 0x9FFF)
+ or (cp >= 0x3400 and cp <= 0x4DBF)
+ or (cp >= 0x20000 and cp <= 0x2A6DF)
+ or (cp >= 0x2A700 and cp <= 0x2B73F)
+ or (cp >= 0x2B740 and cp <= 0x2B81F)
+ or (cp >= 0x2B820 and cp <= 0x2CEAF)
+ or (cp >= 0xF900 and cp <= 0xFAFF)
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
+ ):
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
+class WordpieceTokenizer:
+ """Runs WordPiece tokenization."""
+
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
+ tokenization using the given vocabulary.
+
+ For example, `input = "unaffable"` will return as output `["un", "##aff", "##able"]`.
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through *BasicTokenizer*.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ return output_tokens
+
+
+__all__ = ["FunnelTokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/funnel/tokenization_funnel_fast.py b/venv/lib/python3.13/site-packages/transformers/models/funnel/tokenization_funnel_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..eeeb6f7bf6cb0640ee04bb01737331ba4be1233b
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/funnel/tokenization_funnel_fast.py
@@ -0,0 +1,203 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for Funnel Transformer."""
+
+import json
+from typing import Optional
+
+from tokenizers import normalizers
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+from .tokenization_funnel import FunnelTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
+
+_model_names = [
+ "small",
+ "small-base",
+ "medium",
+ "medium-base",
+ "intermediate",
+ "intermediate-base",
+ "large",
+ "large-base",
+ "xlarge",
+ "xlarge-base",
+]
+
+
+class FunnelTokenizerFast(PreTrainedTokenizerFast):
+ r"""
+ Construct a "fast" Funnel Transformer tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ clean_text (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the text before tokenization by removing any control characters and replacing all
+ whitespaces by the classic one.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this
+ issue](https://github.com/huggingface/transformers/issues/328)).
+ bos_token (`str`, `optional`, defaults to `""`):
+ The beginning of sentence token.
+ eos_token (`str`, `optional`, defaults to `""`):
+ The end of sentence token.
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ wordpieces_prefix (`str`, *optional*, defaults to `"##"`):
+ The prefix for subwords.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = FunnelTokenizer
+ cls_token_type_id: int = 2
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ do_lower_case=True,
+ unk_token="",
+ sep_token="",
+ pad_token="",
+ cls_token="",
+ mask_token="",
+ bos_token="",
+ eos_token="",
+ clean_text=True,
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ wordpieces_prefix="##",
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file,
+ tokenizer_file=tokenizer_file,
+ do_lower_case=do_lower_case,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ clean_text=clean_text,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ wordpieces_prefix=wordpieces_prefix,
+ **kwargs,
+ )
+
+ normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
+ if (
+ normalizer_state.get("lowercase", do_lower_case) != do_lower_case
+ or normalizer_state.get("strip_accents", strip_accents) != strip_accents
+ or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
+ ):
+ normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
+ normalizer_state["lowercase"] = do_lower_case
+ normalizer_state["strip_accents"] = strip_accents
+ normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
+ self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
+
+ self.do_lower_case = do_lower_case
+
+ # Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast.build_inputs_with_special_tokens with BERT->Funnel
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A Funnel sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+
+ if token_ids_1 is not None:
+ output += token_ids_1 + [self.sep_token_id]
+
+ return output
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A Funnel
+ Transformer sequence pair mask has the following format:
+
+ ```
+ 2 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls) * [self.cls_token_type_id] + len(token_ids_0 + sep) * [0]
+ return len(cls) * [self.cls_token_type_id] + len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ # Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast.save_vocabulary
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+
+__all__ = ["FunnelTokenizerFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gemma2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/gemma2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..18905bac42cc6b19f21e069355504e46d070d814
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gemma2/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_gemma2 import *
+ from .modeling_gemma2 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gemma2/configuration_gemma2.py b/venv/lib/python3.13/site-packages/transformers/models/gemma2/configuration_gemma2.py
new file mode 100644
index 0000000000000000000000000000000000000000..d43ec4c47371411c0b15b34afe4d82456d4f0cf8
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gemma2/configuration_gemma2.py
@@ -0,0 +1,182 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/gemma2/modular_gemma2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_gemma2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+
+
+class Gemma2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Gemma2-7B.
+ e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b)
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ vocab_size (`int`, *optional*, defaults to 256000):
+ Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Gemma2Model`]
+ hidden_size (`int`, *optional*, defaults to 2304):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 9216):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 26):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 4):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ head_dim (`int`, *optional*, defaults to 256):
+ The attention head dimension.
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
+ if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
+ max_position_embeddings (`int`, *optional*, defaults to 8192):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 1):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 2):
+ Beginning of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ query_pre_attn_scalar (`float`, *optional*, defaults to 256):
+ scaling factor used on the attention scores
+ sliding_window (`int`, *optional*, defaults to 4096):
+ in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window.
+ layer_types (`list`, *optional*):
+ Attention pattern for each layer.
+ final_logit_softcapping (`float`, *optional*, defaults to 30.0):
+ scaling factor when applying tanh softcapping on the logits.
+ attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
+ scaling factor when applying tanh softcapping on the attention scores.
+
+ ```python
+ >>> from transformers import Gemma2Model, Gemma2Config
+ >>> # Initializing a Gemma2 gemma2-7b style configuration
+ >>> configuration = Gemma2Config()
+ >>> # Initializing a model from the gemma2-7b style configuration
+ >>> model = Gemma2Model(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "gemma2"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=256000,
+ hidden_size=2304,
+ intermediate_size=9216,
+ num_hidden_layers=26,
+ num_attention_heads=8,
+ num_key_value_heads=4,
+ head_dim=256,
+ hidden_activation="gelu_pytorch_tanh",
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=0,
+ eos_token_id=1,
+ bos_token_id=2,
+ tie_word_embeddings=True,
+ rope_theta=10000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ query_pre_attn_scalar=256,
+ sliding_window=4096,
+ layer_types=None,
+ final_logit_softcapping=30.0,
+ attn_logit_softcapping=50.0,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ self.num_key_value_heads = num_key_value_heads
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.hidden_activation = hidden_activation
+ self.query_pre_attn_scalar = query_pre_attn_scalar
+ self.sliding_window = sliding_window
+ self.final_logit_softcapping = final_logit_softcapping
+ self.attn_logit_softcapping = attn_logit_softcapping
+ self.layer_types = layer_types
+
+ if self.layer_types is None:
+ self.layer_types = [
+ "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+
+__all__ = ["Gemma2Config"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py b/venv/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py
new file mode 100644
index 0000000000000000000000000000000000000000..b53e271cd8f0a1408b7be5dbc7ae6759e6a5f3cc
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py
@@ -0,0 +1,598 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/gemma2/modular_gemma2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_gemma2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import (
+ GenericForSequenceClassification,
+ GenericForTokenClassification,
+ GradientCheckpointingLayer,
+)
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_gemma2 import Gemma2Config
+
+
+logger = logging.get_logger(__name__)
+
+
+class Gemma2RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.zeros(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ output = self._norm(x.float())
+ # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
+ # See https://github.com/huggingface/transformers/pull/29402
+ output = output * (1.0 + self.weight.float())
+ return output.type_as(x)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
+
+
+class Gemma2MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_activation]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class Gemma2RotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Gemma2Config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ dropout: float = 0.0,
+ scaling: Optional[float] = None,
+ softcap: Optional[float] = None,
+ **kwargs,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ if scaling is None:
+ scaling = module.head_dim**-0.5
+
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+
+ if softcap is not None:
+ attn_weights = attn_weights / softcap
+ attn_weights = torch.tanh(attn_weights)
+ attn_weights = attn_weights * softcap
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, attn_weights
+
+
+class Gemma2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Gemma2Config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = config.query_pre_attn_scalar**-0.5
+ self.attention_dropout = self.config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.attn_logit_softcapping = self.config.attn_logit_softcapping
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=self.attention_dropout if self.training else 0.0,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ softcap=self.attn_logit_softcapping,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Gemma2DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Gemma2Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.config = config
+ self.attention_type = config.layer_types[layer_idx]
+ self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
+ self.mlp = Gemma2MLP(config)
+ self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.pre_feedforward_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class Gemma2PreTrainedModel(PreTrainedModel):
+ config: Gemma2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Gemma2DecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": Gemma2DecoderLayer,
+ "attentions": Gemma2Attention,
+ }
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+
+ # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
+ if "RMSNorm" in module.__class__.__name__:
+ module.weight.data.zero_()
+
+
+@auto_docstring
+class Gemma2Model(Gemma2PreTrainedModel):
+ def __init__(self, config: Gemma2Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Gemma2RotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None and not self.training:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ # Prepare mask arguments
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ # Create the masks
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # normalized
+ # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
+ # See https://github.com/huggingface/transformers/pull/29402
+ normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
+ hidden_states = hidden_states * normalizer
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+@auto_docstring
+class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Gemma2Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Gemma2ForCausalLM
+
+ >>> model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
+
+ >>> prompt = "What is your favorite condiment?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "What is your favorite condiment?"
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ if self.config.final_logit_softcapping is not None:
+ logits = logits / self.config.final_logit_softcapping
+ logits = torch.tanh(logits)
+ logits = logits * self.config.final_logit_softcapping
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class Gemma2ForSequenceClassification(GenericForSequenceClassification, Gemma2PreTrainedModel):
+ pass
+
+
+class Gemma2ForTokenClassification(GenericForTokenClassification, Gemma2PreTrainedModel):
+ pass
+
+
+__all__ = [
+ "Gemma2ForCausalLM",
+ "Gemma2Model",
+ "Gemma2PreTrainedModel",
+ "Gemma2ForSequenceClassification",
+ "Gemma2ForTokenClassification",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gemma2/modular_gemma2.py b/venv/lib/python3.13/site-packages/transformers/models/gemma2/modular_gemma2.py
new file mode 100644
index 0000000000000000000000000000000000000000..e54795019c7f0a9cf356f3c186a5fd2735eedf2c
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gemma2/modular_gemma2.py
@@ -0,0 +1,587 @@
+# coding=utf-8
+# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, logging
+from ...utils.deprecation import deprecate_kwarg
+from ..gemma.modeling_gemma import (
+ GemmaAttention,
+ GemmaForCausalLM,
+ GemmaForSequenceClassification,
+ GemmaForTokenClassification,
+ GemmaMLP,
+ GemmaModel,
+ GemmaPreTrainedModel,
+ GemmaRMSNorm,
+ GemmaRotaryEmbedding,
+ apply_rotary_pos_emb,
+ repeat_kv,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class Gemma2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Gemma2-7B.
+ e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b)
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ vocab_size (`int`, *optional*, defaults to 256000):
+ Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Gemma2Model`]
+ hidden_size (`int`, *optional*, defaults to 2304):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 9216):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 26):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 4):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ head_dim (`int`, *optional*, defaults to 256):
+ The attention head dimension.
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
+ if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
+ max_position_embeddings (`int`, *optional*, defaults to 8192):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 1):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 2):
+ Beginning of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ query_pre_attn_scalar (`float`, *optional*, defaults to 256):
+ scaling factor used on the attention scores
+ sliding_window (`int`, *optional*, defaults to 4096):
+ in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window.
+ layer_types (`list`, *optional*):
+ Attention pattern for each layer.
+ final_logit_softcapping (`float`, *optional*, defaults to 30.0):
+ scaling factor when applying tanh softcapping on the logits.
+ attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
+ scaling factor when applying tanh softcapping on the attention scores.
+
+ ```python
+ >>> from transformers import Gemma2Model, Gemma2Config
+ >>> # Initializing a Gemma2 gemma2-7b style configuration
+ >>> configuration = Gemma2Config()
+ >>> # Initializing a model from the gemma2-7b style configuration
+ >>> model = Gemma2Model(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "gemma2"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=256000,
+ hidden_size=2304,
+ intermediate_size=9216,
+ num_hidden_layers=26,
+ num_attention_heads=8,
+ num_key_value_heads=4,
+ head_dim=256,
+ hidden_activation="gelu_pytorch_tanh",
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=0,
+ eos_token_id=1,
+ bos_token_id=2,
+ tie_word_embeddings=True,
+ rope_theta=10000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ query_pre_attn_scalar=256,
+ sliding_window=4096,
+ layer_types=None,
+ final_logit_softcapping=30.0,
+ attn_logit_softcapping=50.0,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ self.num_key_value_heads = num_key_value_heads
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.hidden_activation = hidden_activation
+ self.query_pre_attn_scalar = query_pre_attn_scalar
+ self.sliding_window = sliding_window
+ self.final_logit_softcapping = final_logit_softcapping
+ self.attn_logit_softcapping = attn_logit_softcapping
+ self.layer_types = layer_types
+
+ if self.layer_types is None:
+ self.layer_types = [
+ "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
+ ]
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+
+class Gemma2RMSNorm(GemmaRMSNorm):
+ pass
+
+
+class Gemma2MLP(GemmaMLP):
+ def __init__(self, config):
+ super().__init__(config)
+ self.act_fn = ACT2FN[config.hidden_activation]
+
+
+class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
+ pass
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ dropout: float = 0.0,
+ scaling: Optional[float] = None,
+ softcap: Optional[float] = None,
+ **kwargs,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ if scaling is None:
+ scaling = module.head_dim**-0.5
+
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+
+ if softcap is not None:
+ attn_weights = attn_weights / softcap
+ attn_weights = torch.tanh(attn_weights)
+ attn_weights = attn_weights * softcap
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, attn_weights
+
+
+class Gemma2Attention(GemmaAttention):
+ def __init__(self, config: Gemma2Config, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.attn_logit_softcapping = self.config.attn_logit_softcapping
+ self.attention_dropout = self.config.attention_dropout
+ self.is_causal = True
+ self.scaling = config.query_pre_attn_scalar**-0.5
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=self.attention_dropout if self.training else 0.0,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ softcap=self.attn_logit_softcapping,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Gemma2DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Gemma2Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.config = config
+ self.attention_type = config.layer_types[layer_idx]
+ self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
+ self.mlp = Gemma2MLP(config)
+ self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.pre_feedforward_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+class Gemma2PreTrainedModel(GemmaPreTrainedModel):
+ pass
+
+
+class Gemma2Model(GemmaModel):
+ def __init__(self, config: Gemma2Config):
+ super().__init__(config)
+ self.layers = nn.ModuleList(
+ [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None and not self.training:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ # Prepare mask arguments
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ # Create the masks
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # normalized
+ # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
+ # See https://github.com/huggingface/transformers/pull/29402
+ normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
+ hidden_states = hidden_states * normalizer
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class Gemma2ForCausalLM(GemmaForCausalLM):
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Gemma2Model(config)
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Gemma2ForCausalLM
+
+ >>> model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
+
+ >>> prompt = "What is your favorite condiment?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "What is your favorite condiment?"
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ if self.config.final_logit_softcapping is not None:
+ logits = logits / self.config.final_logit_softcapping
+ logits = torch.tanh(logits)
+ logits = logits * self.config.final_logit_softcapping
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class Gemma2ForSequenceClassification(GemmaForSequenceClassification):
+ pass
+
+
+class Gemma2ForTokenClassification(GemmaForTokenClassification):
+ pass
+
+
+__all__ = [
+ "Gemma2Config",
+ "Gemma2ForCausalLM",
+ "Gemma2Model",
+ "Gemma2PreTrainedModel",
+ "Gemma2ForSequenceClassification",
+ "Gemma2ForTokenClassification",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gemma3n/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/gemma3n/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..229e91827036d0830593ea9294e232cffefbac7b
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gemma3n/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_gemma3n import *
+ from .feature_extraction_gemma3n import *
+ from .modeling_gemma3n import *
+ from .processing_gemma3n import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gemma3n/configuration_gemma3n.py b/venv/lib/python3.13/site-packages/transformers/models/gemma3n/configuration_gemma3n.py
new file mode 100644
index 0000000000000000000000000000000000000000..47b5b47d363034d4daf39065f2c778cb453799be
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gemma3n/configuration_gemma3n.py
@@ -0,0 +1,682 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_gemma3n.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections.abc import Sequence
+from typing import Any, Optional, Union
+
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...modeling_rope_utils import rope_config_validation
+from ...utils import is_timm_available, logging, requires_backends
+
+
+if is_timm_available():
+ from timm.data import ImageNetInfo, infer_imagenet_subset
+
+
+logger = logging.get_logger(__name__)
+
+
+class Gemma3nTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Gemma3nTextModel`]. It is used to instantiate an
+ Gemma3nTextModel model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.
+ [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
+
+ Configuration objects that inherit from [`Gemma3nTextConfig`] and can be used to control the model outputs. Read
+ the documentation from [`Gemma3nTextConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 262400):
+ Vocabulary size of the Gemma3nText model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`Gemma3nTextModel`]
+ vocab_size_per_layer_input (`int`, *optional*, defaults to 262144):
+ Vocabulary size of the per-layer text embeddings that augment the standard embeddings.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ hidden_size_per_layer_input (`int`, *optional*, defaults to 256):
+ Dimension of the hidden representations for per-layer emebeddings.
+ intermediate_size (`int` or `Sequence[int]`, *optional*, defaults to 16384):
+ Dimension of the MLP representations. MatFormer configurations may wish to provide a sequence of integers
+ to account for variable intermediate_size values across layers. In such cases,
+ `len(intermediate_size) == num_hidden_layers`.
+ num_hidden_layers (`int`, *optional*, defaults to 35):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 2):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout this
+ [paper](https://huggingface.co/papers/2305.13245). If not specified, will default to `num_attention_heads`.
+ head_dim (`int`, *optional*, defaults to 256):
+ The attention head dimension.
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the decoder. Will default to
+ `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"`
+ activation function.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 1):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 2):
+ Beginning of stream token id.
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings used in global attention.
+ NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we
+ recommend you to update this value accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ rope_local_base_freq (float, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings for local attention.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ sliding_window (`int`, *optional*, defaults to 512):
+ This is the size of the sliding window used by local attention layers.
+ layer_types (`Optional`, *optional*):
+ A sequence of strings defining the attention type for that layer as either "sliding_attention" or
+ "full_attention". If not provided, `layer_types` will de inferred from `num_hidden_layers` using a pattern
+ of four "sliding_attention" layers followed one "full_attention". The last layer in the model should always
+ be a "full_attention" layer.
+ final_logit_softcapping (`float`, *optional*, defaults to 30.0):
+ Scaling factor when applying tanh softcapping on the logits.
+ altup_active_idx (`int`, *optional*, defaults to 0):
+ The index of the prediction from which AltUp will compute additional predictions or correct
+ altup_coef_clip (`float`, *optional*, defaults to 120.0):
+ The maximum amplitude of an AltUp prediction or correction coefficient weight.
+ altup_correct_scale (`bool`, *optional*, defaults to `True`):
+ If True, apply the `AltUp.correct_output_scale` to the corrected prediction at `altup_active_idx`.
+ altup_num_inputs (`int`, *optional*, defaults to 4):
+ The number of predictions that AltUp should be make given the input sequence.
+ num_kv_shared_layers (`int`, *optional*, defaults to 15):
+ The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers`
+ layers in the model "share" the KV values in that each local and global layer in this range uses the KV
+ cache values computed for the last local or global layer, respectively, before entering this range. The
+ value should be a multiple of the attention pattern size (see `layer_types` parameter).
+ laurel_rank (int, *optional*, defaults to 64):
+ The intermediate size for the linear projections in the Learned Augmented Residual Layer.
+ activation_sparsity_pattern (Sequence[float], *optional*):
+ The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must
+ explicitly provide a sparsity value for each layer in the model. By default, the first 10 layers are
+ sparse with a sparsity factor of 0.95 and the rest are dense.
+
+ ```python
+ >>> from transformers import Gemma3nTextModel, Gemma3nTextConfig
+
+ >>> # Initializing a Gemma3nText gemma3n_text-E4B style configuration
+ >>> configuration = Gemma3nTextConfig()
+
+ >>> # Initializing a model from the gemma3n_text-E4B style configuration
+ >>> model = Gemma3nTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "gemma3n_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size: int = 262_400,
+ vocab_size_per_layer_input: int = 262_144,
+ hidden_size: int = 2048,
+ hidden_size_per_layer_input: int = 256,
+ intermediate_size: Union[int, Sequence[int]] = 16_384,
+ num_hidden_layers: int = 35,
+ num_attention_heads: int = 8,
+ num_key_value_heads: int = 2,
+ head_dim: int = 256,
+ hidden_activation: str = "gelu_pytorch_tanh",
+ max_position_embeddings: int = 32_768,
+ initializer_range: float = 0.02,
+ rms_norm_eps: float = 1e-6,
+ use_cache: bool = True,
+ pad_token_id: int = 0,
+ eos_token_id: int = 1,
+ bos_token_id: int = 2,
+ rope_theta: float = 1_000_000.0,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ rope_local_base_freq: float = 10_000.0,
+ attention_bias: bool = False,
+ attention_dropout: float = 0.0,
+ sliding_window: int = 512,
+ layer_types: Optional[Sequence[str]] = None,
+ final_logit_softcapping: float = 30.0,
+ altup_active_idx: int = 0,
+ altup_coef_clip: float = 120.0,
+ altup_correct_scale: bool = True,
+ altup_num_inputs: int = 4,
+ num_kv_shared_layers: int = 15,
+ laurel_rank: int = 64,
+ activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = None,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+
+ if isinstance(intermediate_size, Sequence) and (intsize_len := len(intermediate_size)) != num_hidden_layers:
+ raise ValueError(
+ "intermediate_size must have an explicit intermediate size for every layer or one for all layers. "
+ f"Expected {num_hidden_layers} values but got {intsize_len}."
+ )
+ elif not isinstance(intermediate_size, Sequence):
+ intermediate_size = [intermediate_size] * num_hidden_layers
+
+ self.vocab_size = vocab_size
+ self.vocab_size_per_layer_input = vocab_size_per_layer_input
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ self.num_key_value_heads = num_key_value_heads
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.hidden_activation = hidden_activation
+ self.sliding_window = sliding_window
+ self.final_logit_softcapping = final_logit_softcapping
+ self.layer_types = layer_types
+
+ self.rope_local_base_freq = rope_local_base_freq
+ self.rope_scaling = rope_scaling
+ rope_config_validation(self)
+
+ if layer_types is None:
+ self.layer_types = [
+ "full_attention" if (i + 1) % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers)
+ ]
+ else:
+ self.layer_types = layer_types
+
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+ self.hidden_size_per_layer_input = hidden_size_per_layer_input
+ self.num_kv_shared_layers = num_kv_shared_layers
+
+ self.altup_active_idx = altup_active_idx
+ self.altup_coef_clip = altup_coef_clip
+ self.altup_correct_scale = altup_correct_scale
+ self.altup_num_inputs = altup_num_inputs
+
+ self.laurel_rank = laurel_rank
+
+ if activation_sparsity_pattern is None:
+ num_sparse_layers = 10 if num_hidden_layers > 10 else 0
+ activation_sparsity_pattern = [0.95] * num_sparse_layers + [0.0] * (num_hidden_layers - num_sparse_layers)
+
+ if (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers:
+ raise ValueError(
+ "activation_sparsity_pattern must have an explicit activation sparsity value for every layer."
+ f"Expected {num_hidden_layers} values but got {len_asp}."
+ )
+ self.activation_sparsity_pattern = activation_sparsity_pattern
+
+
+class Gemma3nAudioConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`]. It is used to instantiate
+ an `Gemma3nAudioEncoder` model according to the specified arguments, defining the model architecture. Instantiating
+ a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.,
+ [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
+
+ Configuration objects that inherit from [`Gemma3nAudioConfig`] and can be used to control the model outputs. Read
+ the documentation from [`Gemma3nAudioConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 128):
+ Vocabulary size of the additional hard-token embeddings for audio model. These augment the embeddings
+ included in the `Gemma3nTextModel` to provide, e.g., the end of audio and audio soft token placeholder
+ tokens when converting `input_ids` to embeddings in the `Gemma3nForConditionalGeneration` model.
+ vocab_offset (`int`, *optional*, defaults to 262272):
+ Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
+ 0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
+ input_feat_size (`int`, *optional*, defaults to 128):
+ The number of channels in each mel-spectrogram frame.
+ hidden_size (`int`, *optional*, defaults to 1536):
+ Dimension of the hidden representations.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ gradient_clipping (`float`, *optional*, defaults to 10000000000.0):
+ Clipping value used to stabilize extremely large gradient values.
+ conf_attention_chunk_size (`int`, *optional*, defaults to 12):
+ The sub-sequence size for local attention processing inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_attention_context_left (`int`, *optional*, defaults to 13):
+ The left context size of the local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_attention_context_right (`int`, *optional*, defaults to 0):
+ The right context size of the local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_attention_logit_cap (`float`, *optional*, defaults to 50.0):
+ Logit cap applied during local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_num_attention_heads (`int`, *optional*, defaults to 8):
+ The number of attention heads in local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_num_hidden_layers (`int`, *optional*, defaults to 12):
+ The number of layers that use local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_conv_kernel_size (`int`, *optional*, defaults to 5):
+ Convolution kernel size for the conformer block inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_reduction_factor (`int`, *optional*, defaults to 4):
+ Reduction factor used in the conformer block inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_residual_weight (`float`, *optional*, defaults to 0.5):
+ Residual connection weight inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`):
+ The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection
+ ("sscp") section of the Universal Speech Model.
+ sscp_conv_group_norm_eps (`float`, *optional*, defaults to 0.001):
+ Epsilon used in group normalization in the subsample convolution projection in the Sub-sample Convolution
+ Projection ("sscp") section of the Universal Speech Model.
+ sscp_conv_kernel_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((3, 3), (3, 3))`):
+ Kernel sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
+ Convolution Projection ("sscp") section of the Universal Speech Model. The kernel sizes are specified as a
+ tuple of height and width for each layer, where the height corresponds to the time dimension and the width
+ corresponds to the frequency dimension.
+ sscp_conv_stride_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((2, 2), (2, 2))`):
+ Stride sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
+ Convolution Projection ("sscp") section of the Universal Speech Model. The stride sizes are specified as a
+ tuple of height and width for each layer, where the height corresponds to the time dimension and the width
+ corresponds to the frequency dimension.
+
+ Example:
+
+ ```python
+ >>> from transformers import Gemma3nAudioConfig, Gemma3nAudioEncoder
+
+ >>> # Initializing a Gemma3nAudioEncoder gemma3n_audio-E4B-style configuration
+ >>> configuration = Gemma3nAudioConfig()
+
+ >>> # Initializing a model from the gemma3n_audio-E4B style configuration
+ >>> model = Gemma3nAudioEncoder(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "gemma3n_audio"
+
+ def __init__(
+ self,
+ vocab_size: int = 128,
+ vocab_offset: int = 262_144 + 128, # text vocab size + vision vocab size
+ input_feat_size: int = 128,
+ hidden_size: int = 1536,
+ rms_norm_eps: float = 1e-6,
+ gradient_clipping: float = 10_000_000_000.0,
+ conf_attention_chunk_size: int = 12,
+ conf_attention_context_left: int = 13,
+ conf_attention_context_right: int = 0,
+ conf_attention_logit_cap: float = 50.0,
+ conf_num_attention_heads: int = 8,
+ conf_num_hidden_layers: int = 12,
+ conf_conv_kernel_size: int = 5,
+ conf_reduction_factor: int = 4,
+ conf_residual_weight: float = 0.5,
+ sscp_conv_channel_size: tuple[int, int] = (128, 32),
+ sscp_conv_group_norm_eps: float = 1e-3,
+ sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = (
+ (3, 3),
+ (3, 3),
+ ),
+ sscp_conv_stride_size: tuple[tuple[int, int], tuple[int, int]] = (
+ (2, 2),
+ (2, 2),
+ ),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.input_feat_size = input_feat_size
+ self.hidden_size = hidden_size
+ self.rms_norm_eps = rms_norm_eps
+ self.vocab_size = vocab_size
+ self.vocab_offset = vocab_offset
+ self.gradient_clipping = gradient_clipping
+ self.conf_attention_chunk_size = conf_attention_chunk_size
+ self.conf_attention_context_left = conf_attention_context_left
+ self.conf_attention_context_right = conf_attention_context_right
+ self.conf_attention_logit_cap = conf_attention_logit_cap
+ self.conf_num_attention_heads = conf_num_attention_heads
+ self.conf_num_hidden_layers = conf_num_hidden_layers
+ self.conf_conv_kernel_size = conf_conv_kernel_size
+ self.conf_reduction_factor = conf_reduction_factor
+ self.conf_residual_weight = conf_residual_weight
+ self.sscp_conv_channel_size = sscp_conv_channel_size
+ self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps
+ self.sscp_conv_kernel_size = sscp_conv_kernel_size
+ self.sscp_conv_stride_size = sscp_conv_stride_size
+
+
+class Gemma3nVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. It is used to
+ instantiate an timm model model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B
+ vision tower, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
+
+ Configuration objects inherit from [`Gemma3nVisionConfig`] and can be used to control the model outputs. Read the
+ documentation from [`Gemma3nVisionConfig`] for more information.
+
+ Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default
+ imagenet models is set to `None` due to occlusions in the label descriptions.
+
+ Args:
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ do_pooling (`bool`, *optional*, defaults to `False`):
+ Whether to do pooling for the last_hidden_state in `TimmWrapper` or not.
+ architecture (`str`, *optional*, defaults to `"mobilenetv5_300m_enc"`):
+ Determines vision architecture for TimmWrapper.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ vocab_size (`int`, *optional*, defaults to 128):
+ Vocabulary size of the additional hard-token embeddings for vision model.
+ vocab_offset (`int`, *optional*, defaults to 262144):
+ Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
+ 0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+
+ Example:
+ ```python
+ >>> from transformers import Gemma3nVisionConfig, TimmWrapper
+
+ >>> # Initializing a TimmWrapper gemma3n_vision-E4B-style configuration
+ >>> configuration = Gemma3nVisionConfig()
+
+ >>> # Initializing a gemma3n_vision-E4B-style TimmWrapper from the configuration
+ >>> model = TimmWrapper(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "gemma3n_vision"
+
+ def __init__(
+ self,
+ initializer_range: float = 0.02,
+ do_pooling: bool = False,
+ architecture: str = "mobilenetv5_300m_enc",
+ hidden_size: int = 2048,
+ vocab_size: int = 128,
+ vocab_offset: int = 262_144,
+ rms_norm_eps: float = 1e-06,
+ model_args: Optional[dict] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.architecture = architecture
+ self.initializer_range = initializer_range
+ self.do_pooling = do_pooling
+ self.model_args = model_args # named "model_args" for BC with timm
+ self.hidden_size = hidden_size
+ self.vocab_size = vocab_size
+ self.vocab_offset = vocab_offset
+ self.rms_norm_eps = rms_norm_eps
+
+ @classmethod
+ def from_dict(cls, config_dict: dict[str, Any], **kwargs):
+ label_names = config_dict.get("label_names")
+ is_custom_model = "num_labels" in kwargs or "id2label" in kwargs
+
+ # if no labels added to config, use imagenet labeller in timm
+ if label_names is None and not is_custom_model:
+ requires_backends(cls, ["timm"])
+ imagenet_subset = infer_imagenet_subset(config_dict)
+ if imagenet_subset:
+ dataset_info = ImageNetInfo(imagenet_subset)
+ synsets = dataset_info.label_names()
+ label_descriptions = dataset_info.label_descriptions(as_dict=True)
+ label_names = [label_descriptions[synset] for synset in synsets]
+
+ if label_names is not None and not is_custom_model:
+ kwargs["id2label"] = dict(enumerate(label_names))
+
+ # if all label names are unique, create label2id mapping as well
+ if len(set(label_names)) == len(label_names):
+ kwargs["label2id"] = {name: i for i, name in enumerate(label_names)}
+ else:
+ kwargs["label2id"] = None
+
+ # timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
+ # We are removing these attributes in order to have the native `transformers` num_labels attribute in config
+ # and to avoid duplicate attributes
+ num_labels_in_kwargs = kwargs.pop("num_labels", None)
+ num_labels_in_dict = config_dict.pop("num_classes", None)
+
+ # passed num_labels has priority over num_classes in config_dict
+ kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict
+
+ # pop num_classes from "pretrained_cfg",
+ # it is not necessary to have it, only root one is used in timm
+ if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]:
+ config_dict["pretrained_cfg"].pop("num_classes", None)
+
+ return super().from_dict(config_dict, **kwargs)
+
+ def to_dict(self) -> dict[str, Any]:
+ output = super().to_dict()
+ output.setdefault("num_classes", self.num_labels)
+ output.setdefault("label_names", list(self.id2label.values()))
+ output.pop("id2label", None)
+ output.pop("label2id", None)
+ return output
+
+
+class Gemma3nConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Gemma3nForConditionalGeneration`]. It is used to
+ instantiate a Gemma3nForConditionalGeneration according to the specified arguments, defining the model
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
+ Gemma3n-E4B.
+
+ e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config (`Union[Gemma3nTextConfig, dict]`, *optional*):
+ The config object of the text backbone.
+ vision_config (`Union[AutoConfig, dict]`, *optional*):
+ Custom vision config or dict.
+ audio_config (`Union[AutoConfig, dict]`, *optional*):
+ Custom audio config or dict.
+ audio_soft_tokens_per_image (`int`, *optional*, defaults to 188):
+ The number of soft tokens per audio clip.
+ vision_soft_tokens_per_image (`int`, *optional*, defaults to 256):
+ The number of soft tokens per image.
+ boi_token_id (`int`, *optional*, defaults to 255999):
+ The begin-of-image token index to wrap the image prompt.
+ eoi_token_id (`int`, *optional*, defaults to 262144):
+ The end-of-image token index to wrap the image prompt.
+ image_token_id (`int`, *optional*, defaults to 262145):
+ The image token index to encode the image prompt.
+ boa_token_id (`int`, *optional*, defaults to 256000):
+ The begin-of-audio token index to wrap the audio prompt.
+ eoa_token_id (`int`, *optional*, defaults to 262272):
+ The end-of-audio token index to wrap the audio prompt.
+ audio_token_id (`int`, *optional*, defaults to 262273):
+ The audio token index to encode the audio prompt.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+
+ Example:
+
+ ```python
+ >>> from transformers import Gemma3nForConditionalGeneration, Gemma3nConfig, Gemma3nTextConfig
+
+ >>> # Initializing a MobileNet vision config, which is loaded from TIMM
+ >>> vision_config = Gemma3nVisionConfig()
+
+ >>> # Initializing a Gemma3n Audio config
+ >>> audio_config = Gemma3nAudioConfig()
+
+ >>> # Initializing a Gemma3n Text config
+ >>> text_config = Gemma3nTextConfig()
+
+ >>> # Initializing a Gemma3n gemma-3-4b style configuration
+ >>> configuration = Gemma3nConfig(text_config, vision_config, audio_config)
+
+ >>> # Initializing a model from the gemma-3-4b style configuration
+ >>> model = Gemma3nTextConfig(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "gemma3n"
+ sub_configs = {
+ "text_config": Gemma3nTextConfig,
+ "vision_config": Gemma3nVisionConfig,
+ "audio_config": Gemma3nAudioConfig,
+ }
+
+ def __init__(
+ self,
+ text_config: Optional[Union[Gemma3nTextConfig, dict[str, Any]]] = None,
+ vision_config: Optional[Union[Gemma3nVisionConfig, dict[str, Any]]] = None,
+ audio_config: Optional[Union[Gemma3nAudioConfig, dict[str, Any]]] = None,
+ audio_soft_tokens_per_image: int = 188,
+ vision_soft_tokens_per_image: int = 256,
+ boi_token_id: int = 255_999,
+ eoi_token_id: int = 262_144,
+ image_token_id: int = 262_145,
+ boa_token_id: int = 256_000,
+ eoa_token_id: int = 262_272,
+ audio_token_id: int = 262_273,
+ initializer_range: float = 0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if isinstance(text_config, dict):
+ text_config = Gemma3nTextConfig(**text_config)
+ elif text_config is None:
+ text_config = Gemma3nTextConfig()
+ logger.info("text_config is None. Using default Gemma3nTextConfig.")
+
+ if isinstance(vision_config, dict):
+ vision_config = Gemma3nVisionConfig(**vision_config)
+ elif vision_config is None:
+ vision_config = Gemma3nVisionConfig()
+ logger.info("vision_config is None. Using default Gemma3nVisionConfig.")
+
+ if isinstance(audio_config, dict):
+ audio_config = Gemma3nAudioConfig(**audio_config)
+ elif audio_config is None:
+ audio_config = Gemma3nAudioConfig()
+ logger.info("audio_config is None. Using default Gemma3nAudioConfig.")
+
+ self.text_config = text_config
+ self.vision_config = vision_config
+ self.audio_config = audio_config
+
+ self.audio_soft_tokens_per_image = audio_soft_tokens_per_image
+ self.vision_soft_tokens_per_image = vision_soft_tokens_per_image
+ self.boi_token_id = boi_token_id
+ self.eoi_token_id = eoi_token_id
+ self.image_token_id = image_token_id
+ self.boa_token_id = boa_token_id
+ self.eoa_token_id = eoa_token_id
+ self.audio_token_id = audio_token_id
+ self.initializer_range = initializer_range
+
+
+__all__ = ["Gemma3nAudioConfig", "Gemma3nConfig", "Gemma3nTextConfig", "Gemma3nVisionConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gemma3n/feature_extraction_gemma3n.py b/venv/lib/python3.13/site-packages/transformers/models/gemma3n/feature_extraction_gemma3n.py
new file mode 100644
index 0000000000000000000000000000000000000000..62e3fb3878f73fc7c794f544ddd94232f9ae1b4d
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gemma3n/feature_extraction_gemma3n.py
@@ -0,0 +1,338 @@
+# coding=utf-8
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections.abc import Sequence
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
+from ...feature_extraction_utils import BatchFeature
+from ...utils import PaddingStrategy, TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def create_fb_matrix(
+ n_freqs: int,
+ f_min: float,
+ f_max: float,
+ n_mels: int,
+ sample_rate: int,
+ fft_length: int,
+ norm: Optional[str] = None,
+) -> np.ndarray:
+ r"""Create a frequency bin conversion matrix (NumPy version).
+
+ Args:
+ n_freqs (int): Number of frequencies to highlight/apply
+ f_min (float): Minimum frequency (Hz)
+ f_max (float): Maximum frequency (Hz)
+ n_mels (int): Number of mel filterbanks
+ sample_rate (int): Sample rate of the audio waveform
+ fft_length (int): FFT length
+ norm (Optional[str]): If 'slaney', divide the triangular mel weights by
+ the width of the mel band (area normalization). (Default: ``None``)
+
+ Returns:
+ np.ndarray: Triangular filter banks (fb matrix) of size (``n_freqs``,
+ ``n_mels``)
+ meaning number of frequencies to highlight/apply to x the number of
+ filterbanks.
+ Each column is a filterbank so that assuming there is a matrix A of
+ size (..., ``n_freqs``), the applied result would be
+ ``A @ create_fb_matrix_numpy(A.shape[-1], ...)``.
+ """
+
+ if norm is not None and norm != "slaney":
+ raise ValueError("norm must be one of None or 'slaney'")
+
+ # freq bins
+ all_freqs = np.arange(n_freqs, dtype=np.float32) * (sample_rate / fft_length)
+
+ # calculate mel freq bins
+ # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
+ m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
+ m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
+ m_pts = np.linspace(m_min, m_max, n_mels + 2)
+ # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
+ f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
+ # calculate difference between each mel point and each stft freq point in Hz
+ f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
+ slopes = np.expand_dims(f_pts, 0) - np.expand_dims(all_freqs, 1) # (n_freqs, n_mels + 2)
+ # create overlapping triangles
+ zero = np.zeros(1, dtype=np.float32)
+ down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
+ up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
+ fb = np.maximum(zero, np.minimum(down_slopes, up_slopes))
+
+ if norm is not None and norm == "slaney":
+ # Slaney-style mel is scaled to be approx constant energy per channel
+ enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
+ fb *= np.expand_dims(enorm, 0)
+
+ return fb
+
+
+def _unfold(array: np.ndarray, dimension: int, size: int, step: int) -> np.ndarray:
+ """A basic NumPy equivalent of PyTorch's unfold for 2D arrays along the last dim."""
+ if array.ndim != 2:
+ raise ValueError("This unfold implementation currently supports 2D arrays (batch, time).")
+ if dimension != -1 and dimension != array.ndim - 1:
+ raise ValueError("This unfold implementation only supports unfolding the last dimension.")
+
+ batch_size, original_length = array.shape
+ num_frames = (original_length - size) // step + 1
+
+ if num_frames <= 0:
+ return np.zeros((batch_size, 0, size), dtype=array.dtype)
+
+ output_shape = (batch_size, num_frames, size)
+ output_strides = (array.strides[0], array.strides[1] * step, array.strides[1])
+
+ return np.lib.stride_tricks.as_strided(array, shape=output_shape, strides=output_strides)
+
+
+class Gemma3nAudioFeatureExtractor(SequenceFeatureExtractor):
+ """An audio feature extractor Universal Speech Models https://huggingface.co/papers/2303.01037.
+
+ Args:
+ feature_size (`int`, *optional*, defaults to 128):
+ The feature dimension of the extracted features.
+ sampling_rate (`int`, *optional*, defaults to 16000):
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
+ padding_value (`float`, *optional*, defaults to 0.0):
+ Padding value used to pad the audio. Should correspond to silences.
+ return_attention_mask (`bool`, *optional*, defaults to `True`):
+ Whether to return the attention mask for the generated MEL spectrograms.
+ frame_length_ms (`float`, *optional*, defaults to 32.0):
+ The length of a frame in milliseconds.
+ hop_length_ms (`float`, *optional*, defaults to 10.0):
+ Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
+ min_frequency (`float`, *optional*, defaults to 125.0):
+ The minimum frequency (in Hz) for the Mel filterbank.
+ max_frequency (`float`, *optional*, defaults to 7600.0):
+ The maximum frequency (in Hz) for the Mel filterbank.
+ preemphasis (`float`, *optional*, defaults to 0.97):
+ The preemphasis coefficient.
+ preemphasis_htk_flavor (`bool`, *optional*, defaults to `True`):
+ Whether to use HTK-style preemphasis.
+ fft_overdrive (`bool`, *optional*, defaults to `True`):
+ Whether to use FFT overdrive.
+ dither (`float`, *optional*, defaults to 0.0):
+ Adds dithering. In other words, adds a small Gaussian noise to each frame.
+ E.g. use 0.0001 to add dithering with a normal distribution centered
+ around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech).
+ The value 0.0 means no dithering.
+ Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces
+ the high log_mel_fbank values for signals with hard-zero sections,
+ when VAD cutoff is present in the signal.
+ input_scale_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor applied to the input waveform.
+ mel_floor (`float`, *optional*, defaults to 1e-05):
+ Minimum value for Mel spectrograms to avoid log(0).
+ per_bin_mean (`Optional[Sequence[float]]`, *optional*):
+ Mean values for per-bin normalization.
+ per_bin_stddev (`Optional[Sequence[float]]`, *optional*):
+ Standard deviation values for per-bin normalization.
+ """
+
+ model_input_names = ["input_features", "input_features_mask"]
+
+ def __init__(
+ self,
+ feature_size: int = 128,
+ sampling_rate: int = 16_000,
+ padding_value: float = 0.0,
+ return_attention_mask: bool = True,
+ frame_length_ms: float = 32.0,
+ hop_length_ms: float = 10.0,
+ min_frequency: float = 125.0,
+ max_frequency: float = 7600.0,
+ preemphasis: float = 0.97,
+ preemphasis_htk_flavor: bool = True,
+ fft_overdrive: bool = True,
+ dither: float = 0.0,
+ input_scale_factor: float = 1.0,
+ mel_floor: float = 1e-5,
+ per_bin_mean: Optional[Sequence[float]] = None,
+ per_bin_stddev: Optional[Sequence[float]] = None,
+ **kwargs,
+ ):
+ super().__init__(
+ feature_size=feature_size,
+ sampling_rate=sampling_rate,
+ padding_value=padding_value,
+ return_attention_mask=return_attention_mask,
+ **kwargs,
+ )
+
+ self.min_frequency = min_frequency
+ self.max_frequency = max_frequency
+ self.preemphasis = preemphasis
+ self.preemphasis_htk_flavor = preemphasis_htk_flavor
+ self.fft_overdrive = fft_overdrive
+ self.dither = dither
+ self.input_scale_factor = input_scale_factor
+ self.frame_length = int(round(sampling_rate * frame_length_ms / 1000.0))
+ self.hop_length = int(round(sampling_rate * hop_length_ms / 1000.0))
+ self.mel_floor = np.array(mel_floor, dtype=np.float64)
+
+ fft_length = 2 ** math.ceil(math.log2(self.frame_length))
+ if self.fft_overdrive:
+ fft_length *= 2
+ self.fft_length = fft_length
+
+ hann_arange = np.arange(self.frame_length, dtype=np.float32)
+ window = 0.5 * (1 - np.cos(2 * np.pi * hann_arange / self.frame_length))
+ self.window = window.astype(np.float32)
+
+ self.mel_filters = create_fb_matrix(
+ n_freqs=self.fft_length // 2 + 1,
+ f_min=min_frequency,
+ f_max=max_frequency,
+ n_mels=feature_size,
+ sample_rate=self.sampling_rate,
+ norm=None,
+ fft_length=fft_length,
+ )
+
+ if per_bin_mean is not None:
+ self.per_bin_mean = np.array(per_bin_mean).reshape(1, 1, feature_size)
+ else:
+ self.per_bin_mean = None
+
+ if per_bin_stddev is not None:
+ self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, 1, feature_size)
+ else:
+ self.per_bin_stddev = None
+
+ def _extract_spectrogram(self, waveform: np.ndarray, attention_mask: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+ """"""
+ if waveform.ndim == 1: # If single waveform, add batch dimension
+ waveform = np.expand_dims(waveform, axis=0)
+
+ if self.dither > 0.0:
+ waveform = waveform + self.dither * np.random.randn(*waveform.shape).astype(waveform.dtype)
+
+ if self.input_scale_factor != 1.0:
+ waveform = waveform * self.input_scale_factor
+
+ frame_size_for_unfold = self.frame_length + 1
+
+ # NumPy equivalent of unfold for [B, NumFrames, frame_size_for_unfold]
+ frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=self.hop_length)
+
+ if self.preemphasis > 0.0:
+ if self.preemphasis_htk_flavor:
+ first_in_frame = frames_to_process[..., :1] * (1.0 - self.preemphasis)
+ rest_in_frame = frames_to_process[..., 1:-1] - self.preemphasis * frames_to_process[..., :-2]
+ frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1)
+ else:
+ frames = frames_to_process[..., 1:] - self.preemphasis * frames_to_process[..., :-1]
+ else:
+ frames = frames_to_process[..., :-1]
+
+ frames = frames * self.window # Broadcasting window
+ stft = np.fft.rfft(frames, n=self.fft_length, axis=-1)
+
+ magnitude_spec = np.abs(stft)
+
+ mel_spec = np.matmul(magnitude_spec, self.mel_filters)
+ log_mel_spec = np.log(np.maximum(mel_spec, self.mel_floor))
+
+ if self.per_bin_mean is not None:
+ log_mel_spec = log_mel_spec - self.per_bin_mean # Broadcasting
+
+ if self.per_bin_stddev is not None:
+ log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting
+
+ mel_spectrogram = log_mel_spec.squeeze(0)
+ mask = attention_mask[:: self.hop_length].astype(bool)
+ # TODO: The filtered mask is always exactly 3 elements longer than the mel_spectrogram. Why???
+ return mel_spectrogram, mask[: mel_spectrogram.shape[0]]
+
+ def __call__(
+ self,
+ raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
+ padding: Union[bool, str, PaddingStrategy] = "longest",
+ max_length: Optional[int] = 480_000,
+ truncation: bool = True,
+ pad_to_multiple_of: Optional[int] = 128,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_attention_mask: Optional[bool] = True,
+ **kwargs,
+ ) -> BatchFeature:
+ """Creates a batch of MEL spectrograms from the provided raw speech.
+
+ This implementation uses a different algorithm for windowing and preemphasis compared to the built-in
+ `transformers.audio_utils.spectrogram()` function that _will_ result in different outputs. Consider this
+ carefully when selecting an audio feature extractor, especially with pre-trained models.
+
+ Args:
+ raw_speech:
+ The audio for which MEL spectrograms are created.
+ padding (`Union[bool, str, PaddingStrategy]`, *optional*, defaults to `"longest"`):
+ The padding strategy to use for batches of audio with different lengths.
+ max_length (`int`, *optional*, defaults to 480000):
+ If provided, defines the maximum length of the audio to allow. Audio longer than this will be
+ truncated if `truncation=True`.
+ truncation (`bool`, *optional*, defaults to `True`):
+ Whether or not to truncate audio above `max_length`.
+ pad_to_multiple_of (`int`, *optional*, defaults to 128):
+ When padding, pad to a multiple of this value. The default value is defined for optimal TPU support.
+ return_tensors (`Union[str, TensorType]`, *optional*, defaults to `None`):
+ The type of tensors to return (e.g., NumPy, Torch, JAX, TensorFlow).
+ return_attention_mask (`bool`, *optional*, defaults to `True`):
+ Whether to return the attention mask for the generated MEL spectrograms.
+ """
+
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
+ is_batched_sequence = isinstance(raw_speech, Sequence) and isinstance(raw_speech[0], (np.ndarray, Sequence))
+ is_batched = is_batched_numpy or is_batched_sequence
+
+ if is_batched:
+ raw_speech = [np.asarray([rs]).T for rs in raw_speech]
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
+ raw_speech = np.asarray(raw_speech)
+
+ if not is_batched: # always return a batch
+ raw_speech = [np.asarray([raw_speech])]
+
+ batched_speech = self.pad(
+ BatchFeature({"input_features": raw_speech}),
+ padding=padding,
+ max_length=max_length,
+ truncation=truncation,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ )
+
+ prepared_speech = []
+ prepared_speech_mask = []
+ for speech, mask in zip(batched_speech.input_features, batched_speech.attention_mask):
+ speech, mask = self._extract_spectrogram(speech.T, mask)
+ prepared_speech.append(speech.astype(np.float32))
+ prepared_speech_mask.append(mask)
+
+ return BatchFeature(
+ {"input_features": prepared_speech, "input_features_mask": prepared_speech_mask},
+ tensor_type=return_tensors,
+ )
+
+
+__all__ = ["Gemma3nAudioFeatureExtractor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gemma3n/modeling_gemma3n.py b/venv/lib/python3.13/site-packages/transformers/models/gemma3n/modeling_gemma3n.py
new file mode 100644
index 0000000000000000000000000000000000000000..68595ead4371bba4087813566d5c83d5c0155c03
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gemma3n/modeling_gemma3n.py
@@ -0,0 +1,2394 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_gemma3n.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import copy
+import math
+from collections.abc import Callable, Sequence
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.deprecation import deprecate_kwarg
+from ..auto import AutoModel
+from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Gemma3n outputs, with hidden states and attentions.
+ """
+)
+class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+ audio_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Gemma3n causal language model (or autoregressive) outputs.
+ """
+)
+class Gemma3nCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+ audio_hidden_states: Optional[torch.FloatTensor] = None
+
+
+class Gemma3nRMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
+ super().__init__()
+ self.eps = eps
+ self.with_scale = with_scale
+
+ if self.with_scale:
+ self.weight = nn.Parameter(torch.ones(dim))
+ else:
+ self.register_buffer("weight", torch.tensor(1.0), persistent=False)
+
+ def _norm(self, x):
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
+ # See https://github.com/huggingface/transformers/pull/29402
+ output = self._norm(x.float()) * self.weight.float()
+ return output.type_as(x)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
+
+
+# ==== Audio Encoder ====
+
+
+class Gemma3nAudioRelativePositionEmbedding(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.num_heads = self.config.conf_num_attention_heads
+ self.channels = self.config.hidden_size
+ self.head_dim = self.channels // self.num_heads
+ self.max_backward = max(0, self.config.conf_attention_context_left - 1)
+ self.max_forward = self.config.conf_attention_context_right
+
+ self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
+
+ min_timescale = 1.0
+ max_timescale = 1.0e4
+ num_timescales = self.channels // 2
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
+ inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
+ self.register_buffer(
+ "inv_timescales",
+ inv_timescales.float().unsqueeze(0).unsqueeze(0),
+ persistent=False,
+ )
+
+ def _get_timing_signal_1d_pos(self, position: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ position = position.float().unsqueeze(-1)
+ scaled_time = position * self.inv_timescales.to(device=position.device, dtype=torch.float32)
+ timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
+ return timing_signal.type(dtype)
+
+ def _relative_shift(
+ self,
+ term_bd_before_shift: torch.Tensor,
+ batch_size: int,
+ num_heads: int,
+ num_query_blocks: int,
+ query_block_size: int,
+ key_context_size: int,
+ max_span_plus_1: int,
+ ) -> torch.Tensor:
+ """Performs the relative shift.
+
+ Args:
+ term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
+ (B), num_heads (N), num_query_blocks (U), query_block_size (W),
+ key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1).
+
+ Returns:
+ Tensor of shape [B, N, U, W, C].
+ """
+ # term_bd_before_shift shape: [B, N, U, W, F_span]
+ # Target shape after shift: [B, N, U, W, C]
+
+ # Padding amount for the last dimension (F_span) to become (C + 1)
+ # C = key_context_size
+ # F_span = max_span_plus_1
+ pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
+
+ # PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...)
+ # We only pad the last dimension on the right.
+ padding_tuple = (0, pad_amount_last_dim)
+
+ term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
+ # Shape after pad: [B, N, U, W, C+1]
+
+ # Reshape for slicing (emulating JAX's behavior)
+ # [B, N, U, W * (C+1)]
+ term_bd_reshaped = term_bd_padded.reshape(
+ (
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size * (key_context_size + 1),
+ )
+ )
+
+ # Slice to effective [B, N, U, W * C]
+ term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size]
+
+ # Reshape back to [B, N, U, W, C]
+ term_bd_shifted = term_bd_sliced.reshape(
+ (
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size,
+ key_context_size,
+ )
+ )
+ return term_bd_shifted
+
+ def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
+ # queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim)
+ # keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim)
+ # C = W + L + R (key_context_size)
+ # F_span = L + R + 1 (max_span + 1)
+
+ batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
+ _, _, key_context_size, _, _ = keys.shape
+
+ # Relative positions for sinusoidal embeddings: [L, L-1, ..., -R]
+ # Length is L+R+1 = self.max_span + 1
+ pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze(
+ 0
+ ) # Shape [1, F_span]
+
+ max_span_plus_1 = pos_indices.shape[1] # F_span
+
+ sin_emb_timing_signal = self._get_timing_signal_1d_pos(
+ pos_indices, dtype=queries.dtype
+ ) # Shape [1, F_span, self.channels]
+
+ # Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H]
+ projected_sin_emb = self.pos_proj(sin_emb_timing_signal)
+ # Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H]
+ sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze(
+ 0
+ ) # Shape [F, N, H]
+
+ # term_ac: Query-Key content interaction
+ # queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul
+ # keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul
+ queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H]
+ keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C]
+ term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
+
+ # term_bd: Query-Position interaction
+ # Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
+ # queries shape: [B, U, W, N, H]
+ # sin_emb shape: [F, N, H]
+ # Target output shape: [B, N, U, W, F]
+
+ # Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb
+ q_permuted = queries.permute(0, 3, 1, 2, 4)
+
+ # Permute sin_emb to [N, H, F] to prepare for matmul
+ # sin_emb original is [F, N, H]
+ s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F]
+
+ # Reshape queries for matmul: [B, N, U*W, H]
+ q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
+
+ # Perform matmul: [B, N, U*W, H] @ [N, H, F]
+ # s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
+ # Result: [B, N, U*W, F]
+ term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
+
+ # Reshape to target [B, N, U, W, F]
+ term_bd_unshifed = term_bd_unshifed_matmul.reshape(
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size,
+ max_span_plus_1,
+ )
+
+ # Apply relative shift to term_bd_unshifed
+ term_bd_shifted = self._relative_shift(
+ term_bd_unshifed,
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size,
+ key_context_size,
+ max_span_plus_1,
+ ) # Shape [B, N, U, W, C]
+
+ return term_ac + term_bd_shifted
+
+
+class Gemma3nAudioAttention(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.num_heads = self.config.conf_num_attention_heads
+ self.hidden_size = self.config.hidden_size
+ self.head_dim = self.hidden_size // self.num_heads
+
+ self.chunk_size = self.config.conf_attention_chunk_size
+ self.max_future_horizon = self.config.conf_attention_context_right
+ self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
+ self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
+ self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
+
+ self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config)
+ self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+
+ q_scale = self.head_dim**-0.5
+ r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
+ self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
+
+ lower_causal_mask = torch.tril(
+ torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
+ diagonal=0,
+ ).T
+ upper_causal_mask = torch.tril(
+ torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
+ diagonal=self.max_past_horizon + self.max_future_horizon,
+ )
+ local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
+ local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
+ self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
+
+ self.register_buffer(
+ "softcap",
+ torch.tensor(self.attention_logits_soft_cap).float(),
+ persistent=False,
+ )
+
+ def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
+ batch, _, *tail_shape = x.shape
+ left = x.new_zeros((batch, pad_left, *tail_shape))
+ right = x.new_zeros((batch, pad_right, *tail_shape))
+ x = torch.cat([left, x, right], dim=1)
+ return x
+
+ def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Turns a sequence to non overlapping blocks.
+
+ Args:
+ hidden_states: a tensor of [batch, time, ...].
+
+ Returns:
+ A tensor of [batch, num_blocks, block_size, ...], with necessary
+ paddings,
+ where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
+ """
+ shape = hidden_states.shape
+ b, t = shape[:2]
+ num_blocks = (t + self.chunk_size - 1) // self.chunk_size
+
+ if (padding_len := num_blocks * self.chunk_size - t) > 0:
+ hidden_states = self._pad_dim1(hidden_states, 0, padding_len)
+
+ permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
+ hidden_states = hidden_states.reshape(permute_dims).contiguous()
+ return hidden_states
+
+ def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Extracts temporal context for every block.
+
+ Args:
+ hidden_states: a tensor of [batch, time, ...].
+
+ Returns:
+ A tensor of [batch, num_blocks, context_size, ...], with necessary
+ paddings,
+ where context_size = block_size + left_context + right_context,
+ and output[:, i, ...] are x[:, start-left_context:end+right_context,
+ ...],
+ start = i * block_size, end = (i + 1) * block_size.
+ """
+ pad_left = self.max_past_horizon
+ # The JAX equivalent padding for signal.frame with pad_mode='valid' is
+ # (left_context, right_context + block_size - 1) on the time dimension.
+ # PyTorch's _pad_dim1 applies padding symmetrically if only one value is given,
+ # or (pad_dim_start, pad_dim_end) if two are given.
+ # Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
+ # or dim 1 (time for [B,T]).
+ # The current pad_right calculation matches the JAX effective padding.
+ pad_right = self.max_future_horizon + self.chunk_size - 1
+ hidden_states = self._pad_dim1(hidden_states, pad_left, pad_right)
+
+ frame_len = self.context_size
+ frame_step = self.chunk_size
+
+ # Directly use unfold without the subframe_factor logic
+ # x.unfold(dimension, size, step)
+ # dimension=1 (time dimension, assuming x is [B, T_padded, ...])
+ # size=frame_len (context_size)
+ # step=frame_step (chunk_size)
+ x_unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step)
+
+ # If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
+ # If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
+ # We want to match JAX's typical output for such operations which might be
+ # [B, num_blocks, frame_len, N, H] if N, H are present.
+ # The relative_position_embedding expects keys as [B, U, C, N, H].
+ # If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
+ if hidden_states.ndim > 2 and x_unfolded.ndim > 3: # Check if inner dimensions (like N, H) exist
+ # Current shape after unfold for [B, T_pad, N, H] is [B, U, N, H, C]
+ # Target shape for keys in RPE: [B, U, C, N, H]
+ x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2)
+
+ return x_unfolded.contiguous()
+
+ def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+ # sl.Dense uses jax.numpy.einsum("...a,abcd->...bcd") and jax.numpy.select()
+ qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
+ query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
+ key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
+ value_states = self.v_proj(hidden_states).reshape(qkv_shape).contiguous()
+
+ per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale)
+
+ broadcast_shape = (1, 1, 1, self.head_dim)
+ per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape)
+ query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast
+
+ batch_size, q_time = query_states.shape[:2]
+
+ query_blocks = self._convert_to_block(query_states)
+ key_blocks = self._extract_block_context(key_states)
+ value_blocks = self._extract_block_context(value_states)
+ num_query_blocks = query_blocks.shape[1]
+
+ # 1. Create a mask indicating originally valid positions.
+ original_valid_mask = ~mask # True for valid, False for padded
+
+ # 2. Extract blocks from this validity mask.
+ extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask)
+
+ # If subframe_factor was used in _extract_block_context for a [B, T] input mask,
+ # the shape might be [B, U, C/SF, SF]. Reshape to [B, U, C].
+ # batch_size and num_query_blocks are known from query_blocks.
+ # self.context_size is C.
+ if (
+ extracted_valid_mask_blocks.ndim == 4
+ and extracted_valid_mask_blocks.shape[2] * extracted_valid_mask_blocks.shape[3] == self.context_size
+ ):
+ extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
+ batch_size, num_query_blocks, self.context_size
+ )
+ # After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask.
+ # This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
+ # but for the mask case, this should hold.
+ if extracted_valid_mask_blocks.shape != (
+ batch_size,
+ num_query_blocks,
+ self.context_size,
+ ):
+ raise ValueError(
+ "Shape of extracted_valid_mask_blocks"
+ f" {extracted_valid_mask_blocks.shape} is not ({batch_size},"
+ f" {num_query_blocks}, {self.context_size}) after potential reshape."
+ )
+
+ # 3. Expand dimensions for broadcasting with logits and causal mask.
+ # Target shape for broadcasting with logits [B,N,U,W,C]
+ # extracted_valid_mask_blocks to [B, 1, U, 1, C]
+ condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(1).unsqueeze(-2)
+
+ # self.local_causal_valid_mask is [W, C], True where allowed by local window.
+ # Expand to [1, 1, 1, W, C]
+ condition_from_causality = self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
+
+ # 4. Combine the two conditions.
+ # final_condition will be True where a key is *both* originally valid *and* causally accessible.
+ # Broadcasts to [B, 1, U, W, C]
+ final_condition_for_where = torch.logical_and(
+ condition_from_input_validity,
+ condition_from_causality.to(condition_from_input_validity.device), # Ensure same device
+ )
+
+ # Embed queries and keys
+ logits = self.relative_position_embedding(query_blocks, key_blocks)
+
+ # Apply attention logit softcap
+ # Ensure softcap is on the same device as logits
+ softcap_val = self.softcap.to(logits.device)
+ logits = logits / softcap_val
+ logits = torch.tanh(logits)
+ logits = logits * softcap_val
+
+ # Apply the combined mask.
+ # final_condition_for_where will broadcast with logits [B,N,U,W,C]
+ logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
+ probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
+
+ # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
+ b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
+ h_dim = value_blocks.shape[-1]
+ prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
+ v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
+ result_bmm = torch.bmm(prob_bun, v_bun)
+ context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4)
+ context_vectors = context_vectors.reshape(
+ (
+ batch_size,
+ num_query_blocks * self.chunk_size,
+ self.num_heads,
+ self.head_dim,
+ )
+ )
+ context_vectors = context_vectors[:, :q_time]
+
+ return context_vectors
+
+
+class Gemma3nAudioCumulativeGroupNorm(nn.Module):
+ """Applies Group Normalization cumulatively over the time dimension.
+
+ This layer normalizes the input by calculating the mean and variance
+ cumulatively over the time dimension (dim 1). The statistics are computed
+ over all feature dimensions (specified by `feature_dims` and `num_channels`)
+ for elements marked as valid by the optional `mask`.
+
+ If a `mask` is provided (True for valid, False for invalid/padded),
+ invalid time steps do not contribute to the statistics calculation, and
+ their corresponding output values are zeroed out.
+
+ Scale and bias, if enabled, are applied per-channel (last dimension).
+ This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
+ and `cumulative=True`.
+ """
+
+ def __init__(
+ self,
+ num_channels: int, # Number of channels (size of the last dimension)
+ feature_dims: Sequence[int], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
+ eps: float = 1e-3,
+ ):
+ super().__init__()
+ self.num_channels = num_channels
+ self.feature_dims = tuple(feature_dims)
+ self.eps = eps
+
+ # Scale parameter depends only on the channel dimension
+ self.weight = nn.Parameter(torch.ones(num_channels))
+
+ # Axes for normalization: all dimensions except Batch (0) and Time (1).
+ # For input [B, T, *feature_dims, C], these are dims from 2 onwards.
+ self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Applies cumulative group norm, optionally using a mask.
+
+ Args:
+ hidden_states: Input tensor, shape [B, T, *feature_dims, C].
+
+ Returns:
+ Normalized tensor with the same shape as x.
+ """
+ expected_input_suffix = self.feature_dims + (self.num_channels,)
+ if hidden_states.shape[2:] != expected_input_suffix:
+ raise ValueError(
+ f"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected"
+ f" suffix (feature_dims + num_channels) {expected_input_suffix}"
+ )
+
+ input_dtype = hidden_states.dtype
+ # Calculations are performed in float32 for numerical stability.
+ calc_dtype = torch.float32
+ x_calc = hidden_states.to(calc_dtype)
+
+ # Prepare a broadcastable mask (`mask_calc`).
+ # If no mask is provided, treat all elements as valid
+ # (mask_calc is all ones).
+ # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
+ mask_calc = torch.ones_like(x_calc, dtype=calc_dtype)
+
+ # Cumulative Statistics Calculation
+ # 1. Sum of values over reduction axes at each time step.
+ sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True)
+ # 2. Cumulative sum of values over time.
+ cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
+
+ # 3. Count of valid elements in the normalization group at each time step.
+ # (A "group" here consists of all features at a given Batch, Time).
+ elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
+ # 4. Cumulative count of valid elements over time.
+ cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
+ # Avoid division by zero if all preceding elements were masked.
+ safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0)
+
+ # 5. Cumulative mean.
+ cum_mean = cum_sum_values / safe_cum_count_elements
+
+ # 6. Sum of squared differences from the cumulative mean.
+ # Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
+ # Using x_calc here for the difference, as cum_mean already accounts for masking.
+ squared_diff_from_mean = (x_calc - cum_mean).pow(2)
+ sum_sq_diff_at_t = torch.sum(squared_diff_from_mean, dim=self.reduction_axes, keepdim=True)
+
+ # 7. Cumulative sum of squared differences over time.
+ cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1)
+
+ # 8. Cumulative variance.
+ cum_variance = cum_sum_sq_diff / safe_cum_count_elements
+
+ # Normalize the input using the calculated cumulative statistics:
+ # (x - E[x]) / sqrt(Var[x] + eps)
+ normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps)
+
+ # Apply affine transformation (scale and bias) if enabled.
+ # Scale and bias are applied per-channel (last dimension).
+ scale = self.weight.to(calc_dtype)
+ # Reshape for broadcasting: [C] -> [1, ..., 1, C]
+ scale_view_shape = [1] * (hidden_states.dim() - 1) + [self.num_channels]
+ normalized_x = normalized_x * scale.view(scale_view_shape)
+
+ # Zero out outputs for time steps that were originally masked (where mask_calc is 0).
+ # This ensures padded/invalid positions in the input result in zero output.
+ final_output = normalized_x * mask_calc
+
+ return final_output.to(input_dtype)
+
+
+class Gemma3nAudioSSCPConvBlock(nn.Module):
+ """A single convolution block for the SubSampleConvProjection.
+
+ This block consists of a 2D convolution, followed by CumulativeGroupNorm,
+ and a ReLU activation. It handles manual padding for the convolution.
+ """
+
+ def __init__(
+ self,
+ config: Gemma3nAudioConfig,
+ idx: int,
+ input_freq_dim: int, # Changed from input_spatial_dim
+ manual_padding: tuple[int, int, int, int] = (0, 0, 0, 0),
+ ):
+ super().__init__()
+ self.config = config
+ self.manual_padding = manual_padding
+
+ # in_channels is 1 for the first block, or C_out from previous block's conv
+ in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
+ out_channels = self.config.sscp_conv_channel_size[idx]
+ kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
+ stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
+
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(
+ kernel_h,
+ kernel_w,
+ ), # Kernel (kH, kW) operates on (Time, Freq_dim)
+ stride=(stride_h, stride_w),
+ padding=(0, 0), # Manual padding is used
+ bias=False,
+ )
+
+ # Calculate output frequency dimension (f_out_conv) after this convolution.
+ # input_freq_dim is the unpadded width (feature dimension).
+ # self.manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
+ f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1]
+ f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
+
+ self.norm = Gemma3nAudioCumulativeGroupNorm(
+ num_channels=out_channels, # Channels of the conv output
+ feature_dims=(f_out_conv,), # The frequency dimension size after conv
+ eps=self.config.sscp_conv_group_norm_eps,
+ )
+
+ self.activation = nn.ReLU()
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ # Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
+ # manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
+ # F.pad applies to last two dims: F_in then T_in
+ audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0).to(
+ self.conv.weight.dtype
+ )
+ # Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
+ # Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
+ audio_encodings_conv = self.conv(audio_encodings_padded)
+ # Expected conv output shape: [B, C_out, T_out, F_out]
+ # Input to norm is [B, T_out, F_out, C_out]
+ x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous()
+ x_normed = self.norm(x_for_norm)
+ # Output of norm is [B, T_out, F_out, C_out], permute back to [B, C_out, T_out, F_out]
+ audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous()
+ return self.activation(audio_encodings_normed)
+
+
+class Gemma3nAudioSubSampleConvProjection(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ current_f_for_block_input = config.input_feat_size # Start with original feature dim
+ calculated_block_padding = []
+ calculated_f_out_dims = [] # Tracking frequency dimension output sizes
+
+ for i in range(2): # Assuming 2 conv layers as per sscp_conv_... arrays
+ kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
+ stride_h, stride_w = config.sscp_conv_stride_size[i]
+
+ # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
+ # JAX 'reverse_causal' padding is (0, kernel_size - 1)
+ pad_t_top = 0
+ pad_t_bottom = kernel_h - 1
+
+ # Frequency Padding (Width for Conv2d)
+ # Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
+ # and the successful test configuration.
+ # If kernel/stride/input_freq for frequency changes, this might need re-evaluation
+ # to match generic JAX 'SAME' behavior if it differs.
+ pad_f_left = 1
+ pad_f_right = 1
+
+ manual_padding_tuple = (
+ pad_f_left,
+ pad_f_right,
+ pad_t_top,
+ pad_t_bottom,
+ )
+ calculated_block_padding.append(manual_padding_tuple)
+
+ # Calculate output frequency dimension after this convolution
+ # This uses the actual padding applied and kernel/stride.
+ f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
+ f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 # Assuming dilation_w = 1
+ calculated_f_out_dims.append(f_out_after_conv)
+ current_f_for_block_input = f_out_after_conv
+
+ self.conv_0 = Gemma3nAudioSSCPConvBlock(
+ idx=0,
+ input_freq_dim=config.input_feat_size, # Pass original feature dim
+ config=config,
+ manual_padding=calculated_block_padding[0],
+ )
+ self.conv_1 = Gemma3nAudioSSCPConvBlock(
+ idx=1,
+ input_freq_dim=calculated_f_out_dims[0], # Output freq dim from conv_0
+ config=config,
+ manual_padding=calculated_block_padding[1],
+ )
+ final_c_out = config.sscp_conv_channel_size[-1]
+ final_f_out = calculated_f_out_dims[-1] # Final frequency dimension
+ self.input_proj_in_features = final_c_out * final_f_out
+ self.input_proj_linear = nn.Linear(self.input_proj_in_features, self.config.hidden_size, bias=False)
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ # audio_encodings is [B, T, F_in]
+ # Reshape to [B, 1, T, F_in] (Batch, Channels=1, Height=Time, Width=F_in)
+ audio_encodings_reshaped = audio_encodings.unsqueeze(1)
+ x = self.conv_0(audio_encodings_reshaped)
+ x = self.conv_1(x)
+ # x from conv_1 is [B, C_out_1, T_out_1, F_out_1]
+ b, c_out, t_out, f_out = x.shape
+ # Permute to [B, T_out_1, F_out_1, C_out_1] then flatten F_out_1 and C_out_1
+ x_permuted = x.permute(0, 2, 3, 1).contiguous()
+ output_flattened = x_permuted.view(b, t_out, f_out * c_out)
+ output = self.input_proj_linear(output_flattened)
+ return output
+
+
+class Gemma3nAudioConformerAttention(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+ self.post_in_features = self.config.hidden_size
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+ self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
+ self.attn = Gemma3nAudioAttention(config)
+ self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
+ self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
+
+ def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
+ audio_encodings_input_to_attn = audio_encodings
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings_norm = self.pre_attn_norm(audio_encodings)
+ # Output of self.attn is [B, T, NumHeads, HeadDim]
+ audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask)
+
+ # Reshape from [B, T, NumHeads, HeadDim] to [B, T, NumHeads * HeadDim]
+ # NumHeads * HeadDim = hidden_size
+ b, t, num_heads, head_dim = audio_encodings_attn_out.shape
+ audio_encodings_reshaped = audio_encodings_attn_out.reshape(b, t, num_heads * head_dim)
+
+ audio_encodings = self.post(audio_encodings_reshaped)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ return audio_encodings_input_to_attn + self.post_norm(audio_encodings)
+
+
+class Gemma3nAudioConformerFeedForward(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+
+ self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
+ self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
+ self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
+ self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
+ self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ residual = audio_encodings
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings = self.pre_layer_norm(audio_encodings)
+ audio_encodings: torch.Tensor = self.ffw_layer_1(audio_encodings)
+ audio_encodings = nn.functional.silu(audio_encodings)
+ audio_encodings: torch.Tensor = self.ffw_layer_2(audio_encodings)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings = self.post_layer_norm(audio_encodings)
+ return residual + (audio_encodings * self.post_layer_scale)
+
+
+class Gemma3nAudioConformerLightConv1d(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.linear_start = nn.Linear(self.config.hidden_size, self.config.hidden_size * 2, bias=False)
+ self.depthwise_conv1d = nn.Conv1d(
+ in_channels=self.config.hidden_size,
+ out_channels=self.config.hidden_size,
+ kernel_size=self.config.conf_conv_kernel_size,
+ stride=1,
+ padding=0, # Manual causal padding
+ groups=self.config.hidden_size, # Depthwise
+ bias=False,
+ )
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+ self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
+
+ self.causal_padding = self.config.conf_conv_kernel_size - 1
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ audio_encodings_residual = audio_encodings # Save for residual connection
+
+ audio_encodings = self.pre_layer_norm(audio_encodings)
+ audio_encodings = self.linear_start(audio_encodings)
+ audio_encodings = torch.nn.functional.glu(audio_encodings, dim=-1)
+ # Permute for Conv1d: [B, T, D] -> [B, D, T]
+ audio_encodings_permuted = audio_encodings.permute(0, 2, 1)
+ # Apply manual causal padding
+ audio_encodings_permuted_padded = F.pad(audio_encodings_permuted, (self.causal_padding, 0))
+ audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded)
+ # Permute back: [B, D, T_out] -> [B, T_out, D]
+ audio_encodings = audio_encodings.permute(0, 2, 1)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings = self.conv_norm(audio_encodings)
+ audio_encodings = nn.functional.silu(audio_encodings)
+ audio_encodings = self.linear_end(audio_encodings)
+ output = audio_encodings + audio_encodings_residual
+ return output
+
+
+class Gemma3nAudioConformerBlock(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.ffw_layer_start = Gemma3nAudioConformerFeedForward(self.config)
+ self.attention = Gemma3nAudioConformerAttention(self.config)
+ self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
+ self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+ self.norm = Gemma3nRMSNorm(self.config.hidden_size)
+
+ def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
+ audio_encodings = self.ffw_layer_start(audio_encodings)
+ audio_encodings = self.attention(audio_encodings, audio_mel_mask)
+ validity_mask_for_lconv = ~audio_mel_mask # True for valid
+ audio_encodings_for_lconv_input = audio_encodings * validity_mask_for_lconv.unsqueeze(-1).to(
+ audio_encodings.dtype
+ )
+ audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
+
+ audio_encodings = self.ffw_layer_end(audio_encodings)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ output = self.norm(audio_encodings)
+ return output
+
+
+class Gemma3nAudioEncoder(PreTrainedModel):
+ """
+ An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture.
+ """
+
+ config: Gemma3nAudioConfig
+
+ main_input_name = "audio_mel"
+
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
+ self.conformer = nn.ModuleList(
+ [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
+ )
+
+ def forward(
+ self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
+ ) -> tuple[torch.Tensor, torch.BoolTensor]:
+ """Encodes a batch of MELs.
+
+ Args:
+ audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
+ mel_bins].
+
+ Returns:
+ audio_encodings: a torch.Tensor of shape
+ `[batch_size, self.config.audio_soft_tokens_per_image,
+ self.config.audio_config.hidden_size]`
+ audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
+ """
+ audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
+
+ # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
+ t_sub = audio_encodings.shape[1]
+
+ time_stride_product = 1
+ for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
+ time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
+
+ # Create indices for gathering from the original mask.
+ # These indices map to original time steps corresponding to the start of each
+ # receptive field in the subsampled output.
+ indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
+ indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) # Ensure indices are valid
+
+ # Expand indices for batch compatibility if B > 1 and indices is 1D.
+ if audio_mel_mask.ndim > 1 and indices.ndim == 1:
+ indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) # [B, T_sub]
+ elif (
+ audio_mel_mask.ndim == indices.ndim
+ and audio_mel_mask.shape[0] == 1
+ and indices.shape[0] != 1
+ and t_sub == indices.shape[0]
+ ):
+ # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
+ indices = indices.unsqueeze(0)
+
+ current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub]
+
+ for block in self.conformer:
+ audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
+
+ if self.config.conf_reduction_factor > 1:
+ audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
+ # Reduce the mask as well
+ current_mask = current_mask[:, :: self.config.conf_reduction_factor]
+
+ audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
+ return audio_encodings, current_mask
+
+
+class Gemma3nTextScaledWordEmbedding(nn.Embedding):
+ """
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
+ self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
+
+ def forward(self, input_ids: torch.Tensor):
+ return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
+
+
+class Gemma3nTextLaurelBlock(nn.Module):
+ """Learned Augmented Residual Layer"""
+
+ def __init__(self, config: Gemma3nTextConfig):
+ super().__init__()
+ self.config = config
+
+ self.linear_left = nn.Linear(self.config.hidden_size, self.config.laurel_rank, bias=False)
+ self.linear_right = nn.Linear(self.config.laurel_rank, self.config.hidden_size, bias=False)
+ self.post_laurel_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ laurel_hidden_states: torch.Tensor = self.linear_left(hidden_states)
+ laurel_hidden_states: torch.Tensor = self.linear_right(laurel_hidden_states)
+ normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states)
+ return hidden_states + normed_laurel_hidden_states
+
+
+class Gemma3nTextMLP(nn.Module):
+ def __init__(self, config: Gemma3nTextConfig, layer_idx: int = 0):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size[layer_idx]
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_activation]
+ self.activation_sparsity = config.activation_sparsity_pattern[layer_idx]
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ gate_proj = self.gate_proj(hidden_states)
+ if self.activation_sparsity > 0.0:
+ gate_proj = self._gaussian_topk(gate_proj)
+ activations = self.act_fn(gate_proj)
+ up_proj = self.up_proj(hidden_states)
+ down_proj = self.down_proj(activations * up_proj)
+ return down_proj
+
+ def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
+ target_sparsity_tensor = torch.tensor(self.activation_sparsity, dtype=torch.float32, device=inputs.device)
+ # normal_dist and std_multiplier are adapted from jax.scipy.stats.norm.ppf().
+ #
+ # References:
+ # * https://docs.jax.dev/en/latest/_autosummary/jax.scipy.stats.norm.ppf.html
+ # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.normal.Normal
+ # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.transformed_distribution.TransformedDistribution.icdf
+ normal_dist = torch.distributions.normal.Normal(0, 1)
+ std_multiplier: torch.Tensor = normal_dist.icdf(target_sparsity_tensor)
+ std_multiplier = std_multiplier.type(inputs.dtype)
+ inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
+ inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
+ cutoff_x = inputs_mean + inputs_std * std_multiplier
+ return nn.functional.relu(inputs - cutoff_x)
+
+
+class Gemma3nTextAltUp(nn.Module):
+ """Alternating Updates (AltUp)
+
+ The AltUp module wraps transformer layers. The `predict` step modifies the
+ input to the transformer layer, and the `correct` step propagates the output
+ of the transformer layer to the sparsely updated dimensions.
+
+ See more in the research paper:
+
+ https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
+ """
+
+ def __init__(self, config: Gemma3nTextConfig):
+ super().__init__()
+ self.config = config
+ self.correct_output_scale = nn.Parameter(torch.zeros(self.config.hidden_size))
+ self.correction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs, bias=False)
+ self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
+ self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
+ self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.register_buffer("router_input_scale", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
+
+ def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
+ router_inputs = self.router_norm(x) * self.router_input_scale
+ routed = self.modality_router(router_inputs)
+ return torch.tanh(routed.float()).type_as(x)
+
+ def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Predicts the output of a layer using a trainable map.
+
+ Args:
+ hidden_states: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
+ stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
+
+ Returns:
+ A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` containing the predictions.
+ """
+ modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
+
+ if self.training and self.config.altup_coef_clip is not None:
+ self.prediction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
+
+ # Project and then transpose all 2D matrices contained so that mulmat gives the correct result
+ all_coefs: torch.Tensor = (
+ self.prediction_coefs(modalities)
+ .reshape(*modalities.shape[:-1], self.config.altup_num_inputs, self.config.altup_num_inputs)
+ .permute(0, 1, 3, 2)
+ )
+
+ # permute hidden_states to [batch_size, num_tokens, hidden_size, altup_num_inputs]
+ predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs)
+ predictions = predictions.permute(3, 0, 1, 2) # undo the permute
+ predictions += hidden_states # add the original input
+ return predictions.contiguous().type_as(hidden_states)
+
+ def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
+ """Corrects the predictions relative to the
+
+ Args:
+ predictions: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
+ stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
+ activated: A 3D tensor of shape `[batch_size, num_tokens, hidden_size]` containing the activated inputs.
+
+ Returns:
+ A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` correcting the original
+ predictions relative to the activated input embeddings.
+ """
+ modalities = self.compute_router_modalities(activated)
+ innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
+ innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
+
+ if self.config.altup_coef_clip is not None:
+ self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
+
+ # all_coefs adapted from jax.numpy.einsum("...p,pi->...i", ...)
+ # Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
+ # and expand on dim1 for broadcastability
+ all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0
+ all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1)
+
+ corrected = torch.mul(innovation, all_coefs)
+ corrected += predictions # add the original input
+ return corrected.contiguous().type_as(activated)
+
+ def forward(self, corrected: torch.Tensor) -> torch.Tensor:
+ """
+ This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
+ (which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
+ `scale_corrected_output`
+ """
+ return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
+
+ def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
+ """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
+ return self.forward(corrected)
+
+
+class Gemma3nTextRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Gemma3nTextConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ dropout: float = 0.0,
+ scaling: Optional[float] = None,
+ softcap: Optional[float] = None,
+ **kwargs,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ if scaling is None:
+ scaling = module.head_dim**-0.5
+
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+
+ if softcap is not None:
+ attn_weights = attn_weights / softcap
+ attn_weights = torch.tanh(attn_weights)
+ attn_weights = attn_weights * softcap
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, attn_weights
+
+
+def apply_rotary_pos_emb(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ position_ids: Optional[torch.Tensor] = None,
+ unsqueeze_dim: int = 1,
+):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ x (`torch.Tensor`): The tensor to embed.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ return (x * cos) + (rotate_half(x) * sin)
+
+
+class Gemma3nTextAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
+ super().__init__()
+ self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.attention_dropout = self.config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.sliding_window = config.sliding_window if self.is_sliding else None
+
+ self.q_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
+ self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
+
+ first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
+ self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
+ prev_layers = config.layer_types[:first_kv_shared_layer_idx]
+ if self.is_kv_shared_layer:
+ # For shared layers, find the last non-shared layer of the same type before sharing starts
+ self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
+ self.store_full_length_kv = False
+ else:
+ self.kv_shared_layer_index = None
+ # For non-shared layers, store full-length kv if this is the last non-shared layer of its type
+ self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
+ config.layer_types[layer_idx]
+ )
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.config.head_dim)
+
+ cos, sin = position_embeddings
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
+ query_states = self.q_norm(query_states)
+ query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
+ query_states = query_states.transpose(1, 2)
+
+ # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer
+ if self.is_kv_shared_layer and past_key_values is not None:
+ key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
+ # Device of past layer may be different from current one
+ key_states = key_states.to(query_states.device)
+ value_states = value_states.to(query_states.device)
+ else:
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
+ key_states = self.k_norm(key_states)
+ key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
+ key_states = key_states.transpose(1, 2)
+
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
+ value_states = self.v_norm(value_states)
+ value_states = value_states.transpose(1, 2)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "cache_position": cache_position,
+ "sliding_window": self.sliding_window,
+ }
+ if not self.is_kv_shared_layer:
+ key_states, value_states = past_key_values.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+ if self.store_full_length_kv:
+ if not hasattr(past_key_values, "shared_layers"):
+ past_key_values.shared_layers = {}
+ past_key_values.shared_layers[self.layer_idx] = key_states, value_states
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=self.attention_dropout if self.training else 0.0,
+ scaling=1.0,
+ sliding_window=self.sliding_window,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.layer_idx = layer_idx
+ self.attention_type = config.layer_types[layer_idx]
+ self.self_attn = Gemma3nTextAttention(config, layer_idx)
+ self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx)
+ self.input_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+ self.pre_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+ self.post_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
+ self.act_fn = ACT2FN[config.hidden_activation]
+
+ self.altup = Gemma3nTextAltUp(config)
+ self.laurel = Gemma3nTextLaurelBlock(config)
+ self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
+ self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
+ self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings_global: torch.Tensor,
+ position_embeddings_local: torch.Tensor,
+ per_layer_input: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ predictions = self.altup.predict(hidden_states)
+ active_prediction = predictions[self.config.altup_active_idx]
+
+ active_prediction_normed = self.input_layernorm(active_prediction)
+ laurel_output = self.laurel(active_prediction_normed)
+
+ # apply global RoPE to non-sliding layer only
+ if self.self_attn.is_sliding:
+ position_embeddings = position_embeddings_local
+ else:
+ position_embeddings = position_embeddings_global
+
+ attn, self_attn_weights = self.self_attn(
+ hidden_states=active_prediction_normed,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ attn = self.post_attention_layernorm(attn)
+
+ attn_gated = active_prediction + attn
+ attn_laurel = (attn_gated + laurel_output) / math.sqrt(2)
+
+ attn_norm = self.pre_feedforward_layernorm(attn_laurel)
+ attn_ffw = self.mlp(attn_norm)
+ attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
+ attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
+ corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
+
+ first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
+ if self.config.altup_correct_scale:
+ first_prediction = self.altup.scale_corrected_output(first_prediction)
+
+ # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
+ first_prediction = self.per_layer_input_gate(first_prediction)
+ first_prediction = self.act_fn(first_prediction)
+ first_prediction = torch.multiply(first_prediction, per_layer_input)
+
+ # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
+ first_prediction = self.per_layer_projection(first_prediction)
+ first_prediction = self.post_per_layer_input_norm(first_prediction)
+ corrected_predictions[1:] += first_prediction
+
+ outputs = (corrected_predictions,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class Gemma3nPreTrainedModel(PreTrainedModel):
+ config: Gemma3nConfig
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Gemma3nTextDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": Gemma3nTextDecoderLayer,
+ "attentions": Gemma3nTextAttention,
+ }
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, Gemma3nAudioCumulativeGroupNorm):
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Gemma3nAudioAttention):
+ module.per_dim_scale.data.zero_()
+ elif isinstance(module, Gemma3nTextAltUp):
+ module.correct_output_scale.data.zero_()
+
+
+@auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.")
+class Gemma3nTextModel(Gemma3nPreTrainedModel):
+ config: Gemma3nTextConfig
+
+ def __init__(self, config: Gemma3nTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ # Gemma3n downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
+ self.embed_tokens = Gemma3nTextScaledWordEmbedding(
+ config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
+ )
+ self.layers = nn.ModuleList(
+ [Gemma3nTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+
+ self.norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Gemma3nTextRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # TODO (raushan): Fix this after RoPE refactor. For now we hack it by
+ # reassigning thetas when we want to create a local RoPE layer. Config
+ # defaults should hold values for global RoPE.
+ config = copy.deepcopy(config)
+ config.rope_theta = config.rope_local_base_freq
+ config.rope_scaling = {"rope_type": "default"}
+ self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config)
+
+ self.hidden_size = config.hidden_size
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
+
+ self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding(
+ config.vocab_size_per_layer_input,
+ config.num_hidden_layers * config.hidden_size_per_layer_input,
+ self.padding_idx,
+ embed_scale=config.hidden_size_per_layer_input**0.5,
+ )
+
+ self.per_layer_model_projection = nn.Linear(
+ self.hidden_size,
+ config.num_hidden_layers * config.hidden_size_per_layer_input,
+ bias=False,
+ )
+
+ self.per_layer_projection_norm = Gemma3nRMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps)
+
+ self.altup_projections = nn.ModuleList(
+ [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
+ )
+
+ self.altup_unembed_projections = nn.ModuleList(
+ [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
+ )
+
+ self.register_buffer("per_layer_projection_scale", torch.tensor(self.hidden_size**-0.5), persistent=False)
+ self.register_buffer("per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ per_layer_inputs: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ r"""
+ per_layer_inputs (torch.Tensor, *optional*, defaults to None):
+ Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if input_ids is not None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ per_layer_inputs = self.get_per_layer_inputs(input_ids)
+
+ per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
+
+ if use_cache and past_key_values is None and not self.training:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ # Prepare mask arguments
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ # Create the masks
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ # embed positions
+ hidden_states_0 = inputs_embeds
+
+ # Initialize RoPE embeddings
+ position_embeddings_global = self.rotary_emb(hidden_states_0, position_ids)
+ position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
+
+ # Expand hidden_states to support per-layer inputs
+ target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
+ epsilon_tensor = torch.tensor(1e-5)
+
+ temp_hidden_states = [hidden_states_0]
+ for i in range(1, self.config.altup_num_inputs):
+ # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
+ altup_proj = self.altup_projections[i - 1](hidden_states_0)
+ current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
+ new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
+ new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
+ current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
+ temp_hidden_states.append(current_hidden_state)
+
+ hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ causal_mask = causal_mask_mapping[decoder_layer.attention_type]
+ per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :]
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ position_embeddings_global,
+ position_embeddings_local,
+ per_layer_input,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ # add hidden states from the last decoder layer (but before reprojecting to stay consistent with layer output)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # Per-layer inputs to single output
+ target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
+ temp_hidden_states = [hidden_states[0]]
+ for i in range(1, self.config.altup_num_inputs):
+ # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
+ altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
+ current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
+ new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
+ new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
+ current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
+ temp_hidden_states.append(current_hidden_state)
+
+ hidden_states = torch.stack(temp_hidden_states)
+ hidden_states = torch.mean(hidden_states, dim=0)
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor:
+ return self.embed_tokens_per_layer(input_ids).reshape(
+ *input_ids.shape,
+ self.config.num_hidden_layers,
+ self.hidden_size_per_layer_input,
+ )
+
+ def project_per_layer_inputs(
+ self,
+ inputs_embeds: torch.Tensor,
+ per_layer_inputs: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
+ per_layer_projection *= self.per_layer_projection_scale.to(
+ dtype=inputs_embeds.dtype, device=per_layer_projection.device
+ )
+ per_layer_projection = per_layer_projection.reshape(
+ *inputs_embeds.shape[:-1],
+ self.config.num_hidden_layers,
+ self.hidden_size_per_layer_input,
+ )
+ per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
+
+ if per_layer_inputs is None:
+ return per_layer_projection
+
+ if per_layer_projection.shape != per_layer_inputs.shape:
+ # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
+ per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
+
+ return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
+ dtype=inputs_embeds.dtype, device=per_layer_projection.device
+ )
+
+
+@auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")
+class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+ config: Gemma3nTextConfig
+ base_model_prefix = "model"
+ _checkpoint_conversion_mapping = {"model.language_model": "model"}
+
+ def __init__(self, config: Gemma3nTextConfig):
+ super().__init__(config)
+ self.model = Gemma3nTextModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Gemma3nForCausalLM
+
+ >>> model = Gemma3nForCausalLM.from_pretrained("google/gemma-2-9b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
+
+ >>> prompt = "What is your favorite condiment?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "What is your favorite condiment?"
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ if self.config.final_logit_softcapping is not None:
+ logits = logits / self.config.final_logit_softcapping
+ logits = torch.tanh(logits)
+ logits = logits * self.config.final_logit_softcapping
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class Gemma3nMultimodalEmbedder(nn.Module):
+ """Embeds token ids or soft tokens for multimodal content into language model space."""
+
+ def __init__(
+ self,
+ multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
+ text_config: Gemma3nTextConfig,
+ ):
+ super().__init__()
+
+ self.multimodal_hidden_size = multimodal_config.hidden_size
+ self.eps = multimodal_config.rms_norm_eps
+ self.vocab_offset = multimodal_config.vocab_offset
+ self.vocab_size = multimodal_config.vocab_size
+ self.text_hidden_size = text_config.hidden_size
+
+ self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size)
+ self.hard_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
+ self.soft_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
+ self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
+ self.embedding_post_projection_norm = Gemma3nRMSNorm(self.text_hidden_size, eps=self.eps, with_scale=False)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Embeds token ids or soft tokens for multimodal content into language model space.
+
+ Args:
+ input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
+ `[vocab_offset, vocab_offset + vocab_size)`.
+ inputs_embeds: A torch.Tensor containing the soft tokens to embed.
+
+ Returns:
+ A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is not None:
+ emb_norm = self.soft_embedding_norm(inputs_embeds)
+ else:
+ hard_emb = self.embedding(input_ids - self.vocab_offset)
+ emb_norm = self.hard_embedding_norm(hard_emb)
+
+ emb_norm_proj = self.embedding_projection(emb_norm)
+ return self.embedding_post_projection_norm(emb_norm_proj)
+
+
+@auto_docstring(
+ custom_intro="""
+ The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
+ language modeling head.
+ """
+)
+class Gemma3nModel(Gemma3nPreTrainedModel):
+ _checkpoint_conversion_mapping = {}
+ # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
+ accepts_loss_kwargs = False
+
+ def __init__(self, config: Gemma3nConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
+ self.vocab_size = config.text_config.vocab_size
+
+ language_model = AutoModel.from_config(config=config.text_config)
+ self.language_model = language_model
+
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
+ self.audio_tower = AutoModel.from_config(config.audio_config)
+ self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config)
+ self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ """
+ Projects the last hidden state from the vision model into language model space.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
+ The tensors corresponding to the input images.
+
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ vision_outputs = self.vision_tower(
+ pixel_values=pixel_values, do_pooling=False, return_dict=True
+ ).last_hidden_state
+ # Convert from (batch, channels, height, width) to (batch, height * width, channels) where:
+ # height == width and height * width == Gemma3nConfig.vision_soft_tokens_per_image.
+ vision_outputs = vision_outputs.reshape(
+ vision_outputs.shape[0],
+ self.config.vision_config.hidden_size,
+ self.config.vision_soft_tokens_per_image,
+ ).permute(0, 2, 1)
+ # Normalize and embed the soft tokens into language model space.
+ vision_outputs *= self.config.vision_config.hidden_size**0.5
+ return self.embed_vision(inputs_embeds=vision_outputs)
+
+ def get_placeholder_mask(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ image_features: Optional[torch.FloatTensor] = None,
+ audio_features: Optional[torch.FloatTensor] = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_audio_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ ).all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_audio_mask = input_ids == self.config.audio_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0] * image_features.shape[1]}"
+ )
+
+ n_audio_tokens = special_audio_mask.sum()
+ special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if audio_features is not None and inputs_embeds[special_audio_mask].numel() != audio_features.numel():
+ raise ValueError(
+ f"Audio features and image tokens do not match: tokens: {n_audio_tokens}, features {audio_features.shape[0] * audio_features.shape[1]}"
+ )
+
+ return special_image_mask, special_audio_mask
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None, # text inputs
+ pixel_values: Optional[torch.FloatTensor] = None, # vision inputs
+ input_features: Optional[torch.FloatTensor] = None, # audio inputs
+ attention_mask: Optional[torch.Tensor] = None,
+ input_features_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ **lm_kwargs,
+ ) -> Gemma3nCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
+
+ >>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224")
+ >>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224")
+
+ >>> prompt = "Where is the cat standing?"
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs,)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Where is the cat standing?\nsnow"
+ ```
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if input_ids is not None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # Prepare per-layer inputs from inputs_ids
+ per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input)
+ per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids))
+ per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens)
+
+ # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset)
+ vision_mask = torch.logical_and(
+ input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset
+ )
+ dummy_vision_token_id = self.embed_vision.vocab_offset + self.embed_vision.vocab_size - 1
+ vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device)
+ vision_embeds = self.embed_vision(input_ids=vision_input_ids)
+ expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds)
+ inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds)
+
+ # Handle audio tokens (>= embed_audio.vocab_offset)
+ audio_mask = input_ids >= self.embed_audio.vocab_offset
+ dummy_audio_token_id = self.embed_audio.vocab_offset + self.embed_audio.vocab_size - 1
+ audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device)
+ audio_embeds = self.embed_audio(input_ids=audio_input_ids)
+ expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
+ inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds)
+ else:
+ per_layer_inputs = None
+
+ # Merge text and images
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values)
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ # Merge text and audio
+ if input_features is not None and input_features_mask is not None:
+ audio_features, audio_mask = self.get_audio_features(input_features, ~input_features_mask)
+
+ # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
+ # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
+ # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
+ # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
+ # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
+ audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device)
+ audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
+ audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features)
+
+ audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
+ extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len
+ extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim)
+
+ audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
+ audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ _, special_audio_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
+
+ outputs = self.language_model(
+ input_ids=None,
+ per_layer_inputs=per_layer_inputs,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ return Gemma3nModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values if use_cache else None,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ audio_hidden_states=audio_features if input_features is not None else None,
+ )
+
+ def get_audio_features(
+ self, input_features: torch.Tensor, input_features_mask: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Projects the last hidden state from the audio encoder into language model space.
+
+ Args:
+ input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`):
+ The tensors corresponding to the input audio.
+ input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
+ The attention mask for the input audio.
+
+ Returns:
+ audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_images, audio_length, embed_dim)`).
+ """
+ audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask)
+ return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
+
+
+@auto_docstring(
+ custom_intro="""
+ The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
+ head.
+ """
+)
+class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+ _tied_weights_keys = ["lm_head.weight"]
+ base_model_prefix = "model"
+
+ def __init__(self, config: Gemma3nConfig):
+ super().__init__(config)
+ self.model = Gemma3nModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_image_features(self, pixel_values):
+ return self.model.get_image_features(pixel_values)
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ raise AttributeError("Use embed_vision instead of multi_modal_projector.")
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None, # text inputs
+ pixel_values: Optional[torch.FloatTensor] = None, # vision inputs
+ input_features: Optional[torch.FloatTensor] = None, # audio inputs
+ attention_mask: Optional[torch.Tensor] = None,
+ input_features_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **lm_kwargs,
+ ) -> Gemma3nCausalLMOutputWithPast:
+ r"""
+ input_features_mask (torch.Tensor, *optional*, defaults to None):
+ The attention mask for the input audio.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels in
+ `[0, ..., config.text_config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
+
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
+ >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
+
+ >>> messages = [
+ ... {
+ ... "role": "system",
+ ... "content": [
+ ... {"type": "text", "text": "You are a helpful assistant."}
+ ... ]
+ ... },
+ ... {
+ ... "role": "user", "content": [
+ ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
+ ... {"type": "text", "text": "Where is the cat standing?"},
+ ... ]
+ ... },
+ ... ]
+
+ >>> inputs = processor.apply_chat_template(
+ ... messages,
+ ... tokenizer=True,
+ ... return_dict=True,
+ ... return_tensors="pt",
+ ... add_generation_prompt=True
+ ... )
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ input_features=input_features,
+ attention_mask=attention_mask,
+ input_features_mask=input_features_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ token_type_ids=token_type_ids,
+ cache_position=cache_position,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ **lm_kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None:
+ logits = logits / final_logit_softcapping
+ logits = torch.tanh(logits)
+ logits = logits * final_logit_softcapping
+
+ loss = None
+ if labels is not None:
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ shift_logits = logits[..., :-1, :]
+ shift_labels = labels[..., 1:]
+ if attention_mask is not None:
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
+ else:
+ shift_logits = shift_logits.contiguous()
+ shift_labels = shift_labels.contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss()
+
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
+ loss = loss_fct(flat_logits, flat_labels)
+
+ return Gemma3nCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ audio_hidden_states=outputs.audio_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ pixel_values=None,
+ input_features=None,
+ attention_mask=None,
+ input_features_mask=None,
+ token_type_ids=None,
+ use_cache=True,
+ logits_to_keep=None,
+ labels=None,
+ **kwargs,
+ ):
+ # Overwritten -- custom `position_ids` and `pixel_values` handling
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ use_cache=use_cache,
+ logits_to_keep=logits_to_keep,
+ token_type_ids=token_type_ids,
+ **kwargs,
+ )
+
+ # If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
+ # tokens anymore. Otherwise multimodal inputs should be passed to model.
+ # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
+ if cache_position[0] == 0:
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["input_features"] = input_features
+ model_inputs["input_features_mask"] = input_features_mask
+
+ return model_inputs
+
+ @property
+ def audio_tower(self):
+ return self.model.audio_tower
+
+
+__all__ = [
+ "Gemma3nAudioEncoder",
+ "Gemma3nForCausalLM",
+ "Gemma3nForConditionalGeneration",
+ "Gemma3nModel",
+ "Gemma3nPreTrainedModel",
+ "Gemma3nTextModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gemma3n/modular_gemma3n.py b/venv/lib/python3.13/site-packages/transformers/models/gemma3n/modular_gemma3n.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ea50b7572cf21c9cae1c740c799e936f9d87d34
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gemma3n/modular_gemma3n.py
@@ -0,0 +1,2684 @@
+# coding=utf-8
+# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import copy
+import math
+from collections.abc import Callable, Sequence
+from typing import Any, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutputWithPast
+from ...modeling_rope_utils import rope_config_validation
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.deprecation import deprecate_kwarg
+from ..auto import AutoModel
+from ..gemma2.configuration_gemma2 import Gemma2Config
+from ..gemma2.modeling_gemma2 import (
+ Gemma2MLP,
+ Gemma2PreTrainedModel,
+ Gemma2RotaryEmbedding,
+ eager_attention_forward,
+ rotate_half,
+)
+from ..gemma3.modeling_gemma3 import (
+ Gemma3Attention,
+ Gemma3DecoderLayer,
+ Gemma3ForCausalLM,
+ Gemma3RMSNorm,
+ Gemma3TextModel,
+ Gemma3TextScaledWordEmbedding,
+)
+from ..paligemma.modeling_paligemma import (
+ PaliGemmaCausalLMOutputWithPast,
+ PaliGemmaForConditionalGeneration,
+ PaliGemmaModel,
+ PaligemmaModelOutputWithPast,
+)
+from ..timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class Gemma3nTextConfig(Gemma2Config, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Gemma3nTextModel`]. It is used to instantiate an
+ Gemma3nTextModel model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.
+ [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
+
+ Configuration objects that inherit from [`Gemma3nTextConfig`] and can be used to control the model outputs. Read
+ the documentation from [`Gemma3nTextConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 262400):
+ Vocabulary size of the Gemma3nText model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`Gemma3nTextModel`]
+ vocab_size_per_layer_input (`int`, *optional*, defaults to 262144):
+ Vocabulary size of the per-layer text embeddings that augment the standard embeddings.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ hidden_size_per_layer_input (`int`, *optional*, defaults to 256):
+ Dimension of the hidden representations for per-layer emebeddings.
+ intermediate_size (`int` or `Sequence[int]`, *optional*, defaults to 16384):
+ Dimension of the MLP representations. MatFormer configurations may wish to provide a sequence of integers
+ to account for variable intermediate_size values across layers. In such cases,
+ `len(intermediate_size) == num_hidden_layers`.
+ num_hidden_layers (`int`, *optional*, defaults to 35):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 2):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout this
+ [paper](https://huggingface.co/papers/2305.13245). If not specified, will default to `num_attention_heads`.
+ head_dim (`int`, *optional*, defaults to 256):
+ The attention head dimension.
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the decoder. Will default to
+ `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"`
+ activation function.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 1):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 2):
+ Beginning of stream token id.
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings used in global attention.
+ NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we
+ recommend you to update this value accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ rope_local_base_freq (float, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings for local attention.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ sliding_window (`int`, *optional*, defaults to 512):
+ This is the size of the sliding window used by local attention layers.
+ layer_types (`Optional`, *optional*):
+ A sequence of strings defining the attention type for that layer as either "sliding_attention" or
+ "full_attention". If not provided, `layer_types` will de inferred from `num_hidden_layers` using a pattern
+ of four "sliding_attention" layers followed one "full_attention". The last layer in the model should always
+ be a "full_attention" layer.
+ final_logit_softcapping (`float`, *optional*, defaults to 30.0):
+ Scaling factor when applying tanh softcapping on the logits.
+ altup_active_idx (`int`, *optional*, defaults to 0):
+ The index of the prediction from which AltUp will compute additional predictions or correct
+ altup_coef_clip (`float`, *optional*, defaults to 120.0):
+ The maximum amplitude of an AltUp prediction or correction coefficient weight.
+ altup_correct_scale (`bool`, *optional*, defaults to `True`):
+ If True, apply the `AltUp.correct_output_scale` to the corrected prediction at `altup_active_idx`.
+ altup_num_inputs (`int`, *optional*, defaults to 4):
+ The number of predictions that AltUp should be make given the input sequence.
+ num_kv_shared_layers (`int`, *optional*, defaults to 15):
+ The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers`
+ layers in the model "share" the KV values in that each local and global layer in this range uses the KV
+ cache values computed for the last local or global layer, respectively, before entering this range. The
+ value should be a multiple of the attention pattern size (see `layer_types` parameter).
+ laurel_rank (int, *optional*, defaults to 64):
+ The intermediate size for the linear projections in the Learned Augmented Residual Layer.
+ activation_sparsity_pattern (Sequence[float], *optional*):
+ The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must
+ explicitly provide a sparsity value for each layer in the model. By default, the first 10 layers are
+ sparse with a sparsity factor of 0.95 and the rest are dense.
+
+ ```python
+ >>> from transformers import Gemma3nTextModel, Gemma3nTextConfig
+
+ >>> # Initializing a Gemma3nText gemma3n_text-E4B style configuration
+ >>> configuration = Gemma3nTextConfig()
+
+ >>> # Initializing a model from the gemma3n_text-E4B style configuration
+ >>> model = Gemma3nTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "gemma3n_text"
+
+ def __init__(
+ self,
+ vocab_size: int = 262_400,
+ vocab_size_per_layer_input: int = 262_144,
+ hidden_size: int = 2048,
+ hidden_size_per_layer_input: int = 256,
+ intermediate_size: Union[int, Sequence[int]] = 16_384,
+ num_hidden_layers: int = 35,
+ num_attention_heads: int = 8,
+ num_key_value_heads: int = 2,
+ head_dim: int = 256,
+ hidden_activation: str = "gelu_pytorch_tanh",
+ max_position_embeddings: int = 32_768,
+ initializer_range: float = 0.02,
+ rms_norm_eps: float = 1e-6,
+ use_cache: bool = True,
+ pad_token_id: int = 0,
+ eos_token_id: int = 1,
+ bos_token_id: int = 2,
+ rope_theta: float = 1_000_000.0,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ rope_local_base_freq: float = 10_000.0,
+ attention_bias: bool = False,
+ attention_dropout: float = 0.0,
+ sliding_window: int = 512,
+ layer_types: Optional[Sequence[str]] = None,
+ final_logit_softcapping: float = 30.0,
+ altup_active_idx: int = 0,
+ altup_coef_clip: float = 120.0,
+ altup_correct_scale: bool = True,
+ altup_num_inputs: int = 4,
+ num_kv_shared_layers: int = 15,
+ laurel_rank: int = 64,
+ activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = None,
+ **kwargs,
+ ):
+ PretrainedConfig.__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+
+ if isinstance(intermediate_size, Sequence) and (intsize_len := len(intermediate_size)) != num_hidden_layers:
+ raise ValueError(
+ "intermediate_size must have an explicit intermediate size for every layer or one for all layers. "
+ f"Expected {num_hidden_layers} values but got {intsize_len}."
+ )
+ elif not isinstance(intermediate_size, Sequence):
+ intermediate_size = [intermediate_size] * num_hidden_layers
+
+ self.vocab_size = vocab_size
+ self.vocab_size_per_layer_input = vocab_size_per_layer_input
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ self.num_key_value_heads = num_key_value_heads
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.hidden_activation = hidden_activation
+ self.sliding_window = sliding_window
+ self.final_logit_softcapping = final_logit_softcapping
+ self.layer_types = layer_types
+
+ self.rope_local_base_freq = rope_local_base_freq
+ self.rope_scaling = rope_scaling
+ rope_config_validation(self)
+
+ if layer_types is None:
+ self.layer_types = [
+ "full_attention" if (i + 1) % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers)
+ ]
+ else:
+ self.layer_types = layer_types
+
+ layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+ self.hidden_size_per_layer_input = hidden_size_per_layer_input
+ self.num_kv_shared_layers = num_kv_shared_layers
+
+ self.altup_active_idx = altup_active_idx
+ self.altup_coef_clip = altup_coef_clip
+ self.altup_correct_scale = altup_correct_scale
+ self.altup_num_inputs = altup_num_inputs
+
+ self.laurel_rank = laurel_rank
+
+ if activation_sparsity_pattern is None:
+ num_sparse_layers = 10 if num_hidden_layers > 10 else 0
+ activation_sparsity_pattern = [0.95] * num_sparse_layers + [0.0] * (num_hidden_layers - num_sparse_layers)
+
+ if (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers:
+ raise ValueError(
+ "activation_sparsity_pattern must have an explicit activation sparsity value for every layer."
+ f"Expected {num_hidden_layers} values but got {len_asp}."
+ )
+ self.activation_sparsity_pattern = activation_sparsity_pattern
+
+
+class Gemma3nAudioConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`]. It is used to instantiate
+ an `Gemma3nAudioEncoder` model according to the specified arguments, defining the model architecture. Instantiating
+ a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.,
+ [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
+
+ Configuration objects that inherit from [`Gemma3nAudioConfig`] and can be used to control the model outputs. Read
+ the documentation from [`Gemma3nAudioConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 128):
+ Vocabulary size of the additional hard-token embeddings for audio model. These augment the embeddings
+ included in the `Gemma3nTextModel` to provide, e.g., the end of audio and audio soft token placeholder
+ tokens when converting `input_ids` to embeddings in the `Gemma3nForConditionalGeneration` model.
+ vocab_offset (`int`, *optional*, defaults to 262272):
+ Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
+ 0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
+ input_feat_size (`int`, *optional*, defaults to 128):
+ The number of channels in each mel-spectrogram frame.
+ hidden_size (`int`, *optional*, defaults to 1536):
+ Dimension of the hidden representations.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ gradient_clipping (`float`, *optional*, defaults to 10000000000.0):
+ Clipping value used to stabilize extremely large gradient values.
+ conf_attention_chunk_size (`int`, *optional*, defaults to 12):
+ The sub-sequence size for local attention processing inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_attention_context_left (`int`, *optional*, defaults to 13):
+ The left context size of the local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_attention_context_right (`int`, *optional*, defaults to 0):
+ The right context size of the local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_attention_logit_cap (`float`, *optional*, defaults to 50.0):
+ Logit cap applied during local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_num_attention_heads (`int`, *optional*, defaults to 8):
+ The number of attention heads in local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_num_hidden_layers (`int`, *optional*, defaults to 12):
+ The number of layers that use local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_conv_kernel_size (`int`, *optional*, defaults to 5):
+ Convolution kernel size for the conformer block inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_reduction_factor (`int`, *optional*, defaults to 4):
+ Reduction factor used in the conformer block inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_residual_weight (`float`, *optional*, defaults to 0.5):
+ Residual connection weight inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`):
+ The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection
+ ("sscp") section of the Universal Speech Model.
+ sscp_conv_group_norm_eps (`float`, *optional*, defaults to 0.001):
+ Epsilon used in group normalization in the subsample convolution projection in the Sub-sample Convolution
+ Projection ("sscp") section of the Universal Speech Model.
+ sscp_conv_kernel_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((3, 3), (3, 3))`):
+ Kernel sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
+ Convolution Projection ("sscp") section of the Universal Speech Model. The kernel sizes are specified as a
+ tuple of height and width for each layer, where the height corresponds to the time dimension and the width
+ corresponds to the frequency dimension.
+ sscp_conv_stride_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((2, 2), (2, 2))`):
+ Stride sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
+ Convolution Projection ("sscp") section of the Universal Speech Model. The stride sizes are specified as a
+ tuple of height and width for each layer, where the height corresponds to the time dimension and the width
+ corresponds to the frequency dimension.
+
+ Example:
+
+ ```python
+ >>> from transformers import Gemma3nAudioConfig, Gemma3nAudioEncoder
+
+ >>> # Initializing a Gemma3nAudioEncoder gemma3n_audio-E4B-style configuration
+ >>> configuration = Gemma3nAudioConfig()
+
+ >>> # Initializing a model from the gemma3n_audio-E4B style configuration
+ >>> model = Gemma3nAudioEncoder(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "gemma3n_audio"
+
+ def __init__(
+ self,
+ vocab_size: int = 128,
+ vocab_offset: int = 262_144 + 128, # text vocab size + vision vocab size
+ input_feat_size: int = 128,
+ hidden_size: int = 1536,
+ rms_norm_eps: float = 1e-6,
+ gradient_clipping: float = 10_000_000_000.0,
+ conf_attention_chunk_size: int = 12,
+ conf_attention_context_left: int = 13,
+ conf_attention_context_right: int = 0,
+ conf_attention_logit_cap: float = 50.0,
+ conf_num_attention_heads: int = 8,
+ conf_num_hidden_layers: int = 12,
+ conf_conv_kernel_size: int = 5,
+ conf_reduction_factor: int = 4,
+ conf_residual_weight: float = 0.5,
+ sscp_conv_channel_size: tuple[int, int] = (128, 32),
+ sscp_conv_group_norm_eps: float = 1e-3,
+ sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = (
+ (3, 3),
+ (3, 3),
+ ),
+ sscp_conv_stride_size: tuple[tuple[int, int], tuple[int, int]] = (
+ (2, 2),
+ (2, 2),
+ ),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.input_feat_size = input_feat_size
+ self.hidden_size = hidden_size
+ self.rms_norm_eps = rms_norm_eps
+ self.vocab_size = vocab_size
+ self.vocab_offset = vocab_offset
+ self.gradient_clipping = gradient_clipping
+ self.conf_attention_chunk_size = conf_attention_chunk_size
+ self.conf_attention_context_left = conf_attention_context_left
+ self.conf_attention_context_right = conf_attention_context_right
+ self.conf_attention_logit_cap = conf_attention_logit_cap
+ self.conf_num_attention_heads = conf_num_attention_heads
+ self.conf_num_hidden_layers = conf_num_hidden_layers
+ self.conf_conv_kernel_size = conf_conv_kernel_size
+ self.conf_reduction_factor = conf_reduction_factor
+ self.conf_residual_weight = conf_residual_weight
+ self.sscp_conv_channel_size = sscp_conv_channel_size
+ self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps
+ self.sscp_conv_kernel_size = sscp_conv_kernel_size
+ self.sscp_conv_stride_size = sscp_conv_stride_size
+
+
+class Gemma3nVisionConfig(TimmWrapperConfig):
+ r"""
+ This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. It is used to
+ instantiate an timm model model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B
+ vision tower, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
+
+ Configuration objects inherit from [`Gemma3nVisionConfig`] and can be used to control the model outputs. Read the
+ documentation from [`Gemma3nVisionConfig`] for more information.
+
+ Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default
+ imagenet models is set to `None` due to occlusions in the label descriptions.
+
+ Args:
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ do_pooling (`bool`, *optional*, defaults to `False`):
+ Whether to do pooling for the last_hidden_state in `TimmWrapper` or not.
+ architecture (`str`, *optional*, defaults to `"mobilenetv5_300m_enc"`):
+ Determines vision architecture for TimmWrapper.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ vocab_size (`int`, *optional*, defaults to 128):
+ Vocabulary size of the additional hard-token embeddings for vision model.
+ vocab_offset (`int`, *optional*, defaults to 262144):
+ Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
+ 0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+
+ Example:
+ ```python
+ >>> from transformers import Gemma3nVisionConfig, TimmWrapper
+
+ >>> # Initializing a TimmWrapper gemma3n_vision-E4B-style configuration
+ >>> configuration = Gemma3nVisionConfig()
+
+ >>> # Initializing a gemma3n_vision-E4B-style TimmWrapper from the configuration
+ >>> model = TimmWrapper(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "gemma3n_vision"
+
+ def __init__(
+ self,
+ initializer_range: float = 0.02,
+ do_pooling: bool = False,
+ architecture: str = "mobilenetv5_300m_enc",
+ hidden_size: int = 2048,
+ vocab_size: int = 128,
+ vocab_offset: int = 262_144,
+ rms_norm_eps: float = 1e-06,
+ model_args: Optional[dict] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.architecture = architecture
+ self.initializer_range = initializer_range
+ self.do_pooling = do_pooling
+ self.hidden_size = hidden_size
+ self.vocab_size = vocab_size
+ self.vocab_offset = vocab_offset
+ self.rms_norm_eps = rms_norm_eps
+
+
+class Gemma3nConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Gemma3nForConditionalGeneration`]. It is used to
+ instantiate a Gemma3nForConditionalGeneration according to the specified arguments, defining the model
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
+ Gemma3n-E4B.
+
+ e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config (`Union[Gemma3nTextConfig, dict]`, *optional*):
+ The config object of the text backbone.
+ vision_config (`Union[AutoConfig, dict]`, *optional*):
+ Custom vision config or dict.
+ audio_config (`Union[AutoConfig, dict]`, *optional*):
+ Custom audio config or dict.
+ audio_soft_tokens_per_image (`int`, *optional*, defaults to 188):
+ The number of soft tokens per audio clip.
+ vision_soft_tokens_per_image (`int`, *optional*, defaults to 256):
+ The number of soft tokens per image.
+ boi_token_id (`int`, *optional*, defaults to 255999):
+ The begin-of-image token index to wrap the image prompt.
+ eoi_token_id (`int`, *optional*, defaults to 262144):
+ The end-of-image token index to wrap the image prompt.
+ image_token_id (`int`, *optional*, defaults to 262145):
+ The image token index to encode the image prompt.
+ boa_token_id (`int`, *optional*, defaults to 256000):
+ The begin-of-audio token index to wrap the audio prompt.
+ eoa_token_id (`int`, *optional*, defaults to 262272):
+ The end-of-audio token index to wrap the audio prompt.
+ audio_token_id (`int`, *optional*, defaults to 262273):
+ The audio token index to encode the audio prompt.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+
+ Example:
+
+ ```python
+ >>> from transformers import Gemma3nForConditionalGeneration, Gemma3nConfig, Gemma3nTextConfig
+
+ >>> # Initializing a MobileNet vision config, which is loaded from TIMM
+ >>> vision_config = Gemma3nVisionConfig()
+
+ >>> # Initializing a Gemma3n Audio config
+ >>> audio_config = Gemma3nAudioConfig()
+
+ >>> # Initializing a Gemma3n Text config
+ >>> text_config = Gemma3nTextConfig()
+
+ >>> # Initializing a Gemma3n gemma-3-4b style configuration
+ >>> configuration = Gemma3nConfig(text_config, vision_config, audio_config)
+
+ >>> # Initializing a model from the gemma-3-4b style configuration
+ >>> model = Gemma3nTextConfig(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "gemma3n"
+ sub_configs = {
+ "text_config": Gemma3nTextConfig,
+ "vision_config": Gemma3nVisionConfig,
+ "audio_config": Gemma3nAudioConfig,
+ }
+
+ def __init__(
+ self,
+ text_config: Optional[Union[Gemma3nTextConfig, dict[str, Any]]] = None,
+ vision_config: Optional[Union[Gemma3nVisionConfig, dict[str, Any]]] = None,
+ audio_config: Optional[Union[Gemma3nAudioConfig, dict[str, Any]]] = None,
+ audio_soft_tokens_per_image: int = 188,
+ vision_soft_tokens_per_image: int = 256,
+ boi_token_id: int = 255_999,
+ eoi_token_id: int = 262_144,
+ image_token_id: int = 262_145,
+ boa_token_id: int = 256_000,
+ eoa_token_id: int = 262_272,
+ audio_token_id: int = 262_273,
+ initializer_range: float = 0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if isinstance(text_config, dict):
+ text_config = Gemma3nTextConfig(**text_config)
+ elif text_config is None:
+ text_config = Gemma3nTextConfig()
+ logger.info("text_config is None. Using default Gemma3nTextConfig.")
+
+ if isinstance(vision_config, dict):
+ vision_config = Gemma3nVisionConfig(**vision_config)
+ elif vision_config is None:
+ vision_config = Gemma3nVisionConfig()
+ logger.info("vision_config is None. Using default Gemma3nVisionConfig.")
+
+ if isinstance(audio_config, dict):
+ audio_config = Gemma3nAudioConfig(**audio_config)
+ elif audio_config is None:
+ audio_config = Gemma3nAudioConfig()
+ logger.info("audio_config is None. Using default Gemma3nAudioConfig.")
+
+ self.text_config = text_config
+ self.vision_config = vision_config
+ self.audio_config = audio_config
+
+ self.audio_soft_tokens_per_image = audio_soft_tokens_per_image
+ self.vision_soft_tokens_per_image = vision_soft_tokens_per_image
+ self.boi_token_id = boi_token_id
+ self.eoi_token_id = eoi_token_id
+ self.image_token_id = image_token_id
+ self.boa_token_id = boa_token_id
+ self.eoa_token_id = eoa_token_id
+ self.audio_token_id = audio_token_id
+ self.initializer_range = initializer_range
+
+
+class Gemma3nModelOutputWithPast(PaligemmaModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
+ """
+
+ audio_hidden_states: Optional[torch.FloatTensor] = None
+
+
+class Gemma3nCausalLMOutputWithPast(PaliGemmaCausalLMOutputWithPast):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
+ """
+
+ audio_hidden_states: Optional[torch.FloatTensor] = None
+
+
+class Gemma3nRMSNorm(Gemma3RMSNorm):
+ def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
+ super().__init__(dim, eps=eps)
+ del self.weight
+ self.with_scale = with_scale
+
+ if self.with_scale:
+ self.weight = nn.Parameter(torch.ones(dim))
+ else:
+ self.register_buffer("weight", torch.tensor(1.0), persistent=False)
+
+ def _norm(self, x):
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
+ # See https://github.com/huggingface/transformers/pull/29402
+ output = self._norm(x.float()) * self.weight.float()
+ return output.type_as(x)
+
+
+# ==== Audio Encoder ====
+
+
+class Gemma3nAudioRelativePositionEmbedding(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.num_heads = self.config.conf_num_attention_heads
+ self.channels = self.config.hidden_size
+ self.head_dim = self.channels // self.num_heads
+ self.max_backward = max(0, self.config.conf_attention_context_left - 1)
+ self.max_forward = self.config.conf_attention_context_right
+
+ self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
+
+ min_timescale = 1.0
+ max_timescale = 1.0e4
+ num_timescales = self.channels // 2
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
+ inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
+ self.register_buffer(
+ "inv_timescales",
+ inv_timescales.float().unsqueeze(0).unsqueeze(0),
+ persistent=False,
+ )
+
+ def _get_timing_signal_1d_pos(self, position: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ position = position.float().unsqueeze(-1)
+ scaled_time = position * self.inv_timescales.to(device=position.device, dtype=torch.float32)
+ timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
+ return timing_signal.type(dtype)
+
+ def _relative_shift(
+ self,
+ term_bd_before_shift: torch.Tensor,
+ batch_size: int,
+ num_heads: int,
+ num_query_blocks: int,
+ query_block_size: int,
+ key_context_size: int,
+ max_span_plus_1: int,
+ ) -> torch.Tensor:
+ """Performs the relative shift.
+
+ Args:
+ term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
+ (B), num_heads (N), num_query_blocks (U), query_block_size (W),
+ key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1).
+
+ Returns:
+ Tensor of shape [B, N, U, W, C].
+ """
+ # term_bd_before_shift shape: [B, N, U, W, F_span]
+ # Target shape after shift: [B, N, U, W, C]
+
+ # Padding amount for the last dimension (F_span) to become (C + 1)
+ # C = key_context_size
+ # F_span = max_span_plus_1
+ pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
+
+ # PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...)
+ # We only pad the last dimension on the right.
+ padding_tuple = (0, pad_amount_last_dim)
+
+ term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
+ # Shape after pad: [B, N, U, W, C+1]
+
+ # Reshape for slicing (emulating JAX's behavior)
+ # [B, N, U, W * (C+1)]
+ term_bd_reshaped = term_bd_padded.reshape(
+ (
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size * (key_context_size + 1),
+ )
+ )
+
+ # Slice to effective [B, N, U, W * C]
+ term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size]
+
+ # Reshape back to [B, N, U, W, C]
+ term_bd_shifted = term_bd_sliced.reshape(
+ (
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size,
+ key_context_size,
+ )
+ )
+ return term_bd_shifted
+
+ def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
+ # queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim)
+ # keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim)
+ # C = W + L + R (key_context_size)
+ # F_span = L + R + 1 (max_span + 1)
+
+ batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
+ _, _, key_context_size, _, _ = keys.shape
+
+ # Relative positions for sinusoidal embeddings: [L, L-1, ..., -R]
+ # Length is L+R+1 = self.max_span + 1
+ pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze(
+ 0
+ ) # Shape [1, F_span]
+
+ max_span_plus_1 = pos_indices.shape[1] # F_span
+
+ sin_emb_timing_signal = self._get_timing_signal_1d_pos(
+ pos_indices, dtype=queries.dtype
+ ) # Shape [1, F_span, self.channels]
+
+ # Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H]
+ projected_sin_emb = self.pos_proj(sin_emb_timing_signal)
+ # Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H]
+ sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze(
+ 0
+ ) # Shape [F, N, H]
+
+ # term_ac: Query-Key content interaction
+ # queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul
+ # keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul
+ queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H]
+ keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C]
+ term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
+
+ # term_bd: Query-Position interaction
+ # Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
+ # queries shape: [B, U, W, N, H]
+ # sin_emb shape: [F, N, H]
+ # Target output shape: [B, N, U, W, F]
+
+ # Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb
+ q_permuted = queries.permute(0, 3, 1, 2, 4)
+
+ # Permute sin_emb to [N, H, F] to prepare for matmul
+ # sin_emb original is [F, N, H]
+ s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F]
+
+ # Reshape queries for matmul: [B, N, U*W, H]
+ q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
+
+ # Perform matmul: [B, N, U*W, H] @ [N, H, F]
+ # s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
+ # Result: [B, N, U*W, F]
+ term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
+
+ # Reshape to target [B, N, U, W, F]
+ term_bd_unshifed = term_bd_unshifed_matmul.reshape(
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size,
+ max_span_plus_1,
+ )
+
+ # Apply relative shift to term_bd_unshifed
+ term_bd_shifted = self._relative_shift(
+ term_bd_unshifed,
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size,
+ key_context_size,
+ max_span_plus_1,
+ ) # Shape [B, N, U, W, C]
+
+ return term_ac + term_bd_shifted
+
+
+class Gemma3nAudioAttention(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.num_heads = self.config.conf_num_attention_heads
+ self.hidden_size = self.config.hidden_size
+ self.head_dim = self.hidden_size // self.num_heads
+
+ self.chunk_size = self.config.conf_attention_chunk_size
+ self.max_future_horizon = self.config.conf_attention_context_right
+ self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
+ self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
+ self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
+
+ self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config)
+ self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+
+ q_scale = self.head_dim**-0.5
+ r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
+ self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
+
+ lower_causal_mask = torch.tril(
+ torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
+ diagonal=0,
+ ).T
+ upper_causal_mask = torch.tril(
+ torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
+ diagonal=self.max_past_horizon + self.max_future_horizon,
+ )
+ local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
+ local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
+ self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
+
+ self.register_buffer(
+ "softcap",
+ torch.tensor(self.attention_logits_soft_cap).float(),
+ persistent=False,
+ )
+
+ def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
+ batch, _, *tail_shape = x.shape
+ left = x.new_zeros((batch, pad_left, *tail_shape))
+ right = x.new_zeros((batch, pad_right, *tail_shape))
+ x = torch.cat([left, x, right], dim=1)
+ return x
+
+ def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Turns a sequence to non overlapping blocks.
+
+ Args:
+ hidden_states: a tensor of [batch, time, ...].
+
+ Returns:
+ A tensor of [batch, num_blocks, block_size, ...], with necessary
+ paddings,
+ where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
+ """
+ shape = hidden_states.shape
+ b, t = shape[:2]
+ num_blocks = (t + self.chunk_size - 1) // self.chunk_size
+
+ if (padding_len := num_blocks * self.chunk_size - t) > 0:
+ hidden_states = self._pad_dim1(hidden_states, 0, padding_len)
+
+ permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
+ hidden_states = hidden_states.reshape(permute_dims).contiguous()
+ return hidden_states
+
+ def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Extracts temporal context for every block.
+
+ Args:
+ hidden_states: a tensor of [batch, time, ...].
+
+ Returns:
+ A tensor of [batch, num_blocks, context_size, ...], with necessary
+ paddings,
+ where context_size = block_size + left_context + right_context,
+ and output[:, i, ...] are x[:, start-left_context:end+right_context,
+ ...],
+ start = i * block_size, end = (i + 1) * block_size.
+ """
+ pad_left = self.max_past_horizon
+ # The JAX equivalent padding for signal.frame with pad_mode='valid' is
+ # (left_context, right_context + block_size - 1) on the time dimension.
+ # PyTorch's _pad_dim1 applies padding symmetrically if only one value is given,
+ # or (pad_dim_start, pad_dim_end) if two are given.
+ # Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
+ # or dim 1 (time for [B,T]).
+ # The current pad_right calculation matches the JAX effective padding.
+ pad_right = self.max_future_horizon + self.chunk_size - 1
+ hidden_states = self._pad_dim1(hidden_states, pad_left, pad_right)
+
+ frame_len = self.context_size
+ frame_step = self.chunk_size
+
+ # Directly use unfold without the subframe_factor logic
+ # x.unfold(dimension, size, step)
+ # dimension=1 (time dimension, assuming x is [B, T_padded, ...])
+ # size=frame_len (context_size)
+ # step=frame_step (chunk_size)
+ x_unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step)
+
+ # If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
+ # If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
+ # We want to match JAX's typical output for such operations which might be
+ # [B, num_blocks, frame_len, N, H] if N, H are present.
+ # The relative_position_embedding expects keys as [B, U, C, N, H].
+ # If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
+ if hidden_states.ndim > 2 and x_unfolded.ndim > 3: # Check if inner dimensions (like N, H) exist
+ # Current shape after unfold for [B, T_pad, N, H] is [B, U, N, H, C]
+ # Target shape for keys in RPE: [B, U, C, N, H]
+ x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2)
+
+ return x_unfolded.contiguous()
+
+ def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+ # sl.Dense uses jax.numpy.einsum("...a,abcd->...bcd") and jax.numpy.select()
+ qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
+ query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
+ key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
+ value_states = self.v_proj(hidden_states).reshape(qkv_shape).contiguous()
+
+ per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale)
+
+ broadcast_shape = (1, 1, 1, self.head_dim)
+ per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape)
+ query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast
+
+ batch_size, q_time = query_states.shape[:2]
+
+ query_blocks = self._convert_to_block(query_states)
+ key_blocks = self._extract_block_context(key_states)
+ value_blocks = self._extract_block_context(value_states)
+ num_query_blocks = query_blocks.shape[1]
+
+ # 1. Create a mask indicating originally valid positions.
+ original_valid_mask = ~mask # True for valid, False for padded
+
+ # 2. Extract blocks from this validity mask.
+ extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask)
+
+ # If subframe_factor was used in _extract_block_context for a [B, T] input mask,
+ # the shape might be [B, U, C/SF, SF]. Reshape to [B, U, C].
+ # batch_size and num_query_blocks are known from query_blocks.
+ # self.context_size is C.
+ if (
+ extracted_valid_mask_blocks.ndim == 4
+ and extracted_valid_mask_blocks.shape[2] * extracted_valid_mask_blocks.shape[3] == self.context_size
+ ):
+ extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
+ batch_size, num_query_blocks, self.context_size
+ )
+ # After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask.
+ # This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
+ # but for the mask case, this should hold.
+ if extracted_valid_mask_blocks.shape != (
+ batch_size,
+ num_query_blocks,
+ self.context_size,
+ ):
+ raise ValueError(
+ "Shape of extracted_valid_mask_blocks"
+ f" {extracted_valid_mask_blocks.shape} is not ({batch_size},"
+ f" {num_query_blocks}, {self.context_size}) after potential reshape."
+ )
+
+ # 3. Expand dimensions for broadcasting with logits and causal mask.
+ # Target shape for broadcasting with logits [B,N,U,W,C]
+ # extracted_valid_mask_blocks to [B, 1, U, 1, C]
+ condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(1).unsqueeze(-2)
+
+ # self.local_causal_valid_mask is [W, C], True where allowed by local window.
+ # Expand to [1, 1, 1, W, C]
+ condition_from_causality = self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
+
+ # 4. Combine the two conditions.
+ # final_condition will be True where a key is *both* originally valid *and* causally accessible.
+ # Broadcasts to [B, 1, U, W, C]
+ final_condition_for_where = torch.logical_and(
+ condition_from_input_validity,
+ condition_from_causality.to(condition_from_input_validity.device), # Ensure same device
+ )
+
+ # Embed queries and keys
+ logits = self.relative_position_embedding(query_blocks, key_blocks)
+
+ # Apply attention logit softcap
+ # Ensure softcap is on the same device as logits
+ softcap_val = self.softcap.to(logits.device)
+ logits = logits / softcap_val
+ logits = torch.tanh(logits)
+ logits = logits * softcap_val
+
+ # Apply the combined mask.
+ # final_condition_for_where will broadcast with logits [B,N,U,W,C]
+ logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
+ probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
+
+ # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
+ b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
+ h_dim = value_blocks.shape[-1]
+ prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
+ v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
+ result_bmm = torch.bmm(prob_bun, v_bun)
+ context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4)
+ context_vectors = context_vectors.reshape(
+ (
+ batch_size,
+ num_query_blocks * self.chunk_size,
+ self.num_heads,
+ self.head_dim,
+ )
+ )
+ context_vectors = context_vectors[:, :q_time]
+
+ return context_vectors
+
+
+class Gemma3nAudioCumulativeGroupNorm(nn.Module):
+ """Applies Group Normalization cumulatively over the time dimension.
+
+ This layer normalizes the input by calculating the mean and variance
+ cumulatively over the time dimension (dim 1). The statistics are computed
+ over all feature dimensions (specified by `feature_dims` and `num_channels`)
+ for elements marked as valid by the optional `mask`.
+
+ If a `mask` is provided (True for valid, False for invalid/padded),
+ invalid time steps do not contribute to the statistics calculation, and
+ their corresponding output values are zeroed out.
+
+ Scale and bias, if enabled, are applied per-channel (last dimension).
+ This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
+ and `cumulative=True`.
+ """
+
+ def __init__(
+ self,
+ num_channels: int, # Number of channels (size of the last dimension)
+ feature_dims: Sequence[int], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
+ eps: float = 1e-3,
+ ):
+ super().__init__()
+ self.num_channels = num_channels
+ self.feature_dims = tuple(feature_dims)
+ self.eps = eps
+
+ # Scale parameter depends only on the channel dimension
+ self.weight = nn.Parameter(torch.ones(num_channels))
+
+ # Axes for normalization: all dimensions except Batch (0) and Time (1).
+ # For input [B, T, *feature_dims, C], these are dims from 2 onwards.
+ self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Applies cumulative group norm, optionally using a mask.
+
+ Args:
+ hidden_states: Input tensor, shape [B, T, *feature_dims, C].
+
+ Returns:
+ Normalized tensor with the same shape as x.
+ """
+ expected_input_suffix = self.feature_dims + (self.num_channels,)
+ if hidden_states.shape[2:] != expected_input_suffix:
+ raise ValueError(
+ f"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected"
+ f" suffix (feature_dims + num_channels) {expected_input_suffix}"
+ )
+
+ input_dtype = hidden_states.dtype
+ # Calculations are performed in float32 for numerical stability.
+ calc_dtype = torch.float32
+ x_calc = hidden_states.to(calc_dtype)
+
+ # Prepare a broadcastable mask (`mask_calc`).
+ # If no mask is provided, treat all elements as valid
+ # (mask_calc is all ones).
+ # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
+ mask_calc = torch.ones_like(x_calc, dtype=calc_dtype)
+
+ # Cumulative Statistics Calculation
+ # 1. Sum of values over reduction axes at each time step.
+ sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True)
+ # 2. Cumulative sum of values over time.
+ cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
+
+ # 3. Count of valid elements in the normalization group at each time step.
+ # (A "group" here consists of all features at a given Batch, Time).
+ elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
+ # 4. Cumulative count of valid elements over time.
+ cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
+ # Avoid division by zero if all preceding elements were masked.
+ safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0)
+
+ # 5. Cumulative mean.
+ cum_mean = cum_sum_values / safe_cum_count_elements
+
+ # 6. Sum of squared differences from the cumulative mean.
+ # Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
+ # Using x_calc here for the difference, as cum_mean already accounts for masking.
+ squared_diff_from_mean = (x_calc - cum_mean).pow(2)
+ sum_sq_diff_at_t = torch.sum(squared_diff_from_mean, dim=self.reduction_axes, keepdim=True)
+
+ # 7. Cumulative sum of squared differences over time.
+ cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1)
+
+ # 8. Cumulative variance.
+ cum_variance = cum_sum_sq_diff / safe_cum_count_elements
+
+ # Normalize the input using the calculated cumulative statistics:
+ # (x - E[x]) / sqrt(Var[x] + eps)
+ normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps)
+
+ # Apply affine transformation (scale and bias) if enabled.
+ # Scale and bias are applied per-channel (last dimension).
+ scale = self.weight.to(calc_dtype)
+ # Reshape for broadcasting: [C] -> [1, ..., 1, C]
+ scale_view_shape = [1] * (hidden_states.dim() - 1) + [self.num_channels]
+ normalized_x = normalized_x * scale.view(scale_view_shape)
+
+ # Zero out outputs for time steps that were originally masked (where mask_calc is 0).
+ # This ensures padded/invalid positions in the input result in zero output.
+ final_output = normalized_x * mask_calc
+
+ return final_output.to(input_dtype)
+
+
+class Gemma3nAudioSSCPConvBlock(nn.Module):
+ """A single convolution block for the SubSampleConvProjection.
+
+ This block consists of a 2D convolution, followed by CumulativeGroupNorm,
+ and a ReLU activation. It handles manual padding for the convolution.
+ """
+
+ def __init__(
+ self,
+ config: Gemma3nAudioConfig,
+ idx: int,
+ input_freq_dim: int, # Changed from input_spatial_dim
+ manual_padding: tuple[int, int, int, int] = (0, 0, 0, 0),
+ ):
+ super().__init__()
+ self.config = config
+ self.manual_padding = manual_padding
+
+ # in_channels is 1 for the first block, or C_out from previous block's conv
+ in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
+ out_channels = self.config.sscp_conv_channel_size[idx]
+ kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
+ stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
+
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(
+ kernel_h,
+ kernel_w,
+ ), # Kernel (kH, kW) operates on (Time, Freq_dim)
+ stride=(stride_h, stride_w),
+ padding=(0, 0), # Manual padding is used
+ bias=False,
+ )
+
+ # Calculate output frequency dimension (f_out_conv) after this convolution.
+ # input_freq_dim is the unpadded width (feature dimension).
+ # self.manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
+ f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1]
+ f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
+
+ self.norm = Gemma3nAudioCumulativeGroupNorm(
+ num_channels=out_channels, # Channels of the conv output
+ feature_dims=(f_out_conv,), # The frequency dimension size after conv
+ eps=self.config.sscp_conv_group_norm_eps,
+ )
+
+ self.activation = nn.ReLU()
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ # Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
+ # manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
+ # F.pad applies to last two dims: F_in then T_in
+ audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0).to(
+ self.conv.weight.dtype
+ )
+ # Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
+ # Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
+ audio_encodings_conv = self.conv(audio_encodings_padded)
+ # Expected conv output shape: [B, C_out, T_out, F_out]
+ # Input to norm is [B, T_out, F_out, C_out]
+ x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous()
+ x_normed = self.norm(x_for_norm)
+ # Output of norm is [B, T_out, F_out, C_out], permute back to [B, C_out, T_out, F_out]
+ audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous()
+ return self.activation(audio_encodings_normed)
+
+
+class Gemma3nAudioSubSampleConvProjection(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ current_f_for_block_input = config.input_feat_size # Start with original feature dim
+ calculated_block_padding = []
+ calculated_f_out_dims = [] # Tracking frequency dimension output sizes
+
+ for i in range(2): # Assuming 2 conv layers as per sscp_conv_... arrays
+ kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
+ stride_h, stride_w = config.sscp_conv_stride_size[i]
+
+ # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
+ # JAX 'reverse_causal' padding is (0, kernel_size - 1)
+ pad_t_top = 0
+ pad_t_bottom = kernel_h - 1
+
+ # Frequency Padding (Width for Conv2d)
+ # Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
+ # and the successful test configuration.
+ # If kernel/stride/input_freq for frequency changes, this might need re-evaluation
+ # to match generic JAX 'SAME' behavior if it differs.
+ pad_f_left = 1
+ pad_f_right = 1
+
+ manual_padding_tuple = (
+ pad_f_left,
+ pad_f_right,
+ pad_t_top,
+ pad_t_bottom,
+ )
+ calculated_block_padding.append(manual_padding_tuple)
+
+ # Calculate output frequency dimension after this convolution
+ # This uses the actual padding applied and kernel/stride.
+ f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
+ f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 # Assuming dilation_w = 1
+ calculated_f_out_dims.append(f_out_after_conv)
+ current_f_for_block_input = f_out_after_conv
+
+ self.conv_0 = Gemma3nAudioSSCPConvBlock(
+ idx=0,
+ input_freq_dim=config.input_feat_size, # Pass original feature dim
+ config=config,
+ manual_padding=calculated_block_padding[0],
+ )
+ self.conv_1 = Gemma3nAudioSSCPConvBlock(
+ idx=1,
+ input_freq_dim=calculated_f_out_dims[0], # Output freq dim from conv_0
+ config=config,
+ manual_padding=calculated_block_padding[1],
+ )
+ final_c_out = config.sscp_conv_channel_size[-1]
+ final_f_out = calculated_f_out_dims[-1] # Final frequency dimension
+ self.input_proj_in_features = final_c_out * final_f_out
+ self.input_proj_linear = nn.Linear(self.input_proj_in_features, self.config.hidden_size, bias=False)
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ # audio_encodings is [B, T, F_in]
+ # Reshape to [B, 1, T, F_in] (Batch, Channels=1, Height=Time, Width=F_in)
+ audio_encodings_reshaped = audio_encodings.unsqueeze(1)
+ x = self.conv_0(audio_encodings_reshaped)
+ x = self.conv_1(x)
+ # x from conv_1 is [B, C_out_1, T_out_1, F_out_1]
+ b, c_out, t_out, f_out = x.shape
+ # Permute to [B, T_out_1, F_out_1, C_out_1] then flatten F_out_1 and C_out_1
+ x_permuted = x.permute(0, 2, 3, 1).contiguous()
+ output_flattened = x_permuted.view(b, t_out, f_out * c_out)
+ output = self.input_proj_linear(output_flattened)
+ return output
+
+
+class Gemma3nAudioConformerAttention(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+ self.post_in_features = self.config.hidden_size
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+ self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
+ self.attn = Gemma3nAudioAttention(config)
+ self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
+ self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
+
+ def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
+ audio_encodings_input_to_attn = audio_encodings
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings_norm = self.pre_attn_norm(audio_encodings)
+ # Output of self.attn is [B, T, NumHeads, HeadDim]
+ audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask)
+
+ # Reshape from [B, T, NumHeads, HeadDim] to [B, T, NumHeads * HeadDim]
+ # NumHeads * HeadDim = hidden_size
+ b, t, num_heads, head_dim = audio_encodings_attn_out.shape
+ audio_encodings_reshaped = audio_encodings_attn_out.reshape(b, t, num_heads * head_dim)
+
+ audio_encodings = self.post(audio_encodings_reshaped)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ return audio_encodings_input_to_attn + self.post_norm(audio_encodings)
+
+
+class Gemma3nAudioConformerFeedForward(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+
+ self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
+ self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
+ self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
+ self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
+ self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ residual = audio_encodings
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings = self.pre_layer_norm(audio_encodings)
+ audio_encodings: torch.Tensor = self.ffw_layer_1(audio_encodings)
+ audio_encodings = nn.functional.silu(audio_encodings)
+ audio_encodings: torch.Tensor = self.ffw_layer_2(audio_encodings)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings = self.post_layer_norm(audio_encodings)
+ return residual + (audio_encodings * self.post_layer_scale)
+
+
+class Gemma3nAudioConformerLightConv1d(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.linear_start = nn.Linear(self.config.hidden_size, self.config.hidden_size * 2, bias=False)
+ self.depthwise_conv1d = nn.Conv1d(
+ in_channels=self.config.hidden_size,
+ out_channels=self.config.hidden_size,
+ kernel_size=self.config.conf_conv_kernel_size,
+ stride=1,
+ padding=0, # Manual causal padding
+ groups=self.config.hidden_size, # Depthwise
+ bias=False,
+ )
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+ self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
+
+ self.causal_padding = self.config.conf_conv_kernel_size - 1
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ audio_encodings_residual = audio_encodings # Save for residual connection
+
+ audio_encodings = self.pre_layer_norm(audio_encodings)
+ audio_encodings = self.linear_start(audio_encodings)
+ audio_encodings = torch.nn.functional.glu(audio_encodings, dim=-1)
+ # Permute for Conv1d: [B, T, D] -> [B, D, T]
+ audio_encodings_permuted = audio_encodings.permute(0, 2, 1)
+ # Apply manual causal padding
+ audio_encodings_permuted_padded = F.pad(audio_encodings_permuted, (self.causal_padding, 0))
+ audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded)
+ # Permute back: [B, D, T_out] -> [B, T_out, D]
+ audio_encodings = audio_encodings.permute(0, 2, 1)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings = self.conv_norm(audio_encodings)
+ audio_encodings = nn.functional.silu(audio_encodings)
+ audio_encodings = self.linear_end(audio_encodings)
+ output = audio_encodings + audio_encodings_residual
+ return output
+
+
+class Gemma3nAudioConformerBlock(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.ffw_layer_start = Gemma3nAudioConformerFeedForward(self.config)
+ self.attention = Gemma3nAudioConformerAttention(self.config)
+ self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
+ self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+ self.norm = Gemma3nRMSNorm(self.config.hidden_size)
+
+ def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
+ audio_encodings = self.ffw_layer_start(audio_encodings)
+ audio_encodings = self.attention(audio_encodings, audio_mel_mask)
+ validity_mask_for_lconv = ~audio_mel_mask # True for valid
+ audio_encodings_for_lconv_input = audio_encodings * validity_mask_for_lconv.unsqueeze(-1).to(
+ audio_encodings.dtype
+ )
+ audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
+
+ audio_encodings = self.ffw_layer_end(audio_encodings)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ output = self.norm(audio_encodings)
+ return output
+
+
+class Gemma3nAudioEncoder(PreTrainedModel):
+ """
+ An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture.
+ """
+
+ config: Gemma3nAudioConfig
+
+ main_input_name = "audio_mel"
+
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
+ self.conformer = nn.ModuleList(
+ [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
+ )
+
+ def forward(
+ self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
+ ) -> tuple[torch.Tensor, torch.BoolTensor]:
+ """Encodes a batch of MELs.
+
+ Args:
+ audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
+ mel_bins].
+
+ Returns:
+ audio_encodings: a torch.Tensor of shape
+ `[batch_size, self.config.audio_soft_tokens_per_image,
+ self.config.audio_config.hidden_size]`
+ audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
+ """
+ audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
+
+ # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
+ t_sub = audio_encodings.shape[1]
+
+ time_stride_product = 1
+ for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
+ time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
+
+ # Create indices for gathering from the original mask.
+ # These indices map to original time steps corresponding to the start of each
+ # receptive field in the subsampled output.
+ indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
+ indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) # Ensure indices are valid
+
+ # Expand indices for batch compatibility if B > 1 and indices is 1D.
+ if audio_mel_mask.ndim > 1 and indices.ndim == 1:
+ indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) # [B, T_sub]
+ elif (
+ audio_mel_mask.ndim == indices.ndim
+ and audio_mel_mask.shape[0] == 1
+ and indices.shape[0] != 1
+ and t_sub == indices.shape[0]
+ ):
+ # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
+ indices = indices.unsqueeze(0)
+
+ current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub]
+
+ for block in self.conformer:
+ audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
+
+ if self.config.conf_reduction_factor > 1:
+ audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
+ # Reduce the mask as well
+ current_mask = current_mask[:, :: self.config.conf_reduction_factor]
+
+ audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
+ return audio_encodings, current_mask
+
+
+# ==== Language Model ====
+
+
+class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
+ pass
+
+
+class Gemma3nTextLaurelBlock(nn.Module):
+ """Learned Augmented Residual Layer"""
+
+ def __init__(self, config: Gemma3nTextConfig):
+ super().__init__()
+ self.config = config
+
+ self.linear_left = nn.Linear(self.config.hidden_size, self.config.laurel_rank, bias=False)
+ self.linear_right = nn.Linear(self.config.laurel_rank, self.config.hidden_size, bias=False)
+ self.post_laurel_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ laurel_hidden_states: torch.Tensor = self.linear_left(hidden_states)
+ laurel_hidden_states: torch.Tensor = self.linear_right(laurel_hidden_states)
+ normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states)
+ return hidden_states + normed_laurel_hidden_states
+
+
+class Gemma3nTextMLP(Gemma2MLP):
+ def __init__(self, config: Gemma3nTextConfig, layer_idx: int = 0):
+ super().__init__(config)
+ self.intermediate_size = config.intermediate_size[layer_idx]
+ self.activation_sparsity = config.activation_sparsity_pattern[layer_idx]
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ gate_proj = self.gate_proj(hidden_states)
+ if self.activation_sparsity > 0.0:
+ gate_proj = self._gaussian_topk(gate_proj)
+ activations = self.act_fn(gate_proj)
+ up_proj = self.up_proj(hidden_states)
+ down_proj = self.down_proj(activations * up_proj)
+ return down_proj
+
+ def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
+ target_sparsity_tensor = torch.tensor(self.activation_sparsity, dtype=torch.float32, device=inputs.device)
+ # normal_dist and std_multiplier are adapted from jax.scipy.stats.norm.ppf().
+ #
+ # References:
+ # * https://docs.jax.dev/en/latest/_autosummary/jax.scipy.stats.norm.ppf.html
+ # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.normal.Normal
+ # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.transformed_distribution.TransformedDistribution.icdf
+ normal_dist = torch.distributions.normal.Normal(0, 1)
+ std_multiplier: torch.Tensor = normal_dist.icdf(target_sparsity_tensor)
+ std_multiplier = std_multiplier.type(inputs.dtype)
+ inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
+ inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
+ cutoff_x = inputs_mean + inputs_std * std_multiplier
+ return nn.functional.relu(inputs - cutoff_x)
+
+
+class Gemma3nTextAltUp(nn.Module):
+ """Alternating Updates (AltUp)
+
+ The AltUp module wraps transformer layers. The `predict` step modifies the
+ input to the transformer layer, and the `correct` step propagates the output
+ of the transformer layer to the sparsely updated dimensions.
+
+ See more in the research paper:
+
+ https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
+ """
+
+ def __init__(self, config: Gemma3nTextConfig):
+ super().__init__()
+ self.config = config
+ self.correct_output_scale = nn.Parameter(torch.zeros(self.config.hidden_size))
+ self.correction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs, bias=False)
+ self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
+ self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
+ self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.register_buffer("router_input_scale", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
+
+ def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
+ router_inputs = self.router_norm(x) * self.router_input_scale
+ routed = self.modality_router(router_inputs)
+ return torch.tanh(routed.float()).type_as(x)
+
+ def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Predicts the output of a layer using a trainable map.
+
+ Args:
+ hidden_states: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
+ stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
+
+ Returns:
+ A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` containing the predictions.
+ """
+ modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
+
+ if self.training and self.config.altup_coef_clip is not None:
+ self.prediction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
+
+ # Project and then transpose all 2D matrices contained so that mulmat gives the correct result
+ all_coefs: torch.Tensor = (
+ self.prediction_coefs(modalities)
+ .reshape(*modalities.shape[:-1], self.config.altup_num_inputs, self.config.altup_num_inputs)
+ .permute(0, 1, 3, 2)
+ )
+
+ # permute hidden_states to [batch_size, num_tokens, hidden_size, altup_num_inputs]
+ predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs)
+ predictions = predictions.permute(3, 0, 1, 2) # undo the permute
+ predictions += hidden_states # add the original input
+ return predictions.contiguous().type_as(hidden_states)
+
+ def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
+ """Corrects the predictions relative to the
+
+ Args:
+ predictions: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
+ stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
+ activated: A 3D tensor of shape `[batch_size, num_tokens, hidden_size]` containing the activated inputs.
+
+ Returns:
+ A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` correcting the original
+ predictions relative to the activated input embeddings.
+ """
+ modalities = self.compute_router_modalities(activated)
+ innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
+ innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
+
+ if self.config.altup_coef_clip is not None:
+ self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
+
+ # all_coefs adapted from jax.numpy.einsum("...p,pi->...i", ...)
+ # Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
+ # and expand on dim1 for broadcastability
+ all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0
+ all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1)
+
+ corrected = torch.mul(innovation, all_coefs)
+ corrected += predictions # add the original input
+ return corrected.contiguous().type_as(activated)
+
+ def forward(self, corrected: torch.Tensor) -> torch.Tensor:
+ """
+ This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
+ (which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
+ `scale_corrected_output`
+ """
+ return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
+
+ def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
+ """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
+ return self.forward(corrected)
+
+
+class Gemma3nTextRotaryEmbedding(Gemma2RotaryEmbedding):
+ pass
+
+
+def apply_rotary_pos_emb(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ position_ids: Optional[torch.Tensor] = None,
+ unsqueeze_dim: int = 1,
+):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ x (`torch.Tensor`): The tensor to embed.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ return (x * cos) + (rotate_half(x) * sin)
+
+
+class Gemma3nTextAttention(Gemma3Attention):
+ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.is_causal = True
+ del self.attn_logit_softcapping
+ del self.scaling
+ self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
+
+ first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
+ self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
+ prev_layers = config.layer_types[:first_kv_shared_layer_idx]
+ if self.is_kv_shared_layer:
+ # For shared layers, find the last non-shared layer of the same type before sharing starts
+ self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
+ self.store_full_length_kv = False
+ else:
+ self.kv_shared_layer_index = None
+ # For non-shared layers, store full-length kv if this is the last non-shared layer of its type
+ self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
+ config.layer_types[layer_idx]
+ )
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.config.head_dim)
+
+ cos, sin = position_embeddings
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
+ query_states = self.q_norm(query_states)
+ query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
+ query_states = query_states.transpose(1, 2)
+
+ # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer
+ if self.is_kv_shared_layer and past_key_values is not None:
+ key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
+ # Device of past layer may be different from current one
+ key_states = key_states.to(query_states.device)
+ value_states = value_states.to(query_states.device)
+ else:
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
+ key_states = self.k_norm(key_states)
+ key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
+ key_states = key_states.transpose(1, 2)
+
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
+ value_states = self.v_norm(value_states)
+ value_states = value_states.transpose(1, 2)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "cache_position": cache_position,
+ "sliding_window": self.sliding_window,
+ }
+ if not self.is_kv_shared_layer:
+ key_states, value_states = past_key_values.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+ if self.store_full_length_kv:
+ if not hasattr(past_key_values, "shared_layers"):
+ past_key_values.shared_layers = {}
+ past_key_values.shared_layers[self.layer_idx] = key_states, value_states
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=self.attention_dropout if self.training else 0.0,
+ scaling=1.0,
+ sliding_window=self.sliding_window,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
+ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx)
+
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
+ self.act_fn = ACT2FN[config.hidden_activation]
+
+ self.altup = Gemma3nTextAltUp(config)
+ self.laurel = Gemma3nTextLaurelBlock(config)
+ self.self_attn = Gemma3nTextAttention(config, layer_idx)
+ self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
+ self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
+ self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings_global: torch.Tensor,
+ position_embeddings_local: torch.Tensor,
+ per_layer_input: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ predictions = self.altup.predict(hidden_states)
+ active_prediction = predictions[self.config.altup_active_idx]
+
+ active_prediction_normed = self.input_layernorm(active_prediction)
+ laurel_output = self.laurel(active_prediction_normed)
+
+ # apply global RoPE to non-sliding layer only
+ if self.self_attn.is_sliding:
+ position_embeddings = position_embeddings_local
+ else:
+ position_embeddings = position_embeddings_global
+
+ attn, self_attn_weights = self.self_attn(
+ hidden_states=active_prediction_normed,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ attn = self.post_attention_layernorm(attn)
+
+ attn_gated = active_prediction + attn
+ attn_laurel = (attn_gated + laurel_output) / math.sqrt(2)
+
+ attn_norm = self.pre_feedforward_layernorm(attn_laurel)
+ attn_ffw = self.mlp(attn_norm)
+ attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
+ attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
+ corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
+
+ first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
+ if self.config.altup_correct_scale:
+ first_prediction = self.altup.scale_corrected_output(first_prediction)
+
+ # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
+ first_prediction = self.per_layer_input_gate(first_prediction)
+ first_prediction = self.act_fn(first_prediction)
+ first_prediction = torch.multiply(first_prediction, per_layer_input)
+
+ # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
+ first_prediction = self.per_layer_projection(first_prediction)
+ first_prediction = self.post_per_layer_input_norm(first_prediction)
+ corrected_predictions[1:] += first_prediction
+
+ outputs = (corrected_predictions,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+class Gemma3nPreTrainedModel(Gemma2PreTrainedModel):
+ config: Gemma3nConfig
+ base_model_prefix = ""
+ _no_split_modules = ["Gemma3nTextDecoderLayer"]
+
+ def _init_weights(self, module):
+ PreTrainedModel._init_weights(self, module)
+ if isinstance(module, Gemma3nAudioCumulativeGroupNorm):
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Gemma3nAudioAttention):
+ module.per_dim_scale.data.zero_()
+ elif isinstance(module, Gemma3nTextAltUp):
+ module.correct_output_scale.data.zero_()
+
+
+@auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.")
+class Gemma3nTextModel(Gemma3TextModel):
+ config: Gemma3nTextConfig
+
+ def __init__(self, config: Gemma3nTextConfig):
+ super().__init__(config)
+
+ self.hidden_size = config.hidden_size
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
+
+ self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding(
+ config.vocab_size_per_layer_input,
+ config.num_hidden_layers * config.hidden_size_per_layer_input,
+ self.padding_idx,
+ embed_scale=config.hidden_size_per_layer_input**0.5,
+ )
+
+ self.per_layer_model_projection = nn.Linear(
+ self.hidden_size,
+ config.num_hidden_layers * config.hidden_size_per_layer_input,
+ bias=False,
+ )
+
+ self.per_layer_projection_norm = Gemma3nRMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps)
+ self.layers = nn.ModuleList(
+ [Gemma3nTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+
+ self.norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.altup_projections = nn.ModuleList(
+ [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
+ )
+
+ self.altup_unembed_projections = nn.ModuleList(
+ [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
+ )
+
+ self.register_buffer("per_layer_projection_scale", torch.tensor(self.hidden_size**-0.5), persistent=False)
+ self.register_buffer("per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False)
+ self.rotary_emb = Gemma3nTextRotaryEmbedding(config=config)
+
+ # TODO (raushan): Fix this after RoPE refactor. For now we hack it by
+ # reassigning thetas when we want to create a local RoPE layer. Config
+ # defaults should hold values for global RoPE.
+ config = copy.deepcopy(config)
+ config.rope_theta = config.rope_local_base_freq
+ config.rope_scaling = {"rope_type": "default"}
+ self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config)
+
+ def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor:
+ return self.embed_tokens_per_layer(input_ids).reshape(
+ *input_ids.shape,
+ self.config.num_hidden_layers,
+ self.hidden_size_per_layer_input,
+ )
+
+ def project_per_layer_inputs(
+ self,
+ inputs_embeds: torch.Tensor,
+ per_layer_inputs: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
+ per_layer_projection *= self.per_layer_projection_scale.to(
+ dtype=inputs_embeds.dtype, device=per_layer_projection.device
+ )
+ per_layer_projection = per_layer_projection.reshape(
+ *inputs_embeds.shape[:-1],
+ self.config.num_hidden_layers,
+ self.hidden_size_per_layer_input,
+ )
+ per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
+
+ if per_layer_inputs is None:
+ return per_layer_projection
+
+ if per_layer_projection.shape != per_layer_inputs.shape:
+ # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
+ per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
+
+ return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
+ dtype=inputs_embeds.dtype, device=per_layer_projection.device
+ )
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ per_layer_inputs: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ r"""
+ per_layer_inputs (torch.Tensor, *optional*, defaults to None):
+ Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if input_ids is not None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ per_layer_inputs = self.get_per_layer_inputs(input_ids)
+
+ per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
+
+ if use_cache and past_key_values is None and not self.training:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ # Prepare mask arguments
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ # Create the masks
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ # embed positions
+ hidden_states_0 = inputs_embeds
+
+ # Initialize RoPE embeddings
+ position_embeddings_global = self.rotary_emb(hidden_states_0, position_ids)
+ position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
+
+ # Expand hidden_states to support per-layer inputs
+ target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
+ epsilon_tensor = torch.tensor(1e-5)
+
+ temp_hidden_states = [hidden_states_0]
+ for i in range(1, self.config.altup_num_inputs):
+ # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
+ altup_proj = self.altup_projections[i - 1](hidden_states_0)
+ current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
+ new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
+ new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
+ current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
+ temp_hidden_states.append(current_hidden_state)
+
+ hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ causal_mask = causal_mask_mapping[decoder_layer.attention_type]
+ per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :]
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ position_embeddings_global,
+ position_embeddings_local,
+ per_layer_input,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ # add hidden states from the last decoder layer (but before reprojecting to stay consistent with layer output)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # Per-layer inputs to single output
+ target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
+ temp_hidden_states = [hidden_states[0]]
+ for i in range(1, self.config.altup_num_inputs):
+ # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
+ altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
+ current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
+ new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
+ new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
+ current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
+ temp_hidden_states.append(current_hidden_state)
+
+ hidden_states = torch.stack(temp_hidden_states)
+ hidden_states = torch.mean(hidden_states, dim=0)
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+@auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")
+class Gemma3nForCausalLM(Gemma3ForCausalLM):
+ _checkpoint_conversion_mapping = {"model.language_model": "model"}
+ base_model_prefix = "model"
+
+
+class Gemma3nMultimodalEmbedder(nn.Module):
+ """Embeds token ids or soft tokens for multimodal content into language model space."""
+
+ def __init__(
+ self,
+ multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
+ text_config: Gemma3nTextConfig,
+ ):
+ super().__init__()
+
+ self.multimodal_hidden_size = multimodal_config.hidden_size
+ self.eps = multimodal_config.rms_norm_eps
+ self.vocab_offset = multimodal_config.vocab_offset
+ self.vocab_size = multimodal_config.vocab_size
+ self.text_hidden_size = text_config.hidden_size
+
+ self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size)
+ self.hard_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
+ self.soft_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
+ self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
+ self.embedding_post_projection_norm = Gemma3nRMSNorm(self.text_hidden_size, eps=self.eps, with_scale=False)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Embeds token ids or soft tokens for multimodal content into language model space.
+
+ Args:
+ input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
+ `[vocab_offset, vocab_offset + vocab_size)`.
+ inputs_embeds: A torch.Tensor containing the soft tokens to embed.
+
+ Returns:
+ A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is not None:
+ emb_norm = self.soft_embedding_norm(inputs_embeds)
+ else:
+ hard_emb = self.embedding(input_ids - self.vocab_offset)
+ emb_norm = self.hard_embedding_norm(hard_emb)
+
+ emb_norm_proj = self.embedding_projection(emb_norm)
+ return self.embedding_post_projection_norm(emb_norm_proj)
+
+
+@auto_docstring(
+ custom_intro="""
+ The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
+ language modeling head.
+ """
+)
+class Gemma3nModel(PaliGemmaModel):
+ _checkpoint_conversion_mapping = {}
+
+ def __init__(self, config: Gemma3nConfig):
+ super().__init__(config)
+ del self.multi_modal_projector # Replaced by Gemma3nVisionEmbedder
+ del self.text_config_dtype
+ self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
+ self.audio_tower = AutoModel.from_config(config.audio_config)
+ self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config)
+ self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config)
+
+ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ """
+ Projects the last hidden state from the vision model into language model space.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
+ The tensors corresponding to the input images.
+
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ vision_outputs = self.vision_tower(
+ pixel_values=pixel_values, do_pooling=False, return_dict=True
+ ).last_hidden_state
+ # Convert from (batch, channels, height, width) to (batch, height * width, channels) where:
+ # height == width and height * width == Gemma3nConfig.vision_soft_tokens_per_image.
+ vision_outputs = vision_outputs.reshape(
+ vision_outputs.shape[0],
+ self.config.vision_config.hidden_size,
+ self.config.vision_soft_tokens_per_image,
+ ).permute(0, 2, 1)
+ # Normalize and embed the soft tokens into language model space.
+ vision_outputs *= self.config.vision_config.hidden_size**0.5
+ return self.embed_vision(inputs_embeds=vision_outputs)
+
+ def get_placeholder_mask(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ image_features: Optional[torch.FloatTensor] = None,
+ audio_features: Optional[torch.FloatTensor] = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_audio_mask = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ ).all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_audio_mask = input_ids == self.config.audio_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0] * image_features.shape[1]}"
+ )
+
+ n_audio_tokens = special_audio_mask.sum()
+ special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if audio_features is not None and inputs_embeds[special_audio_mask].numel() != audio_features.numel():
+ raise ValueError(
+ f"Audio features and image tokens do not match: tokens: {n_audio_tokens}, features {audio_features.shape[0] * audio_features.shape[1]}"
+ )
+
+ return special_image_mask, special_audio_mask
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None, # text inputs
+ pixel_values: Optional[torch.FloatTensor] = None, # vision inputs
+ input_features: Optional[torch.FloatTensor] = None, # audio inputs
+ attention_mask: Optional[torch.Tensor] = None,
+ input_features_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ **lm_kwargs,
+ ) -> Gemma3nCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
+
+ >>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224")
+ >>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224")
+
+ >>> prompt = "Where is the cat standing?"
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs,)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Where is the cat standing?\nsnow"
+ ```
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if input_ids is not None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # Prepare per-layer inputs from inputs_ids
+ per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input)
+ per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids))
+ per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens)
+
+ # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset)
+ vision_mask = torch.logical_and(
+ input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset
+ )
+ dummy_vision_token_id = self.embed_vision.vocab_offset + self.embed_vision.vocab_size - 1
+ vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device)
+ vision_embeds = self.embed_vision(input_ids=vision_input_ids)
+ expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds)
+ inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds)
+
+ # Handle audio tokens (>= embed_audio.vocab_offset)
+ audio_mask = input_ids >= self.embed_audio.vocab_offset
+ dummy_audio_token_id = self.embed_audio.vocab_offset + self.embed_audio.vocab_size - 1
+ audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device)
+ audio_embeds = self.embed_audio(input_ids=audio_input_ids)
+ expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
+ inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds)
+ else:
+ per_layer_inputs = None
+
+ # Merge text and images
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values)
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ # Merge text and audio
+ if input_features is not None and input_features_mask is not None:
+ audio_features, audio_mask = self.get_audio_features(input_features, ~input_features_mask)
+
+ # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
+ # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
+ # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
+ # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
+ # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
+ audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device)
+ audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
+ audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features)
+
+ audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
+ extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len
+ extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim)
+
+ audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
+ audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ _, special_audio_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
+
+ outputs = self.language_model(
+ input_ids=None,
+ per_layer_inputs=per_layer_inputs,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ return Gemma3nModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values if use_cache else None,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ audio_hidden_states=audio_features if input_features is not None else None,
+ )
+
+ def get_audio_features(
+ self, input_features: torch.Tensor, input_features_mask: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Projects the last hidden state from the audio encoder into language model space.
+
+ Args:
+ input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`):
+ The tensors corresponding to the input audio.
+ input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
+ The attention mask for the input audio.
+
+ Returns:
+ audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_images, audio_length, embed_dim)`).
+ """
+ audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask)
+ return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
+
+ def _update_causal_mask(self, **super_kwargs):
+ raise AttributeError("We don't want to inherit it")
+
+
+@auto_docstring(
+ custom_intro="""
+ The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
+ head.
+ """
+)
+class Gemma3nForConditionalGeneration(PaliGemmaForConditionalGeneration):
+ _checkpoint_conversion_mapping = {}
+ base_model_prefix = "model"
+
+ @property
+ def audio_tower(self):
+ return self.model.audio_tower
+
+ @property
+ def multi_modal_projector(self):
+ raise AttributeError("Use embed_vision instead of multi_modal_projector.")
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None, # text inputs
+ pixel_values: Optional[torch.FloatTensor] = None, # vision inputs
+ input_features: Optional[torch.FloatTensor] = None, # audio inputs
+ attention_mask: Optional[torch.Tensor] = None,
+ input_features_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **lm_kwargs,
+ ) -> Gemma3nCausalLMOutputWithPast:
+ r"""
+ input_features_mask (torch.Tensor, *optional*, defaults to None):
+ The attention mask for the input audio.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels in
+ `[0, ..., config.text_config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
+
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
+ >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
+
+ >>> messages = [
+ ... {
+ ... "role": "system",
+ ... "content": [
+ ... {"type": "text", "text": "You are a helpful assistant."}
+ ... ]
+ ... },
+ ... {
+ ... "role": "user", "content": [
+ ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
+ ... {"type": "text", "text": "Where is the cat standing?"},
+ ... ]
+ ... },
+ ... ]
+
+ >>> inputs = processor.apply_chat_template(
+ ... messages,
+ ... tokenizer=True,
+ ... return_dict=True,
+ ... return_tensors="pt",
+ ... add_generation_prompt=True
+ ... )
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ input_features=input_features,
+ attention_mask=attention_mask,
+ input_features_mask=input_features_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ token_type_ids=token_type_ids,
+ cache_position=cache_position,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ **lm_kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None:
+ logits = logits / final_logit_softcapping
+ logits = torch.tanh(logits)
+ logits = logits * final_logit_softcapping
+
+ loss = None
+ if labels is not None:
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ shift_logits = logits[..., :-1, :]
+ shift_labels = labels[..., 1:]
+ if attention_mask is not None:
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
+ else:
+ shift_logits = shift_logits.contiguous()
+ shift_labels = shift_labels.contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss()
+
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
+ loss = loss_fct(flat_logits, flat_labels)
+
+ return Gemma3nCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ audio_hidden_states=outputs.audio_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ pixel_values=None,
+ input_features=None,
+ attention_mask=None,
+ input_features_mask=None,
+ token_type_ids=None,
+ use_cache=True,
+ logits_to_keep=None,
+ labels=None,
+ **kwargs,
+ ):
+ # Overwritten -- custom `position_ids` and `pixel_values` handling
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ use_cache=use_cache,
+ logits_to_keep=logits_to_keep,
+ token_type_ids=token_type_ids,
+ **kwargs,
+ )
+
+ # If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
+ # tokens anymore. Otherwise multimodal inputs should be passed to model.
+ # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
+ if cache_position[0] == 0:
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["input_features"] = input_features
+ model_inputs["input_features_mask"] = input_features_mask
+
+ return model_inputs
+
+ def _prepare_4d_causal_attention_mask_with_cache_position(self, **super_kwargs):
+ raise AttributeError("Do not inherit _prepare_4d_causal_attention_mask_with_cache_position from PaliGemma")
+
+
+__all__ = [
+ "Gemma3nAudioConfig",
+ "Gemma3nAudioEncoder",
+ "Gemma3nConfig",
+ "Gemma3nForCausalLM",
+ "Gemma3nForConditionalGeneration",
+ "Gemma3nModel",
+ "Gemma3nPreTrainedModel",
+ "Gemma3nTextConfig",
+ "Gemma3nTextModel",
+ "Gemma3nVisionConfig",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gemma3n/processing_gemma3n.py b/venv/lib/python3.13/site-packages/transformers/models/gemma3n/processing_gemma3n.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2c2c3ae10f8ab79f2c18d95028f57c28f6f9150
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gemma3n/processing_gemma3n.py
@@ -0,0 +1,165 @@
+# coding=utf-8
+# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput, make_nested_list_of_images
+from ...processing_utils import AudioKwargs, ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+
+
+class Gemma3nImagesKwargs(ImagesKwargs):
+ do_convert_rgb: Optional[bool]
+
+
+class Gemma3nProcessorKwargs(ProcessingKwargs, total=False):
+ audio_kwargs: AudioKwargs
+ images_kwargs: Gemma3nImagesKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ },
+ }
+
+
+class Gemma3nProcessor(ProcessorMixin):
+ """
+ A processor for Gemma 3n, wrapping the full capabilities of a feature extractor, image processor, and tokenizer
+ into a single processor.
+
+ Args:
+ feature_extractor (`Gemma3nAudioFeatureExtractor`):
+ Feature extractor that converts raw audio waveforms into MEL spectrograms for the audio encoder. This
+ should return a `BatchFeature` with `input_features` and `input_features_mask` features.
+ image_processor (`SiglipImageProcessorFast`):
+ Image processor that prepares batches of images for the vision encoder. This should return a `BatchFeature`
+ with a `pixel_values` feature.
+ tokenizer (`GemmaTokenizerFast`):
+ The text tokenizer for the model.
+ chat_template (`string`, *optional*):
+ A Jinja template for generating text prompts from a set of messages.
+ audio_seq_length (int, *optional*, defaults to 188):
+ The number of audio soft tokens that will be added to the text prompt
+ image_seq_length (int, *optional*, defaults to 256):
+ The number of image soft tokens that should be added to
+ """
+
+ attributes = ["feature_extractor", "image_processor", "tokenizer"]
+ feature_extractor_class = "AutoFeatureExtractor"
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ feature_extractor,
+ image_processor,
+ tokenizer,
+ chat_template=None,
+ audio_seq_length: int = 188,
+ image_seq_length: int = 256,
+ **kwargs,
+ ):
+ self.audio_seq_length = audio_seq_length
+ self.audio_token_id = tokenizer.audio_token_id
+ self.boa_token = tokenizer.boa_token
+ self.audio_token = tokenizer.audio_token
+ audio_tokens_expanded = "".join([tokenizer.audio_token] * audio_seq_length)
+ self.full_audio_sequence = f"\n\n{tokenizer.boa_token}{audio_tokens_expanded}{tokenizer.eoa_token}\n\n"
+
+ self.image_seq_length = image_seq_length
+ self.image_token_id = tokenizer.image_token_id
+ self.boi_token = tokenizer.boi_token
+ self.image_token = tokenizer.image_token
+ image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
+ self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
+
+ super().__init__(
+ feature_extractor=feature_extractor,
+ image_processor=image_processor,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ **kwargs,
+ )
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ audio: Optional[Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]]] = None,
+ videos=None,
+ **kwargs: Unpack[Gemma3nProcessorKwargs],
+ ) -> BatchFeature:
+ if text is None and images is None and audio is None:
+ raise ValueError("Provide at least one of `text`, `images`, or `audio`.")
+
+ output_kwargs = self._merge_kwargs(
+ Gemma3nProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list) and not isinstance(text[0], str):
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+
+ if audio is not None:
+ audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
+
+ if not text:
+ text = [self.audio_token for _ in audio]
+
+ # Expand placeholder audio tokens to the full audio token sequence
+ text = [prompt.replace(self.audio_token, self.full_audio_sequence) for prompt in text]
+ else:
+ audio_inputs = {}
+
+ if images is not None:
+ images = self.image_processor.fetch_images(images)
+ batched_images = make_nested_list_of_images(images)
+ image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
+
+ # Create empty text to be replaced with placeholders
+ if not text:
+ text = [" ".join([self.image_token] * len(images)) for images in batched_images]
+
+ if len(batched_images) != len(text):
+ raise ValueError(
+ f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
+ )
+
+ # Expand placeholder image tokens to the full image token sequence
+ text = [prompt.replace(self.image_token, self.full_image_sequence) for prompt in text]
+ else:
+ image_inputs = {}
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
+
+ # Add token type ids manually, as tokenizer can't do arbitrary position token types
+ array_ids = text_inputs["input_ids"]
+ token_type_ids = np.zeros_like(array_ids)
+ token_type_ids[array_ids == self.image_token_id] = 1
+ token_type_ids[array_ids == self.audio_token_id] = 3
+ text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
+ text_inputs["token_type_ids"] = token_type_ids.tolist()
+ return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
+
+
+__all__ = ["Gemma3nProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4v_moe/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/glm4v_moe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f99578a4be721ecdc5bcbd157fe75f8f16384086
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/glm4v_moe/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_glm4v_moe import *
+ from .modeling_glm4v_moe import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4v_moe/configuration_glm4v_moe.py b/venv/lib/python3.13/site-packages/transformers/models/glm4v_moe/configuration_glm4v_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..b06642e250bcfedd82dcb7f47e2aae6d3f249dcb
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/glm4v_moe/configuration_glm4v_moe.py
@@ -0,0 +1,384 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/glm4v_moe/modular_glm4v_moe.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_glm4v_moe.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class Glm4vMoeVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Glm4vMoeVisionModel`]. It is used to instantiate an Glm4vMoeVisionModel
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
+ a similar configuration to that of
+ GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking).
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 1536):
+ Dimensionality of the encoder layers and the pooler layer.
+ depth (`int`, *optional*, defaults to 24):
+ Number of layers (depth) in the model.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to add a bias to the queries, keys and values.
+ intermediate_size (`int`, *optional*, defaults to 13696):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"selu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ Dropout probability for attention weights.
+ projection_dropout (`float`, *optional*, defaults to 0.0):
+ Dropout probability for the projection layer.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ image_size (`int` or `list[int]`, *optional*, defaults to `[336, 336]`):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to `14`):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ out_hidden_size (`int`, *optional*, defaults to 4096):
+ The output hidden size of the vision model.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ spatial_merge_size (`int`, *optional*, defaults to 2):
+ The size used for merging spatial dimensions.
+ temporal_patch_size (`int`, *optional*, defaults to 2):
+ The size used for patches along the temporal dimension.
+ Example:
+
+ ```python
+ >>> from transformers import Glm4vMoeVisionConfig, Glm4vMoeVisionModel
+
+ >>> # Initializing a Glm4vMoeVisionConfig GLM-4.1V-9B style configuration
+ >>> configuration = Glm4vMoeVisionConfig()
+
+ >>> # Initializing a model (with random weights) from the GLM-4.1V-9B configuration
+ >>> model = Glm4vMoeVisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "glm4v_moe"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=24,
+ hidden_size=1536,
+ hidden_act="silu",
+ attention_bias=False,
+ attention_dropout=0.0,
+ num_heads=12,
+ in_channels=3,
+ image_size=336,
+ patch_size=14,
+ rms_norm_eps=1e-05,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ out_hidden_size=4096,
+ intermediate_size=13696,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.out_hidden_size = out_hidden_size
+ self.intermediate_size = intermediate_size
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+
+class Glm4vMoeTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Glm4vMoeModel`]. It is used to instantiate a
+ GLM-4.5V model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of
+ GLM-4.5V [zai-org/GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 151424):
+ Vocabulary size of the Glm4vMoe model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Glm4vMoeModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 10944):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 46):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 96):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5): The factor of the partial rotary position.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 65536):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ attention_bias (`bool`, defaults to `True`, *optional*, defaults to `True`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ moe_intermediate_size (`int`, *optional*, defaults to 1408):
+ Intermediate size of the routed expert.
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
+ number of experts per token.
+ n_shared_experts (`int`, *optional*, defaults to 1):
+ Number of shared experts.
+ n_routed_experts (`int`, *optional*, defaults to 128):
+ Number of routed experts.
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor or routed experts.
+ n_group (`int`, *optional*, defaults to 1):
+ Number of groups for routed experts.
+ topk_group (`int`, *optional*, defaults to 1):
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
+ first_k_dense_replace (`int`, *optional*, defaults to 1):
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
+ \--k dense layers--/
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the topk probabilities.
+
+ ```python
+ >>> from transformers import Glm4vMoeTextModel, Glm4vMoeConfig
+
+ >>> # Initializing a GLM-4.5V style configuration
+ >>> configuration = Glm4vMoeConfig()
+
+ >>> # Initializing a model from the GLM-4.5V style configuration
+ >>> model = Glm4vMoeTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "Glm4vMoe_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `Glm4vMoe`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
+ "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+ base_config_key = "text_config"
+
+ def __init__(
+ self,
+ vocab_size=151424,
+ hidden_size=4096,
+ intermediate_size=10944,
+ num_hidden_layers=46,
+ num_attention_heads=96,
+ partial_rotary_factor=0.5,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=65536,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=True,
+ attention_dropout=0.0,
+ moe_intermediate_size=1408,
+ num_experts_per_tok=8,
+ n_shared_experts=1,
+ n_routed_experts=128,
+ routed_scaling_factor=1.0,
+ n_group=1,
+ topk_group=1,
+ first_k_dense_replace=1,
+ norm_topk_prob=True,
+ **kwargs,
+ ):
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.partial_rotary_factor = partial_rotary_factor
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self, ignore_keys={"mrope_section"})
+
+ # MoE arguments
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_experts_per_tok = num_experts_per_tok
+ self.n_group = n_group
+ self.topk_group = topk_group
+ self.n_shared_experts = n_shared_experts
+ self.n_routed_experts = n_routed_experts
+ self.routed_scaling_factor = routed_scaling_factor
+ self.first_k_dense_replace = first_k_dense_replace
+ self.norm_topk_prob = norm_topk_prob
+
+
+class Glm4vMoeConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Glm4vMoeModel`]. It is used to instantiate a
+ GLM-4.5V model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of
+ GLM-4.5V [zai-org/GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vMoeTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vMoeVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 151363):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 151364):
+ The video token index to encode the image prompt.
+ image_start_token_id (`int`, *optional*, defaults to 151339):
+ The image start token index to encode the start of image.
+ image_end_token_id (`int`, *optional*, defaults to 151340):
+ The image end token index to encode the end of image.
+ video_start_token_id (`int`, *optional*, defaults to 151341):
+ The video start token index to encode the start of video.
+ video_end_token_id (`int`, *optional*, defaults to 151342):
+ The video end token index to encode the end of video.
+
+ ```python
+ >>> from transformers import Glm4vMoeForConditionalGeneration, Glm4vMoeConfig
+
+ >>> # Initializing a GLM-4.5V style configuration
+ >>> configuration = Glm4vMoeConfig()
+
+ >>> # Initializing a model from the GLM-4.5V style configuration
+ >>> model = Glm4vMoeForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "glm4v_moe"
+ sub_configs = {"vision_config": Glm4vMoeVisionConfig, "text_config": Glm4vMoeTextConfig}
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=151363,
+ video_token_id=151364,
+ image_start_token_id=151339,
+ image_end_token_id=151340,
+ video_start_token_id=151341,
+ video_end_token_id=151342,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if isinstance(vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+
+ if isinstance(text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ self.text_config = self.sub_configs["text_config"](**kwargs)
+
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.video_start_token_id = video_start_token_id
+ self.video_end_token_id = video_end_token_id
+ self.image_start_token_id = image_start_token_id
+ self.image_end_token_id = image_end_token_id
+
+
+__all__ = ["Glm4vMoeConfig", "Glm4vMoeTextConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/venv/lib/python3.13/site-packages/transformers/models/glm4v_moe/modeling_glm4v_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d476f9569a35858357257c68bae26e6efe22b8d
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/glm4v_moe/modeling_glm4v_moe.py
@@ -0,0 +1,1752 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/glm4v_moe/modular_glm4v_moe.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_glm4v_moe.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import itertools
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import LayerNorm
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_glm4v_moe import Glm4vMoeConfig, Glm4vMoeTextConfig, Glm4vMoeVisionConfig
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class Glm4vMoeRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Glm4vMoeRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
+
+ Explanation:
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
+ difference with modern LLMs.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ mrope_section(`List(int)`):
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ mrope_section = mrope_section * 2
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+
+ # Keep half or full tensor for later concatenation
+ rotary_dim = cos.shape[-1]
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ # Apply rotary embeddings on the first half or full tensor
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ # Concatenate back to full shape
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
+
+ return q_embed, k_embed
+
+
+class Glm4vMoeTextAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Glm4vMoeTextConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rope_scaling = config.rope_scaling
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
+ )
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Glm4vMoeTextTopkRouter(nn.Module):
+ def __init__(self, config: Glm4vMoeTextConfig):
+ super().__init__()
+ self.config = config
+ self.top_k = config.num_experts_per_tok
+ self.n_routed_experts = config.n_routed_experts
+ self.routed_scaling_factor = config.routed_scaling_factor
+ self.n_group = config.n_group
+ self.topk_group = config.topk_group
+ self.norm_topk_prob = config.norm_topk_prob
+
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
+ self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32))
+
+ @torch.no_grad()
+ def get_topk_indices(self, scores):
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
+ group_scores = (
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
+ .topk(2, dim=-1)[0]
+ .sum(dim=-1)
+ )
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
+ group_mask = torch.zeros_like(group_scores)
+ group_mask.scatter_(1, group_idx, 1)
+ score_mask = (
+ group_mask.unsqueeze(-1)
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
+ .reshape(-1, self.n_routed_experts)
+ )
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
+ return topk_indices
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
+ scores = router_logits.sigmoid()
+ topk_indices = self.get_topk_indices(scores)
+ topk_weights = scores.gather(1, topk_indices)
+ if self.norm_topk_prob:
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
+ topk_weights /= denominator
+ topk_weights = topk_weights * self.routed_scaling_factor
+ return topk_indices, topk_weights
+
+
+class Glm4vMoeTextMoE(nn.Module):
+ """
+ A mixed expert module containing shared experts.
+ """
+
+ def __init__(self, config: Glm4vMoeTextConfig):
+ super().__init__()
+ self.config = config
+ self.experts = nn.ModuleList(
+ [
+ Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size)
+ for _ in range(config.n_routed_experts)
+ ]
+ )
+ self.gate = Glm4vMoeTextTopkRouter(config)
+ self.shared_experts = Glm4vMoeTextMLP(
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
+ )
+
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
+ r"""
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
+ """
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
+ expert_mask = expert_mask.permute(2, 0, 1)
+
+ for expert_idx in range(len(self.experts)):
+ expert = self.experts[expert_idx]
+ mask = expert_mask[expert_idx]
+ token_indices, weight_indices = torch.where(mask)
+
+ if token_indices.numel() > 0:
+ expert_weights = topk_weights[token_indices, weight_indices]
+ expert_input = hidden_states[token_indices]
+ expert_output = expert(expert_input)
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
+
+ # in original deepseek, the output of the experts are gathered once we leave this module
+ # thus the moe module is itelsf an IsolatedParallel module
+ # and all expert are "local" meaning we shard but we don't gather
+ return final_hidden_states.type(hidden_states.dtype)
+
+ def forward(self, hidden_states):
+ residuals = hidden_states
+ orig_shape = hidden_states.shape
+ topk_indices, topk_weights = self.gate(hidden_states)
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
+ hidden_states = hidden_states + self.shared_experts(residuals)
+ return hidden_states
+
+
+class Glm4vMoeTextMLP(nn.Module):
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
+
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class Glm4vMoeTextRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Glm4vMoeTextRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Glm4vMoeTextDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Glm4vMoeTextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = Glm4vMoeTextAttention(config=config, layer_idx=layer_idx)
+
+ if layer_idx >= config.first_k_dense_replace:
+ self.mlp = Glm4vMoeTextMoE(config)
+ else:
+ self.mlp = Glm4vMoeTextMLP(config)
+
+ self.input_layernorm = Glm4vMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Glm4vMoeTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class Glm4vMoePreTrainedModel(PreTrainedModel):
+ config: Glm4vMoeConfig
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Glm4vMoeTextDecoderLayer", "Glm4vMoeVisionBlock"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _can_compile_fullgraph = False
+ _supports_attention_backend = True
+
+ _can_record_outputs = {
+ "hidden_states": Glm4vMoeTextDecoderLayer,
+ "attentions": Glm4vMoeTextAttention,
+ }
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, Glm4vMoeTextTopkRouter):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+
+class Glm4vMoeisionMlp(nn.Module):
+ def __init__(self, config, bias: bool = False):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.out_hidden_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+class Glm4vMoeVisionPatchEmbed(nn.Module):
+ def __init__(self, config: Glm4vMoeVisionConfig) -> None:
+ super().__init__()
+ self.patch_size = config.patch_size
+ self.temporal_patch_size = config.temporal_patch_size
+ self.in_channels = config.in_channels
+ self.embed_dim = config.hidden_size
+
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
+ self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class Glm4vMoeVisionRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+class Glm4vMoeVisionPatchMerger(nn.Module):
+ def __init__(self, dim: int, context_dim: int, hidden_act: str, bias: bool = False) -> None:
+ super().__init__()
+ self.proj = nn.Linear(dim, dim, bias=bias)
+ self.post_projection_norm = LayerNorm(dim)
+ self.gate_proj = nn.Linear(dim, context_dim, bias=bias)
+ self.up_proj = nn.Linear(dim, context_dim, bias=bias)
+ self.down_proj = nn.Linear(context_dim, dim, bias=bias)
+ self.act1 = nn.GELU()
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.proj(hidden_state)
+ hidden_state = self.act1(self.post_projection_norm(hidden_state))
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+class Glm4vMoeVisionEmbeddings(nn.Module):
+ def __init__(self, config: Glm4vMoeVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+
+ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
+ """
+ Forward pass with integrated position encoding adaptation using 2D interpolation.
+
+ Args:
+ embeddings: Input embeddings tensor
+ lengths (torch.Tensor): Sequence lengths for each image in the batch.
+ image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w).
+ h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch.
+ w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch.
+
+ Returns:
+ torch.Tensor: Embeddings with adapted position encoding added.
+ """
+ # Get position embedding parameters
+ pos_embed_weight = self.position_embedding.weight
+ hidden_size = pos_embed_weight.shape[1]
+ total_seq = h_coords.shape[0]
+ device = pos_embed_weight.device
+
+ # Move coordinates to correct device
+ h_coords, w_coords = h_coords.to(device), w_coords.to(device)
+
+ # Handle empty sequence case
+ if total_seq == 0:
+ adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype)
+ else:
+ # Convert inputs to tensors if needed
+ if isinstance(lengths, list):
+ lengths = torch.tensor(lengths, device=device, dtype=torch.long)
+ if not isinstance(image_shapes, torch.Tensor):
+ image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long)
+
+ # Prepare 2D position embedding
+ orig_size_sq = pos_embed_weight.shape[0]
+ orig_size = int(orig_size_sq**0.5)
+ pos_embed_2d = (
+ pos_embed_weight.view(orig_size, orig_size, hidden_size)
+ .permute(2, 0, 1)
+ .unsqueeze(0)
+ .to(device=device, dtype=torch.float32)
+ )
+
+ # Calculate target dimensions for each patch
+ target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
+ device=device, dtype=torch.float32
+ )
+ target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
+ device=device, dtype=torch.float32
+ )
+
+ # Normalize coordinates to [-1, 1] range for grid_sample
+ h_coords = h_coords.to(device=device, dtype=torch.float32)
+ w_coords = w_coords.to(device=device, dtype=torch.float32)
+ norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
+ norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
+
+ # Create sampling grid
+ grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
+
+ # Perform bicubic interpolation
+ interpolated_embed_fp32 = F.grid_sample(
+ pos_embed_2d, grid, mode="bicubic", align_corners=False, padding_mode="border"
+ )
+
+ # Reshape and convert back to original dtype
+ adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
+ adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
+
+ # Add adapted position encoding to embeddings
+ embeddings = embeddings + adapted_pos_embed
+ return embeddings
+
+
+def apply_rotary_pos_emb_vision(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
+ orig_q_dtype = q.dtype
+ orig_k_dtype = k.dtype
+ q, k = q.float(), k.float()
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ q_embed = q_embed.to(orig_q_dtype)
+ k_embed = k_embed.to(orig_k_dtype)
+ return q_embed, k_embed
+
+
+class Glm4vMoeVisionAttention(nn.Module):
+ def __init__(self, config: Glm4vMoeVisionConfig) -> None:
+ super().__init__()
+ self.dim = config.hidden_size
+ self.num_heads = config.num_heads
+ self.head_dim = self.dim // self.num_heads
+ self.num_key_value_groups = 1 # needed for eager attention
+ self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ self.scaling = self.head_dim**-0.5
+ self.config = config
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ query_states, key_states, value_states = (
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ )
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
+
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ if self.config._attn_implementation == "flash_attention_2":
+ # Flash Attention 2: Use cu_seqlens for variable length attention
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+ attn_output, _ = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ cu_seq_lens_q=cu_seqlens,
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+ else:
+ # Other implementations: Process each chunk separately
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+ splits = [
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
+ ]
+
+ attn_outputs = [
+ attention_interface(
+ self,
+ q,
+ k,
+ v,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ is_causal=False,
+ **kwargs,
+ )[0]
+ for q, k, v in zip(*splits)
+ ]
+ attn_output = torch.cat(attn_outputs, dim=1)
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Glm4vMoeVisionBlock(GradientCheckpointingLayer):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.norm1 = Glm4vMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.norm2 = Glm4vMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.attn = Glm4vMoeVisionAttention(config)
+ self.mlp = Glm4vMoeisionMlp(config, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Glm4vMoeTextRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Glm4vMoeTextConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, Glm4vMoeText has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class Glm4vMoeModelOutputWithPast(ModelOutput):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+class Glm4vMoeVisionModel(Glm4vMoePreTrainedModel):
+ config: Glm4vMoeVisionConfig
+ _no_split_modules = ["Glm4vMoeVisionBlock"]
+
+ def __init__(self, config) -> None:
+ super().__init__(config)
+ self.spatial_merge_size = config.spatial_merge_size
+ self.patch_size = config.patch_size
+
+ self.embeddings = Glm4vMoeVisionEmbeddings(config)
+ self.patch_embed = Glm4vMoeVisionPatchEmbed(config)
+
+ head_dim = config.hidden_size // config.num_heads
+ self.rotary_pos_emb = Glm4vMoeVisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList([Glm4vMoeVisionBlock(config) for _ in range(config.depth)])
+ self.merger = Glm4vMoeVisionPatchMerger(
+ dim=config.out_hidden_size, context_dim=config.intermediate_size, hidden_act=config.hidden_act
+ )
+
+ self.post_conv_layernorm = Glm4vMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.downsample = nn.Conv2d(
+ in_channels=config.hidden_size,
+ out_channels=config.out_hidden_size,
+ kernel_size=config.spatial_merge_size,
+ stride=config.spatial_merge_size,
+ )
+ self.post_layernorm = Glm4vMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ self.post_init()
+
+ def rot_pos_emb(self, grid_thw):
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb, pos_ids
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
+ The final hidden states of the model.
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Returns:
+ `torch.Tensor`: hidden_states.
+ """
+ hidden_states = self.patch_embed(hidden_states)
+ hidden_states = self.post_conv_layernorm(hidden_states)
+
+ rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
+
+ for blk in self.blocks:
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states = self.post_layernorm(hidden_states)
+
+ hidden_states = hidden_states.view(
+ -1, self.spatial_merge_size, self.spatial_merge_size, hidden_states.shape[-1]
+ )
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
+ hidden_states = self.downsample(hidden_states).view(-1, self.config.out_hidden_size)
+
+ hidden_states = self.merger(hidden_states)
+ return hidden_states
+
+
+@auto_docstring
+class Glm4vMoeTextModel(Glm4vMoePreTrainedModel):
+ config: Glm4vMoeTextConfig
+
+ def __init__(self, config: Glm4vMoeTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Glm4vMoeTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Glm4vMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Glm4vMoeTextRotaryEmbedding(config=config)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ @check_model_inputs()
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ # torch.jit.trace() doesn't support cache objects in the output
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # the hard coded `3` is for temporal, height and width.
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.dim() == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = layer_outputs
+
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class Glm4vMoeModel(Glm4vMoePreTrainedModel):
+ base_model_prefix = ""
+ _checkpoint_conversion_mapping = {}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: Glm4vMoeConfig
+ _no_split_modules = ["Glm4vMoeTextDecoderLayer", "Glm4vMoeVisionBlock"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = Glm4vMoeVisionModel._from_config(config.vision_config)
+ self.language_model = Glm4vMoeTextModel._from_config(config.text_config)
+ self.rope_deltas = None # cache rope_deltas here
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_rope_index(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
+
+ Explanation:
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
+
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
+ Examples:
+ input_ids: [T T T T T], here T is for text.
+ temporal position_ids: [0, 1, 2, 3, 4]
+ height position_ids: [0, 1, 2, 3, 4]
+ width position_ids: [0, 1, 2, 3, 4]
+
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
+ and 1D rotary position embedding for text part.
+ Examples:
+ Temporal (Time): 3 patches, representing different segments of the video in time.
+ Height: 2 patches, dividing each frame vertically.
+ Width: 2 patches, dividing each frame horizontally.
+ We also have some important parameters:
+ fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
+ tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
+ temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
+ interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
+ vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
+ text temporal position_ids: [101, 102, 103, 104, 105]
+ text height position_ids: [101, 102, 103, 104, 105]
+ text width position_ids: [101, 102, 103, 104, 105]
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ Returns:
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
+ """
+
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_start_token_id = self.config.video_start_token_id
+ video_end_token_id = self.config.video_end_token_id
+
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_index, video_index = 0, 0
+ video_group_index = 0
+ attention_mask = attention_mask.to(total_input_ids.device)
+ for i, input_ids in enumerate(total_input_ids):
+ input_ids = input_ids[attention_mask[i] == 1]
+ input_tokens = input_ids.tolist()
+
+ input_token_type = []
+ video_check_flg = False
+ for token in input_tokens:
+ if token == video_start_token_id:
+ video_check_flg = True
+ elif token == video_end_token_id:
+ video_check_flg = False
+
+ if token == image_token_id and not video_check_flg:
+ input_token_type.append("image")
+ elif token == image_token_id and video_check_flg:
+ input_token_type.append("video")
+ else:
+ input_token_type.append("text")
+
+ input_type_group = []
+ for key, group in itertools.groupby(enumerate(input_token_type), lambda x: x[1]):
+ group = list(group)
+ start_index = group[0][0]
+ end_index = group[-1][0] + 1
+ input_type_group.append((key, start_index, end_index))
+
+ llm_pos_ids_list = []
+ video_frame_num = 1
+ for modality_type, start_idx, end_idx in input_type_group:
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+
+ if modality_type == "image":
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
+
+ image_index += 1
+ video_frame_num = 1
+
+ elif modality_type == "video":
+ t, h, w = (
+ video_frame_num,
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t,
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+
+ for t_idx in range(llm_grid_t):
+ t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
+
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
+
+ video_group_index += 1
+
+ if video_group_index >= video_grid_thw[video_index][0]:
+ video_index += 1
+ video_group_index = 0
+
+ video_frame_num += 1
+
+ else:
+ text_len = end_idx - start_idx
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ video_frame_num = 1
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+ else:
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
+ else:
+ position_ids = (
+ torch.arange(input_ids.shape[1], device=input_ids.device)
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ """
+ Encodes videos into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input videos.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ """
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
+ # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
+ temp_frames_hw = []
+ for t, h, w in video_grid_thw:
+ repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
+ temp_frames_hw.append(repeated_row)
+ flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
+ video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
+ split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ video_embeds = torch.split(video_embeds, split_sizes)
+ return video_embeds
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ pixel_values = pixel_values.type(self.visual.dtype)
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ image_embeds = torch.split(image_embeds, split_sizes)
+ return image_embeds
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: Optional[torch.FloatTensor] = None,
+ video_features: Optional[torch.FloatTensor] = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ else:
+ # GLM-4.1V and GLM-4.5V special_video_mask is special_image_mask
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
+ raise ValueError(
+ f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
+ )
+
+ return special_image_mask, special_video_mask
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Glm4vMoeModelOutputWithPast]:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw)
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds)
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds)
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ if position_ids is None:
+ attention_mask_tensor = (
+ attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
+ )
+ if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
+ attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
+ # Only apply conversion for floating point tensors (inverted masks)
+ if attention_mask_tensor.dtype.is_floating_point:
+ attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
+ attention_mask_tensor = (1.0 - attention_mask_tensor).int()
+
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # When compiling, we can't check tensor values thus we check only input length
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
+ # models currently cannot do asssisted decoding
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
+ (input_ids is not None and input_ids.shape[1] != 1)
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
+ )
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
+ (cache_position is not None and cache_position[0] == 0)
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
+ )
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ video_grid_thw,
+ attention_mask=attention_mask_tensor,
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ delta = (
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
+ if cache_position is not None
+ else 0
+ )
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ if cache_position is not None: # otherwise `deltas` is an int `0`
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return Glm4vMoeModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Glm4vMoe causal language model (or autoregressive) outputs.
+ """
+)
+class Glm4vMoeCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+class Glm4vMoeForConditionalGeneration(Glm4vMoePreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+ _tied_weights_keys = ["lm_head.weight"]
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Glm4vMoeModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_video_features(
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
+ ):
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw)
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ return self.model.get_image_features(pixel_values, image_grid_thw)
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def visual(self):
+ return self.model.visual
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Glm4vMoeCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Glm4vMoeForConditionalGeneration
+
+ >>> model = Glm4vMoeForConditionalGeneration.from_pretrained("THUDM/GLM-4.1V-9B-Thinking")
+ >>> processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": "What is shown in this image?"},
+ ],
+ },
+ ]
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return Glm4vMoeCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ # GLM-4.1V position_ids are prepareed with rope_deltas in forward
+ model_inputs["position_ids"] = None
+
+ if cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_values_videos"] = None
+
+ return model_inputs
+
+ def _get_image_nums_and_video_nums(
+ self,
+ input_ids: Optional[torch.LongTensor],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Returns:
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
+ """
+
+ if inputs_embeds is not None:
+ is_image = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ is_video_start = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ is_video_end = (
+ inputs_embeds
+ == self.get_input_embeddings()(
+ torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ )[..., 0]
+ else:
+ is_image = input_ids == self.config.image_start_token_id
+ is_video_start = input_ids == self.config.video_start_token_id
+ is_video_end = input_ids == self.config.video_end_token_id
+
+ # Cumulative sum to track if we're inside a video span
+ # We'll assume well-formed video tags (i.e. matching starts and ends)
+ video_level = torch.cumsum(is_video_start.int() - is_video_end.int(), dim=1)
+ inside_video = video_level > 0 # shape (batch_size, seq_length)
+
+ # Mask out image tokens that are inside video spans
+ standalone_images = is_image & (~inside_video)
+
+ # Count per batch
+ image_counts = standalone_images.sum(dim=1)
+ video_counts = is_video_start.sum(dim=1)
+
+ return image_counts, video_counts
+
+ def _expand_inputs_for_generation(
+ self,
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: Optional[torch.LongTensor] = None,
+ **model_kwargs,
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
+ # Overwritten -- Support for expanding tensors without a batch size dimension
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
+ # image_grid_thw.shape[0] is sum(num_images for samples)
+
+ if expand_size == 1:
+ return input_ids, model_kwargs
+
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
+
+ def _expand_dict_for_generation_visual(dict_to_expand):
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
+ )
+
+ def _repeat_interleave_samples(x, lengths, repeat_times):
+ samples = torch.split(x, lengths)
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
+ return result
+
+ for key in dict_to_expand:
+ if key == "pixel_values":
+ # split images into samples
+ samples = torch.split(image_grid_thw, list(image_nums))
+ # compute the sequence length of images for each sample
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "image_grid_thw":
+ # get the num of images for each sample
+ lengths = list(image_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "pixel_values_videos":
+ samples = torch.split(video_grid_thw, list(video_nums))
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "video_grid_thw":
+ lengths = list(video_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "second_per_grid_ts":
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
+ )
+ return dict_to_expand
+
+ def _expand_dict_for_generation(dict_to_expand):
+ for key in dict_to_expand:
+ if (
+ key != "cache_position"
+ and dict_to_expand[key] is not None
+ and isinstance(dict_to_expand[key], torch.Tensor)
+ and key not in visual_keys
+ ):
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
+ return dict_to_expand
+
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
+
+ if is_encoder_decoder:
+ if model_kwargs.get("encoder_outputs") is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
+
+ return input_ids, model_kwargs
+
+
+__all__ = ["Glm4vMoeForConditionalGeneration", "Glm4vMoeModel", "Glm4vMoePreTrainedModel", "Glm4vMoeTextModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/glm4v_moe/modular_glm4v_moe.py b/venv/lib/python3.13/site-packages/transformers/models/glm4v_moe/modular_glm4v_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dfe28ff19da878a689afb0ab6621e8cfb35f340
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/glm4v_moe/modular_glm4v_moe.py
@@ -0,0 +1,459 @@
+# coding=utf-8
+# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional
+
+import torch
+import torch.nn as nn
+
+from ...cache_utils import Cache
+from ...configuration_utils import PretrainedConfig
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_rope_utils import rope_config_validation
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import logging
+from ..glm4.modeling_glm4 import Glm4Attention
+from ..glm4_moe.configuration_glm4_moe import Glm4MoeConfig
+from ..glm4_moe.modeling_glm4_moe import (
+ Glm4MoeDecoderLayer,
+ Glm4MoeMLP,
+ Glm4MoeMoE,
+ Glm4MoePreTrainedModel,
+ Glm4MoeRMSNorm,
+ Glm4MoeTopkRouter,
+ eager_attention_forward,
+)
+from ..glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
+from ..glm4v.modeling_glm4v import (
+ Glm4vForConditionalGeneration,
+ rotate_half,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class Glm4vMoeVisionConfig(Glm4vVisionConfig):
+ pass
+
+
+class Glm4vMoeTextConfig(Glm4MoeConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Glm4vMoeModel`]. It is used to instantiate a
+ GLM-4.5V model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of
+ GLM-4.5V [zai-org/GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 151424):
+ Vocabulary size of the Glm4vMoe model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Glm4vMoeModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 10944):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 46):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 96):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5): The factor of the partial rotary position.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 65536):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ attention_bias (`bool`, defaults to `True`, *optional*, defaults to `True`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ moe_intermediate_size (`int`, *optional*, defaults to 1408):
+ Intermediate size of the routed expert.
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
+ number of experts per token.
+ n_shared_experts (`int`, *optional*, defaults to 1):
+ Number of shared experts.
+ n_routed_experts (`int`, *optional*, defaults to 128):
+ Number of routed experts.
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor or routed experts.
+ n_group (`int`, *optional*, defaults to 1):
+ Number of groups for routed experts.
+ topk_group (`int`, *optional*, defaults to 1):
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
+ first_k_dense_replace (`int`, *optional*, defaults to 1):
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
+ \--k dense layers--/
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the topk probabilities.
+
+ ```python
+ >>> from transformers import Glm4vMoeTextModel, Glm4vMoeConfig
+
+ >>> # Initializing a GLM-4.5V style configuration
+ >>> configuration = Glm4vMoeConfig()
+
+ >>> # Initializing a model from the GLM-4.5V style configuration
+ >>> model = Glm4vMoeTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "Glm4vMoe_text"
+ base_config_key = "text_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `Glm4vMoe`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
+ "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=151424,
+ hidden_size=4096,
+ intermediate_size=10944,
+ num_hidden_layers=46,
+ num_attention_heads=96,
+ partial_rotary_factor=0.5,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=65536,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=True,
+ attention_dropout=0.0,
+ moe_intermediate_size=1408,
+ num_experts_per_tok=8,
+ n_shared_experts=1,
+ n_routed_experts=128,
+ routed_scaling_factor=1.0,
+ n_group=1,
+ topk_group=1,
+ first_k_dense_replace=1,
+ norm_topk_prob=True,
+ **kwargs,
+ ):
+ PretrainedConfig.__init__(self, tie_word_embeddings=tie_word_embeddings, **kwargs)
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.partial_rotary_factor = partial_rotary_factor
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self, ignore_keys={"mrope_section"})
+
+ # MoE arguments
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_experts_per_tok = num_experts_per_tok
+ self.n_group = n_group
+ self.topk_group = topk_group
+ self.n_shared_experts = n_shared_experts
+ self.n_routed_experts = n_routed_experts
+ self.routed_scaling_factor = routed_scaling_factor
+ self.first_k_dense_replace = first_k_dense_replace
+ self.norm_topk_prob = norm_topk_prob
+
+
+class Glm4vMoeConfig(Glm4vConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Glm4vMoeModel`]. It is used to instantiate a
+ GLM-4.5V model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of
+ GLM-4.5V [zai-org/GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vMoeTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vMoeVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ image_token_id (`int`, *optional*, defaults to 151363):
+ The image token index to encode the image prompt.
+ video_token_id (`int`, *optional*, defaults to 151364):
+ The video token index to encode the image prompt.
+ image_start_token_id (`int`, *optional*, defaults to 151339):
+ The image start token index to encode the start of image.
+ image_end_token_id (`int`, *optional*, defaults to 151340):
+ The image end token index to encode the end of image.
+ video_start_token_id (`int`, *optional*, defaults to 151341):
+ The video start token index to encode the start of video.
+ video_end_token_id (`int`, *optional*, defaults to 151342):
+ The video end token index to encode the end of video.
+
+ ```python
+ >>> from transformers import Glm4vMoeForConditionalGeneration, Glm4vMoeConfig
+
+ >>> # Initializing a GLM-4.5V style configuration
+ >>> configuration = Glm4vMoeConfig()
+
+ >>> # Initializing a model from the GLM-4.5V style configuration
+ >>> model = Glm4vMoeForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ image_token_id=151363,
+ video_token_id=151364,
+ image_start_token_id=151339,
+ image_end_token_id=151340,
+ video_start_token_id=151341,
+ video_end_token_id=151342,
+ **kwargs,
+ ):
+ super().__init__()
+
+
+class Glm4vMoeRMSNorm(Glm4MoeRMSNorm):
+ pass
+
+
+def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
+
+ Explanation:
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
+ difference with modern LLMs.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ mrope_section(`List(int)`):
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ mrope_section = mrope_section * 2
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+
+ # Keep half or full tensor for later concatenation
+ rotary_dim = cos.shape[-1]
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ # Apply rotary embeddings on the first half or full tensor
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ # Concatenate back to full shape
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
+
+ return q_embed, k_embed
+
+
+class Glm4vMoeTextAttention(Glm4Attention):
+ def __init__(self, config: Glm4vMoeTextConfig, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+ self.rope_scaling = config.rope_scaling
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
+ )
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Glm4vMoeTextTopkRouter(Glm4MoeTopkRouter, nn.Module):
+ def __init__(self, config: Glm4vMoeTextConfig):
+ super().__init__(config)
+
+
+class Glm4vMoeTextMoE(Glm4MoeMoE):
+ def __init__(self, config: Glm4vMoeTextConfig):
+ super().__init__(config)
+ self.config = config
+ self.experts = nn.ModuleList(
+ [
+ Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size)
+ for _ in range(config.n_routed_experts)
+ ]
+ )
+ self.gate = Glm4vMoeTextTopkRouter(config)
+ self.shared_experts = Glm4vMoeTextMLP(
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
+ )
+
+
+class Glm4vMoeTextMLP(Glm4MoeMLP):
+ pass
+
+
+class Glm4vMoeTextDecoderLayer(Glm4MoeDecoderLayer):
+ def __init__(self, config: Glm4vMoeTextConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+
+
+class Glm4vMoePreTrainedModel(Glm4MoePreTrainedModel):
+ config: Glm4vMoeConfig
+ base_model_prefix = ""
+ _no_split_modules = ["Glm4vMoeTextDecoderLayer", "Glm4vMoeVisionBlock"]
+ _skip_keys_device_placement = "past_key_values"
+
+ _can_record_outputs = {
+ "hidden_states": Glm4vMoeTextDecoderLayer,
+ "attentions": Glm4vMoeTextAttention,
+ }
+
+
+class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
+ pass
+
+
+__all__ = [
+ "Glm4vMoeConfig",
+ "Glm4vMoeTextConfig",
+ "Glm4vMoeForConditionalGeneration",
+ "Glm4vMoeModel", # noqa: F822
+ "Glm4vMoePreTrainedModel",
+ "Glm4vMoeTextModel", # noqa: F822
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt2/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/gpt2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f01899e668e3a86548db3f59c7f42d70746385ab
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gpt2/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_gpt2 import *
+ from .modeling_flax_gpt2 import *
+ from .modeling_gpt2 import *
+ from .modeling_tf_gpt2 import *
+ from .tokenization_gpt2 import *
+ from .tokenization_gpt2_fast import *
+ from .tokenization_gpt2_tf import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt2/configuration_gpt2.py b/venv/lib/python3.13/site-packages/transformers/models/gpt2/configuration_gpt2.py
new file mode 100644
index 0000000000000000000000000000000000000000..db5151a2ba15635a7943744799b0689fc96790d3
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gpt2/configuration_gpt2.py
@@ -0,0 +1,274 @@
+# coding=utf-8
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""OpenAI GPT-2 configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+from typing import Any, Optional
+
+from ... import PreTrainedTokenizer, TensorType, is_torch_available
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfigWithPast, PatchingSpec
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class GPT2Config(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to
+ instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the GPT-2
+ [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50257):
+ Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
+ n_positions (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ n_embd (`int`, *optional*, defaults to 768):
+ Dimensionality of the embeddings and hidden states.
+ n_layer (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ n_head (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ n_inner (`int`, *optional*):
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
+ activation_function (`str`, *optional*, defaults to `"gelu_new"`):
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ embd_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the embeddings.
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
+ The epsilon to use in the layer normalization layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ summary_type (`string`, *optional*, defaults to `"cls_index"`):
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
+ [`TFGPT2DoubleHeadsModel`].
+
+ Has to be one of the following options:
+
+ - `"last"`: Take the last token hidden state (like XLNet).
+ - `"first"`: Take the first token hidden state (like BERT).
+ - `"mean"`: Take the mean of all tokens hidden states.
+ - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
+ - `"attn"`: Not implemented now, use multi-head attention.
+ summary_use_proj (`bool`, *optional*, defaults to `True`):
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
+ [`TFGPT2DoubleHeadsModel`].
+
+ Whether or not to add a projection after the vector extraction.
+ summary_activation (`str`, *optional*):
+ Argument used when doing sequence summary. Used in for the multiple choice head in
+ [`GPT2DoubleHeadsModel`].
+
+ Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
+ summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
+ [`TFGPT2DoubleHeadsModel`].
+
+ Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
+ summary_first_dropout (`float`, *optional*, defaults to 0.1):
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
+ [`TFGPT2DoubleHeadsModel`].
+
+ The dropout ratio to be used after the projection and activation.
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
+ Scale attention weights by dividing by sqrt(hidden_size)..
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ bos_token_id (`int`, *optional*, defaults to 50256):
+ Id of the beginning of sentence token in the vocabulary.
+ eos_token_id (`int`, *optional*, defaults to 50256):
+ Id of the end of sentence token in the vocabulary.
+ scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
+ Whether to additionally scale attention weights by `1 / layer_idx + 1`.
+ reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
+ Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
+ dot-product/softmax to float() when training with mixed precision.
+
+ Example:
+
+ ```python
+ >>> from transformers import GPT2Config, GPT2Model
+
+ >>> # Initializing a GPT2 configuration
+ >>> configuration = GPT2Config()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = GPT2Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "gpt2"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {
+ "hidden_size": "n_embd",
+ "max_position_embeddings": "n_positions",
+ "num_attention_heads": "n_head",
+ "num_hidden_layers": "n_layer",
+ }
+
+ def __init__(
+ self,
+ vocab_size=50257,
+ n_positions=1024,
+ n_embd=768,
+ n_layer=12,
+ n_head=12,
+ n_inner=None,
+ activation_function="gelu_new",
+ resid_pdrop=0.1,
+ embd_pdrop=0.1,
+ attn_pdrop=0.1,
+ layer_norm_epsilon=1e-5,
+ initializer_range=0.02,
+ summary_type="cls_index",
+ summary_use_proj=True,
+ summary_activation=None,
+ summary_proj_to_labels=True,
+ summary_first_dropout=0.1,
+ scale_attn_weights=True,
+ use_cache=True,
+ bos_token_id=50256,
+ eos_token_id=50256,
+ scale_attn_by_inverse_layer_idx=False,
+ reorder_and_upcast_attn=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.n_positions = n_positions
+ self.n_embd = n_embd
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.n_inner = n_inner
+ self.activation_function = activation_function
+ self.resid_pdrop = resid_pdrop
+ self.embd_pdrop = embd_pdrop
+ self.attn_pdrop = attn_pdrop
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_range = initializer_range
+ self.summary_type = summary_type
+ self.summary_use_proj = summary_use_proj
+ self.summary_activation = summary_activation
+ self.summary_first_dropout = summary_first_dropout
+ self.summary_proj_to_labels = summary_proj_to_labels
+ self.scale_attn_weights = scale_attn_weights
+ self.use_cache = use_cache
+ self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
+ self.reorder_and_upcast_attn = reorder_and_upcast_attn
+
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+
+class GPT2OnnxConfig(OnnxConfigWithPast):
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ task: str = "default",
+ patching_specs: Optional[list[PatchingSpec]] = None,
+ use_past: bool = False,
+ ):
+ super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
+ if not getattr(self._config, "pad_token_id", None):
+ # TODO: how to do that better?
+ self._config.pad_token_id = 0
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
+ if self.use_past:
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
+ common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
+ else:
+ common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
+
+ return common_inputs
+
+ @property
+ def num_layers(self) -> int:
+ return self._config.n_layer
+
+ @property
+ def num_attention_heads(self) -> int:
+ return self._config.n_head
+
+ def generate_dummy_inputs(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ batch_size: int = -1,
+ seq_length: int = -1,
+ is_pair: bool = False,
+ framework: Optional[TensorType] = None,
+ ) -> Mapping[str, Any]:
+ common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
+ tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
+ )
+
+ # We need to order the input in the way they appears in the forward()
+ ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
+
+ # Need to add the past_keys
+ if self.use_past:
+ if not is_torch_available():
+ raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
+ else:
+ import torch
+
+ batch, seqlen = common_inputs["input_ids"].shape
+ # Not using the same length for past_key_values
+ past_key_values_length = seqlen + 2
+ past_shape = (
+ batch,
+ self.num_attention_heads,
+ past_key_values_length,
+ self._config.hidden_size // self.num_attention_heads,
+ )
+ ordered_inputs["past_key_values"] = [
+ (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
+ ]
+
+ ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
+ if self.use_past:
+ mask_dtype = ordered_inputs["attention_mask"].dtype
+ ordered_inputs["attention_mask"] = torch.cat(
+ [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
+ )
+
+ return ordered_inputs
+
+ @property
+ def default_onnx_opset(self) -> int:
+ return 13
+
+
+__all__ = ["GPT2Config", "GPT2OnnxConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt2/modeling_flax_gpt2.py b/venv/lib/python3.13/site-packages/transformers/models/gpt2/modeling_flax_gpt2.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e419217c5a3642ee27f6f3df87e1c27c0d5ac79
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gpt2/modeling_flax_gpt2.py
@@ -0,0 +1,782 @@
+# coding=utf-8
+# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Optional
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+
+from ...modeling_flax_outputs import (
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
+ FlaxCausalLMOutputWithCrossAttentions,
+)
+from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from .configuration_gpt2 import GPT2Config
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "openai-community/gpt2"
+_CONFIG_FOR_DOC = "GPT2Config"
+
+
+GPT2_START_DOCSTRING = r"""
+
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+GPT2_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class FlaxConv1D(nn.Module):
+ features: int
+ use_bias: bool = True
+ dtype: Any = jnp.float32
+ precision: Any = None
+
+ @nn.compact
+ def __call__(self, inputs):
+ inputs = jnp.asarray(inputs, self.dtype)
+ kernel = self.param("kernel", jax.nn.initializers.normal(stddev=0.02), (self.features, inputs.shape[-1]))
+ kernel = jnp.asarray(kernel.transpose(), self.dtype)
+ y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision)
+ if self.use_bias:
+ bias = self.param("bias", jax.nn.initializers.zeros, (self.features,))
+ bias = jnp.asarray(bias, self.dtype)
+ y = y + bias
+ return y
+
+
+class FlaxGPT2Attention(nn.Module):
+ config: GPT2Config
+ dtype: jnp.dtype = jnp.float32
+ causal: bool = True
+ is_cross_attention: bool = False
+
+ def setup(self):
+ config = self.config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+
+ if self.is_cross_attention:
+ self.c_attn = FlaxConv1D(2 * self.embed_dim, dtype=self.dtype)
+ self.q_attn = FlaxConv1D(self.embed_dim, dtype=self.dtype)
+ else:
+ self.c_attn = FlaxConv1D(3 * self.embed_dim, dtype=self.dtype)
+ self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype)
+
+ self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
+
+ if self.causal:
+ self.causal_mask = make_causal_mask(
+ jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool"
+ )
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
+
+ @nn.compact
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slightly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
+ def __call__(
+ self,
+ hidden_states,
+ key_value_states: Optional[jnp.ndarray] = None,
+ attention_mask=None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ batch_size = hidden_states.shape[0]
+
+ if not is_cross_attention:
+ qkv_out = self.c_attn(hidden_states)
+ query, key, value = jnp.split(qkv_out, 3, axis=2)
+ else:
+ q_out = self.q_attn(hidden_states)
+ (query,) = jnp.split(q_out, 1, axis=2)
+ kv_out = self.c_attn(key_value_states)
+ key, value = jnp.split(kv_out, 2, axis=2)
+
+ query = self._split_heads(query)
+ key = self._split_heads(key)
+ value = self._split_heads(value)
+
+ query_length, key_length = query.shape[1], key.shape[1]
+
+ if self.causal:
+ if self.has_variable("cache", "cached_key"):
+ mask_shift = self.variables["cache"]["cache_index"]
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_mask = lax.dynamic_slice(
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
+ )
+ else:
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+ # combine masks if needed
+ if attention_mask is not None and self.causal:
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_mask)
+ elif self.causal:
+ attention_mask = causal_mask
+ elif attention_mask is not None:
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+
+ dropout_rng = None
+ if not deterministic and self.config.attn_pdrop > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+ key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
+
+ # transform boolean mask into float mask
+ if attention_mask is not None:
+ attention_bias = lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
+ )
+ else:
+ attention_bias = None
+
+ # usual dot product attention
+ attn_weights = dot_product_attention_weights(
+ query,
+ key,
+ bias=attention_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.config.attn_pdrop,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ precision=None,
+ )
+
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
+ attn_output = self._merge_heads(attn_output)
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
+
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
+ return outputs
+
+
+class FlaxGPT2MLP(nn.Module):
+ config: GPT2Config
+ intermediate_size: int
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ embed_dim = self.config.hidden_size
+ self.c_fc = FlaxConv1D(self.intermediate_size, dtype=self.dtype)
+ self.c_proj = FlaxConv1D(embed_dim, dtype=self.dtype)
+ self.act = ACT2FN[self.config.activation_function]
+ self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
+
+ def __call__(self, hidden_states, deterministic: bool = True):
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ return hidden_states
+
+
+class FlaxGPT2Block(nn.Module):
+ config: GPT2Config
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ hidden_size = self.config.hidden_size
+ inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
+
+ self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+ self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
+ self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+
+ if self.config.add_cross_attention:
+ self.crossattention = FlaxGPT2Attention(
+ config=self.config, dtype=self.dtype, causal=False, is_cross_attention=True
+ )
+ self.ln_cross_attn = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+
+ self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ residual = hidden_states
+ hidden_states = self.ln_1(hidden_states)
+ attn_outputs = self.attn(
+ hidden_states,
+ attention_mask=attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ )
+ # residual connection
+ attn_output = attn_outputs[0] # output_attn: a, (attentions)
+ outputs = attn_outputs[1:]
+ # residual connection
+ hidden_states = attn_output + residual
+
+ # Cross-Attention Block
+ if encoder_hidden_states is not None:
+ # add one self-attention block for cross-attention
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
+ "cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+ residual = hidden_states
+ hidden_states = self.ln_cross_attn(hidden_states)
+ cross_attn_outputs = self.crossattention(
+ hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ )
+ attn_output = cross_attn_outputs[0]
+ # residual connection
+ hidden_states = residual + attn_output
+ outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
+
+ residual = hidden_states
+ hidden_states = self.ln_2(hidden_states)
+ feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
+ # residual connection
+ hidden_states = residual + feed_forward_hidden_states
+
+ outputs = (hidden_states,) + outputs
+
+ return outputs
+
+
+class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = GPT2Config
+ base_model_prefix = "transformer"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: GPT2Config,
+ input_shape: tuple = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ if self.config.add_cross_attention:
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
+ encoder_attention_mask = attention_mask
+ module_init_outputs = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ position_ids,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ return_dict=False,
+ )
+ else:
+ module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
+
+ random_params = module_init_outputs["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ def init_cache(self, batch_size, max_length):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ """
+ # init input variables to retrieve cache
+ input_ids = jnp.ones((batch_size, max_length))
+ attention_mask = jnp.ones_like(input_ids)
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
+ )
+ return unfreeze(init_variables["cache"])
+
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ position_ids=None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ params: Optional[dict] = None,
+ past_key_values: Optional[dict] = None,
+ dropout_rng: jax.random.PRNGKey = None,
+ train: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if encoder_hidden_states is not None and encoder_attention_mask is None:
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ batch_size, sequence_length = input_ids.shape
+
+ if position_ids is None:
+ if past_key_values is not None:
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
+
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ if attention_mask is None:
+ attention_mask = jnp.ones((batch_size, sequence_length))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ jnp.array(position_ids, dtype="i4"),
+ encoder_hidden_states,
+ encoder_attention_mask,
+ not train,
+ False,
+ output_attentions,
+ output_hidden_states,
+ return_dict,
+ rngs=rngs,
+ mutable=mutable,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past_key_values = outputs
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past_key_values = outputs
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
+
+ return outputs
+
+
+class FlaxGPT2BlockCollection(nn.Module):
+ config: GPT2Config
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.blocks = [
+ FlaxGPT2Block(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
+ ]
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ for block in self.blocks:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = block(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ # this contains possible `None` values - `FlaxGPT2Module` will filter them out
+ outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
+
+ return outputs
+
+
+class FlaxGPT2Module(nn.Module):
+ config: GPT2Config
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.embed_dim = self.config.hidden_size
+
+ self.wte = nn.Embed(
+ self.config.vocab_size,
+ self.embed_dim,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ dtype=self.dtype,
+ )
+ self.wpe = nn.Embed(
+ self.config.max_position_embeddings,
+ self.embed_dim,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ dtype=self.dtype,
+ )
+ self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
+ self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)
+ self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ deterministic=True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ input_embeds = self.wte(input_ids.astype("i4"))
+ position_embeds = self.wpe(position_ids.astype("i4"))
+
+ hidden_states = input_embeds + position_embeds
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+
+ outputs = self.h(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = outputs[1] + (hidden_states,)
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
+ else:
+ outputs = (hidden_states,) + outputs[1:]
+
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=outputs[1],
+ attentions=outputs[2],
+ cross_attentions=outputs[3],
+ )
+
+
+@add_start_docstrings(
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
+ GPT2_START_DOCSTRING,
+)
+class FlaxGPT2Model(FlaxGPT2PreTrainedModel):
+ module_class = FlaxGPT2Module
+
+
+append_call_sample_docstring(
+ FlaxGPT2Model,
+ _CHECKPOINT_FOR_DOC,
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
+ _CONFIG_FOR_DOC,
+)
+
+
+class FlaxGPT2LMHeadModule(nn.Module):
+ config: GPT2Config
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.transformer = FlaxGPT2Module(self.config, dtype=self.dtype)
+ self.lm_head = nn.Dense(
+ self.config.vocab_size,
+ use_bias=False,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ outputs = self.transformer(
+ input_ids,
+ attention_mask,
+ position_ids,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ if self.config.tie_word_embeddings:
+ shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
+ else:
+ lm_logits = self.lm_head(hidden_states)
+
+ if not return_dict:
+ return (lm_logits,) + outputs[1:]
+
+ return FlaxCausalLMOutputWithCrossAttentions(
+ logits=lm_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """,
+ GPT2_START_DOCSTRING,
+)
+class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
+ module_class = FlaxGPT2LMHeadModule
+
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
+ # initializing the cache
+ batch_size, seq_length = input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since GPT2 uses a causal mask, those positions are masked anyways.
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if attention_mask is not None:
+ position_ids = attention_mask.cumsum(axis=-1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(
+ extended_attention_mask, attention_mask.astype("i4"), (0, 0)
+ )
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "attention_mask": extended_attention_mask,
+ "position_ids": position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
+ return model_kwargs
+
+
+append_call_sample_docstring(
+ FlaxGPT2LMHeadModel,
+ _CHECKPOINT_FOR_DOC,
+ FlaxCausalLMOutputWithCrossAttentions,
+ _CONFIG_FOR_DOC,
+)
+
+
+__all__ = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt2/modeling_gpt2.py b/venv/lib/python3.13/site-packages/transformers/models/gpt2/modeling_gpt2.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae0786179464115b880ab5d5b4c771292ad5b2db
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gpt2/modeling_gpt2.py
@@ -0,0 +1,1638 @@
+# coding=utf-8
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch OpenAI GPT-2 model."""
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN, get_activation
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ auto_docstring,
+ logging,
+)
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.model_parallel_utils import assert_device_map, get_device_map
+from .configuration_gpt2 import GPT2Config
+
+
+logger = logging.get_logger(__name__)
+
+
+def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
+ """Load tf checkpoints in a pytorch model"""
+ try:
+ import re
+
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array.squeeze())
+
+ for name, array in zip(names, arrays):
+ name = name[6:] # skip "model/"
+ name = name.split("/")
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
+ scope_names = re.split(r"(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "w" or scope_names[0] == "g":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "b":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
+ pointer = getattr(pointer, scope_names[0])
+ pointer = getattr(pointer, "weight")
+ else:
+ pointer = getattr(pointer, scope_names[0])
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ try:
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+ except ValueError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
+
+ if module.scale_attn_weights:
+ attn_weights = attn_weights / torch.full(
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
+ )
+
+ # Layer-wise attention scaling
+ if module.scale_attn_by_inverse_layer_idx:
+ attn_weights = attn_weights / float(module.layer_idx + 1)
+
+ if not module.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
+ attn_weights = attn_weights.type(value.dtype)
+ attn_weights = module.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2)
+
+ return attn_output, attn_weights
+
+
+class GPT2Attention(nn.Module):
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
+ super().__init__()
+ self.config = config
+ max_positions = config.max_position_embeddings
+ self.register_buffer(
+ "bias",
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
+ 1, 1, max_positions, max_positions
+ ),
+ persistent=False,
+ )
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
+
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ self.split_size = self.embed_dim
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ self.scale_attn_weights = config.scale_attn_weights
+ self.is_cross_attention = is_cross_attention
+
+ # Layer-wise attention scaling, reordering, and upcasting
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
+ self.layer_idx = layer_idx
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
+
+ if self.is_cross_attention:
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
+ else:
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
+
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
+ self.is_causal = True
+
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
+
+ # Prune conv1d layers
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
+
+ # Update hyper params
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
+ self.num_heads = self.num_heads - len(heads)
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
+ bsz, num_heads, q_seq_len, dk = query.size()
+ _, _, k_seq_len, _ = key.size()
+
+ # Preallocate attn_weights for `baddbmm`
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
+
+ # Compute Scale Factor
+ scale_factor = 1.0
+ if self.scale_attn_weights:
+ scale_factor /= float(value.size(-1)) ** 0.5
+
+ if self.scale_attn_by_inverse_layer_idx:
+ scale_factor /= float(self.layer_idx + 1)
+
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
+ with torch.autocast(query.device.type, enabled=False):
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
+
+ if not self.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
+ if attn_weights.dtype != torch.float32:
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
+ attn_weights = attn_weights.type(value.dtype)
+ attn_weights = self.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2)
+
+ return attn_output, attn_weights
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: Optional[tuple[torch.FloatTensor]],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ **kwargs,
+ ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
+ is_cross_attention = encoder_hidden_states is not None
+ if past_key_values is not None:
+ if isinstance(past_key_values, EncoderDecoderCache):
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_layer from cache
+ curr_past_key_value = past_key_values.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_values.self_attention_cache
+ else:
+ curr_past_key_value = past_key_values
+
+ if is_cross_attention:
+ if not hasattr(self, "q_attn"):
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
+ )
+ query_states = self.q_attn(hidden_states)
+ attention_mask = encoder_attention_mask
+
+ # Try to get key/value states from cache if possible
+ if past_key_values is not None and is_updated:
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
+ value_states = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
+ key_states = key_states.view(shape_kv).transpose(1, 2)
+ value_states = value_states.view(shape_kv).transpose(1, 2)
+ else:
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
+ key_states = key_states.view(shape_kv).transpose(1, 2)
+ value_states = value_states.view(shape_kv).transpose(1, 2)
+
+ shape_q = (*query_states.shape[:-1], -1, self.head_dim)
+ query_states = query_states.view(shape_q).transpose(1, 2)
+
+ if (past_key_values is not None and not is_cross_attention) or (
+ past_key_values is not None and is_cross_attention and not is_updated
+ ):
+ # save all key/value_layer to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention:
+ past_key_values.is_updated[self.layer_idx] = True
+
+ is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
+
+ using_eager = self.config._attn_implementation == "eager"
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ if using_eager and self.reorder_and_upcast_attn:
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
+ query_states, key_states, value_states, attention_mask, head_mask
+ )
+ else:
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ head_mask=head_mask,
+ dropout=self.attn_dropout.p if self.training else 0.0,
+ is_causal=is_causal,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+
+ return attn_output, attn_weights
+
+
+class GPT2MLP(nn.Module):
+ def __init__(self, intermediate_size, config):
+ super().__init__()
+ embed_dim = config.hidden_size
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
+ self.act = ACT2FN[config.activation_function]
+ self.dropout = nn.Dropout(config.resid_pdrop)
+
+ def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.FloatTensor:
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class GPT2Block(GradientCheckpointingLayer):
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ hidden_size = config.hidden_size
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
+
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+ self.attn = GPT2Attention(config=config, layer_idx=layer_idx)
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ if config.add_cross_attention:
+ self.crossattention = GPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx)
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ self.mlp = GPT2MLP(inner_dim, config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: Optional[tuple[torch.FloatTensor]],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ **kwargs,
+ ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]:
+ residual = hidden_states
+ hidden_states = self.ln_1(hidden_states)
+ attn_output, self_attn_weights = self.attn(
+ hidden_states,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ **kwargs,
+ )
+ # residual connection
+ hidden_states = attn_output + residual
+
+ if encoder_hidden_states is not None:
+ # add one self-attention block for cross-attention
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
+ "cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+ residual = hidden_states
+ hidden_states = self.ln_cross_attn(hidden_states)
+ cross_attn_output, cross_attn_weights = self.crossattention(
+ hidden_states,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ # residual connection
+ hidden_states = residual + cross_attn_output
+
+ residual = hidden_states
+ hidden_states = self.ln_2(hidden_states)
+ feed_forward_hidden_states = self.mlp(hidden_states)
+ # residual connection
+ hidden_states = residual + feed_forward_hidden_states
+
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (self_attn_weights,)
+ if encoder_hidden_states is not None:
+ outputs += (cross_attn_weights,)
+
+ return outputs
+
+
+# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->GPT2
+class GPT2SequenceSummary(nn.Module):
+ r"""
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ config ([`GPT2Config`]):
+ The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
+ config class of your model for the default values it uses):
+
+ - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
+
+ - `"last"` -- Take the last token hidden state (like XLNet)
+ - `"first"` -- Take the first token hidden state (like Bert)
+ - `"mean"` -- Take the mean of all tokens hidden states
+ - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
+ - `"attn"` -- Not implemented now, use multi-head attention
+
+ - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
+ - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
+ (otherwise to `config.hidden_size`).
+ - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
+ another string or `None` will add no activation.
+ - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
+ - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
+ """
+
+ def __init__(self, config: GPT2Config):
+ super().__init__()
+
+ self.summary_type = getattr(config, "summary_type", "last")
+ if self.summary_type == "attn":
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
+ raise NotImplementedError
+
+ self.summary = nn.Identity()
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
+ num_classes = config.num_labels
+ else:
+ num_classes = config.hidden_size
+ self.summary = nn.Linear(config.hidden_size, num_classes)
+
+ activation_string = getattr(config, "summary_activation", None)
+ self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
+
+ self.first_dropout = nn.Identity()
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
+
+ self.last_dropout = nn.Identity()
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
+ ) -> torch.FloatTensor:
+ """
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
+ The hidden states of the last layer.
+ cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
+ Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
+
+ Returns:
+ `torch.FloatTensor`: The summary of the sequence hidden states.
+ """
+ if self.summary_type == "last":
+ output = hidden_states[:, -1]
+ elif self.summary_type == "first":
+ output = hidden_states[:, 0]
+ elif self.summary_type == "mean":
+ output = hidden_states.mean(dim=1)
+ elif self.summary_type == "cls_index":
+ if cls_index is None:
+ cls_index = torch.full_like(
+ hidden_states[..., :1, :],
+ hidden_states.shape[-2] - 1,
+ dtype=torch.long,
+ )
+ else:
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
+ elif self.summary_type == "attn":
+ raise NotImplementedError
+
+ output = self.first_dropout(output)
+ output = self.summary(output)
+ output = self.activation(output)
+ output = self.last_dropout(output)
+
+ return output
+
+
+@auto_docstring
+class GPT2PreTrainedModel(PreTrainedModel):
+ config: GPT2Config
+ load_tf_weights = load_tf_weights_in_gpt2
+ base_model_prefix = "transformer"
+ is_parallelizable = True
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GPT2Block"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_attention_backend = True
+
+ _can_compile_fullgraph = True
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, (nn.Linear, Conv1D)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
+ #
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+ for name, p in module.named_parameters():
+ if name == "c_proj.weight":
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for outputs of models predicting if two sentences are consecutive or not.
+ """
+)
+class GPT2DoubleHeadsModelOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
+ Multiple choice classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ mc_loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ mc_logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+PARALLELIZE_DOCSTRING = r"""
+ This is an experimental feature and is a subject to change at a moment's notice.
+
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
+ it will evenly distribute blocks across all devices.
+
+ Args:
+ device_map (`dict[int, list]`, *optional*):
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
+ following number of attention modules:
+
+ - openai-community/gpt2: 12
+ - openai-community/gpt2-medium: 24
+ - openai-community/gpt2-large: 36
+ - openai-community/gpt2-xl: 48
+
+ Example:
+
+ ```python
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl")
+ device_map = {
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ }
+ model.parallelize(device_map)
+ ```
+"""
+DEPARALLELIZE_DOCSTRING = r"""
+ Moves the model to cpu from a model parallel state.
+
+ Example:
+
+ ```python
+ # On a 4 GPU machine with openai-community/gpt2-large:
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large")
+ device_map = {
+ 0: [0, 1, 2, 3, 4, 5, 6, 7],
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
+ }
+ model.parallelize(device_map) # Splits the model across several devices
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
+ ```
+"""
+
+
+@auto_docstring
+class GPT2Model(GPT2PreTrainedModel):
+ _supports_param_buffer_assignment = False
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embed_dim = config.hidden_size
+
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
+
+ self.drop = nn.Dropout(config.embd_pdrop)
+ self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+ self.gradient_checkpointing = False
+ self._attn_implementation = config._attn_implementation
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
+ def parallelize(self, device_map=None):
+ # Check validity of device_map
+ warnings.warn(
+ "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
+ " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
+ " ...}",
+ FutureWarning,
+ )
+ self.device_map = (
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
+ )
+ assert_device_map(self.device_map, len(self.h))
+ self.model_parallel = True
+ self.first_device = "cpu" if "cpu" in self.device_map else "cuda:" + str(min(self.device_map.keys()))
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
+ self.wte = self.wte.to(self.first_device)
+ self.wpe = self.wpe.to(self.first_device)
+ # Load onto devices
+ for k, v in self.device_map.items():
+ for block in v:
+ cuda_device = "cuda:" + str(k)
+ self.h[block] = self.h[block].to(cuda_device)
+ # ln_f to last
+ self.ln_f = self.ln_f.to(self.last_device)
+
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
+ def deparallelize(self):
+ warnings.warn(
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
+ FutureWarning,
+ )
+ self.model_parallel = False
+ self.device_map = None
+ self.first_device = "cpu"
+ self.last_device = "cpu"
+ self.wte = self.wte.to("cpu")
+ self.wpe = self.wpe.to("cpu")
+ for index in range(len(self.h)):
+ self.h[index] = self.h[index].to("cpu")
+ self.ln_f = self.ln_f.to("cpu")
+ torch.cuda.empty_cache()
+
+ def get_input_embeddings(self):
+ return self.wte
+
+ def set_input_embeddings(self, new_embeddings):
+ self.wte = new_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
+ """
+ for layer, heads in heads_to_prune.items():
+ self.h[layer].attn.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size = inputs_embeds.shape[0]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
+ if use_cache:
+ if past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+ elif isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
+ "You should pass an instance of `Cache` instead, e.g. "
+ "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
+ past_key_values = EncoderDecoderCache(past_key_values, DynamicCache(config=self.config))
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ position_embeds = self.wpe(position_ids)
+ hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
+
+ # Attention mask.
+ # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
+ if attention_mask is not None and attention_mask.ndim < 4:
+ attention_mask = attention_mask.view(batch_size, -1)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ if _use_sdpa:
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+ elif self._attn_implementation != "flash_attention_2":
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # head_mask has shape n_layer x batch x n_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if token_type_ids is not None:
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
+
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, block in enumerate(self.h):
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ if isinstance(head_mask, torch.Tensor):
+ head_mask = head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = block(
+ hidden_states,
+ past_key_values if not (self.gradient_checkpointing and self.training) else None,
+ cache_position,
+ causal_mask,
+ head_mask[i],
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (outputs[2],)
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ past_key_values = past_key_values if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """
+)
+class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = GPT2Model(config)
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
+ def parallelize(self, device_map=None):
+ warnings.warn(
+ "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
+ " 0, 'transformer.h.1': 1, ...}",
+ FutureWarning,
+ )
+ self.device_map = (
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
+ if device_map is None
+ else device_map
+ )
+ assert_device_map(self.device_map, len(self.transformer.h))
+ self.transformer.parallelize(self.device_map)
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
+ self.model_parallel = True
+
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
+ def deparallelize(self):
+ warnings.warn(
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
+ FutureWarning,
+ )
+ self.transformer.deparallelize()
+ self.transformer = self.transformer.to("cpu")
+ self.lm_head = self.lm_head.to("cpu")
+ self.model_parallel = False
+ torch.cuda.empty_cache()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+
+ # Set device for model parallelism
+ if self.model_parallel:
+ torch.cuda.set_device(self.transformer.first_device)
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
+
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ # Flatten the tokens
+ loss = self.loss_function(
+ logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ cross_attentions=transformer_outputs.cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
+ input sequence).
+ """
+)
+class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ config.num_labels = 1
+ self.transformer = GPT2Model(config)
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+ self.multiple_choice_head = GPT2SequenceSummary(config)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
+ def parallelize(self, device_map=None):
+ warnings.warn(
+ "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
+ " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
+ " own `device_map` but it needs to be a dictionary module_name to device, so for instance"
+ " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
+ FutureWarning,
+ )
+ self.device_map = (
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
+ if device_map is None
+ else device_map
+ )
+ assert_device_map(self.device_map, len(self.transformer.h))
+ self.transformer.parallelize(self.device_map)
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
+ self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device)
+ self.model_parallel = True
+
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
+ def deparallelize(self):
+ warnings.warn(
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
+ FutureWarning,
+ )
+ self.transformer.deparallelize()
+ self.transformer = self.transformer.to("cpu")
+ self.lm_head = self.lm_head.to("cpu")
+ self.multiple_choice_head = self.multiple_choice_head.to("cpu")
+ self.model_parallel = False
+ torch.cuda.empty_cache()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ mc_token_ids: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ mc_labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, GPT2DoubleHeadsModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
+ 1]`.
+ labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
+ `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
+ mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
+ where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
+ >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2")
+
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
+ >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
+ >>> # Update the model embeddings with the new vocabulary size
+ >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
+
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
+ >>> encoded_choices = [tokenizer.encode(s) for s in choices]
+ >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
+
+ >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
+ >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
+
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
+ >>> lm_logits = outputs.logits
+ >>> mc_logits = outputs.mc_logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+
+ # Set device for model parallelism
+ if self.model_parallel:
+ torch.cuda.set_device(self.transformer.first_device)
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
+
+ lm_logits = self.lm_head(hidden_states)
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
+
+ mc_loss = None
+ if mc_labels is not None:
+ loss_fct = CrossEntropyLoss()
+ mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
+ lm_loss = None
+ if labels is not None:
+ labels = labels.to(lm_logits.device)
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits, mc_logits) + transformer_outputs[1:]
+ if mc_loss is not None:
+ output = (mc_loss,) + output
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return GPT2DoubleHeadsModelOutput(
+ loss=lm_loss,
+ mc_loss=mc_loss,
+ logits=lm_logits,
+ mc_logits=mc_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The GPT2 Model transformer with a sequence classification head on top (linear layer).
+
+ [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-1) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+class GPT2ForSequenceClassification(GPT2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = GPT2Model(config)
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class GPT2ForTokenClassification(GPT2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.transformer = GPT2Model(config)
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
+ classifier_dropout = config.classifier_dropout
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, TokenClassifierOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = GPT2Model(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, QuestionAnsweringModelOutput]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "GPT2DoubleHeadsModel",
+ "GPT2ForQuestionAnswering",
+ "GPT2ForSequenceClassification",
+ "GPT2ForTokenClassification",
+ "GPT2LMHeadModel",
+ "GPT2Model",
+ "GPT2PreTrainedModel",
+ "load_tf_weights_in_gpt2",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt2/modeling_tf_gpt2.py b/venv/lib/python3.13/site-packages/transformers/models/gpt2/modeling_tf_gpt2.py
new file mode 100644
index 0000000000000000000000000000000000000000..42e23fc290151f09d47a30efca1cb7f4e4a3d669
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gpt2/modeling_tf_gpt2.py
@@ -0,0 +1,1238 @@
+# coding=utf-8
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 OpenAI GPT-2 model."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutputWithPastAndCrossAttentions,
+ TFCausalLMOutputWithCrossAttentions,
+ TFSequenceClassifierOutputWithPast,
+)
+from ...modeling_tf_utils import (
+ TFCausalLanguageModelingLoss,
+ TFConv1D,
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ TFSequenceSummary,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_gpt2 import GPT2Config
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "openai-community/gpt2"
+_CONFIG_FOR_DOC = "GPT2Config"
+
+
+class TFAttention(keras.layers.Layer):
+ def __init__(self, nx, config, scale=False, is_cross_attention=False, **kwargs):
+ super().__init__(**kwargs)
+
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
+ # [switch nx => n_state from Block to Attention to keep identical to TF implementation]
+ assert n_state % config.n_head == 0
+ self.n_head = config.n_head
+ self.split_size = n_state
+ self.scale = scale
+ self.output_attentions = config.output_attentions
+
+ self.is_cross_attention = is_cross_attention
+
+ if self.is_cross_attention:
+ self.c_attn = TFConv1D(n_state * 2, nx, initializer_range=config.initializer_range, name="c_attn")
+ self.q_attn = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="q_attn")
+ else:
+ self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
+
+ self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
+ self.attn_dropout = keras.layers.Dropout(config.attn_pdrop)
+ self.resid_dropout = keras.layers.Dropout(config.resid_pdrop)
+ self.pruned_heads = set()
+ self.embed_dim = n_state
+
+ def prune_heads(self, heads):
+ pass
+
+ @staticmethod
+ def causal_attention_mask(nd, ns, dtype):
+ """
+ 1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
+ -1, ns-nd), but doesn't produce garbage on TPUs.
+ """
+ i = tf.range(nd)[:, None]
+ j = tf.range(ns)
+ m = i >= j - ns + nd
+ return tf.cast(m, dtype)
+
+ def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
+ # q, k, v have shape [batch, heads, sequence, features]
+ w = tf.matmul(q, k, transpose_b=True)
+ if self.scale:
+ dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores
+ w = w / tf.math.sqrt(dk)
+
+ if not self.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+
+ # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
+ _, _, nd, ns = shape_list(w)
+ b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
+ b = tf.reshape(b, [1, 1, nd, ns])
+ w = w * b - 1e4 * (1 - b)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attention_mask = tf.cast(attention_mask, dtype=w.dtype)
+ w = w + attention_mask
+
+ w = stable_softmax(w, axis=-1)
+ w = self.attn_dropout(w, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ w = w * head_mask
+
+ outputs = [tf.matmul(w, v)]
+ if output_attentions:
+ outputs.append(w)
+ return outputs
+
+ def merge_heads(self, x):
+ x = tf.transpose(x, [0, 2, 1, 3])
+ x_shape = shape_list(x)
+ new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
+ return tf.reshape(x, new_x_shape)
+
+ def split_heads(self, x):
+ x_shape = shape_list(x)
+ new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
+ x = tf.reshape(x, new_x_shape)
+ return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
+
+ def call(
+ self,
+ x,
+ layer_past,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ use_cache,
+ output_attentions,
+ training=False,
+ ):
+ if encoder_hidden_states is not None:
+ if not hasattr(self, "q_attn"):
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
+ )
+
+ query = self.q_attn(x)
+ kv_out = self.c_attn(encoder_hidden_states)
+ key, value = tf.split(kv_out, 2, axis=2)
+ attention_mask = encoder_attention_mask
+ else:
+ x = self.c_attn(x)
+ query, key, value = tf.split(x, 3, axis=2)
+
+ query = self.split_heads(query)
+ key = self.split_heads(key)
+ value = self.split_heads(value)
+ if layer_past is not None:
+ past_key, past_value = tf.unstack(layer_past, axis=0, num=2)
+ key = tf.concat([past_key, key], axis=-2)
+ value = tf.concat([past_value, value], axis=-2)
+
+ # to cope with keras serialization
+ if use_cache:
+ present = tf.stack([key, value], axis=0)
+ else:
+ present = (None,)
+
+ attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training)
+ a = attn_outputs[0]
+
+ a = self.merge_heads(a)
+ a = self.c_proj(a)
+ a = self.resid_dropout(a, training=training)
+
+ outputs = [a, present] + attn_outputs[1:]
+ return outputs # a, present, (attentions)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if self.is_cross_attention:
+ c_attn_shape = 2 * self.embed_dim
+ else:
+ c_attn_shape = 3 * self.embed_dim
+ if getattr(self, "c_proj", None) is not None:
+ with tf.name_scope(self.c_proj.name):
+ self.c_proj.build([None, None, self.embed_dim])
+ if getattr(self, "c_attn", None) is not None:
+ with tf.name_scope(self.c_attn.name):
+ self.c_attn.build([None, None, c_attn_shape])
+ if getattr(self, "q_attn", None) is not None:
+ with tf.name_scope(self.q_attn.name):
+ self.q_attn.build([None, None, self.embed_dim])
+
+
+class TFMLP(keras.layers.Layer):
+ def __init__(self, n_state, config, **kwargs):
+ super().__init__(**kwargs)
+ nx = config.n_embd
+ self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc")
+ self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj")
+ self.act = get_tf_activation(config.activation_function)
+ self.dropout = keras.layers.Dropout(config.resid_pdrop)
+ self.intermediate_size = n_state
+ self.embed_dim = nx
+
+ def call(self, x, training=False):
+ h = self.act(self.c_fc(x))
+ h2 = self.c_proj(h)
+ h2 = self.dropout(h2, training=training)
+ return h2
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "c_fc", None) is not None:
+ with tf.name_scope(self.c_fc.name):
+ self.c_fc.build([None, None, self.intermediate_size])
+ if getattr(self, "c_proj", None) is not None:
+ with tf.name_scope(self.c_proj.name):
+ self.c_proj.build([None, None, self.embed_dim])
+
+
+class TFBlock(keras.layers.Layer):
+ def __init__(self, config, scale=False, **kwargs):
+ super().__init__(**kwargs)
+ nx = config.n_embd
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * nx
+ self.ln_1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
+ self.attn = TFAttention(nx, config, scale, name="attn")
+ self.ln_2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
+
+ if config.add_cross_attention:
+ self.crossattention = TFAttention(nx, config, scale, name="crossattention", is_cross_attention=True)
+ self.ln_cross_attn = keras.layers.LayerNormalization(
+ epsilon=config.layer_norm_epsilon, name="ln_cross_attn"
+ )
+
+ self.mlp = TFMLP(inner_dim, config, name="mlp")
+ self.hidden_size = config.hidden_size
+
+ def call(
+ self,
+ x,
+ layer_past,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ use_cache,
+ output_attentions,
+ training=False,
+ ):
+ a = self.ln_1(x)
+ output_attn = self.attn(
+ a,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ a = output_attn[0] # output_attn: a, present, (attentions)
+ outputs = output_attn[1:]
+ x = x + a
+
+ # Cross-Attention Block
+ if encoder_hidden_states is not None:
+ # add one self-attention block for cross-attention
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
+ "cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+
+ ca = self.ln_cross_attn(x)
+ output_cross_attn = self.crossattention(
+ ca,
+ layer_past=None,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=False,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ ca = output_cross_attn[0] # output_attn: a, present, (cross_attentions)
+ x = x + ca
+ outputs = outputs + output_cross_attn[2:] # add cross attentions if we output attention weights
+
+ m = self.ln_2(x)
+ m = self.mlp(m, training=training)
+ x = x + m
+
+ outputs = [x] + outputs
+ return outputs # x, present, (attentions, cross_attentions)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "ln_1", None) is not None:
+ with tf.name_scope(self.ln_1.name):
+ self.ln_1.build([None, None, self.hidden_size])
+ if getattr(self, "attn", None) is not None:
+ with tf.name_scope(self.attn.name):
+ self.attn.build(None)
+ if getattr(self, "ln_2", None) is not None:
+ with tf.name_scope(self.ln_2.name):
+ self.ln_2.build([None, None, self.hidden_size])
+ if getattr(self, "mlp", None) is not None:
+ with tf.name_scope(self.mlp.name):
+ self.mlp.build(None)
+ if getattr(self, "crossattention", None) is not None:
+ with tf.name_scope(self.crossattention.name):
+ self.crossattention.build(None)
+ if getattr(self, "ln_cross_attn", None) is not None:
+ with tf.name_scope(self.ln_cross_attn.name):
+ self.ln_cross_attn.build([None, None, self.hidden_size])
+
+
+@keras_serializable
+class TFGPT2MainLayer(keras.layers.Layer):
+ config_class = GPT2Config
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ self.config = config
+ self.output_attentions = config.output_attentions
+ self.output_hidden_states = config.output_hidden_states
+ self.use_cache = config.use_cache
+ self.return_dict = config.use_return_dict
+
+ self.num_hidden_layers = config.n_layer
+ self.n_embd = config.n_embd
+ self.n_positions = config.n_positions
+ self.initializer_range = config.initializer_range
+
+ self.wte = keras.layers.Embedding(
+ input_dim=config.vocab_size,
+ output_dim=config.hidden_size,
+ embeddings_initializer=get_initializer(config.initializer_range),
+ name="wte",
+ )
+ self.wpe = keras.layers.Embedding(
+ input_dim=config.n_positions,
+ output_dim=config.n_embd,
+ embeddings_initializer=get_initializer(config.initializer_range),
+ name="wpe",
+ )
+ self.drop = keras.layers.Dropout(config.embd_pdrop)
+ self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)]
+ self.ln_f = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
+ self.embed_dim = config.hidden_size
+
+ def get_input_embeddings(self):
+ return self.wte
+
+ def set_input_embeddings(self, new_embeddings):
+ self.wte = new_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
+ """
+ raise NotImplementedError
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]:
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if past_key_values is None:
+ past_length = 0
+ past_key_values = [None] * len(self.h)
+ else:
+ past_length = shape_list(past_key_values[0][0])[-2]
+
+ if position_ids is None:
+ position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
+
+ if attention_mask is not None:
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask_shape = shape_list(attention_mask)
+ attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]))
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ one_cst = tf.constant(1.0)
+ attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
+ attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0))
+
+ # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
+ if self.config.add_cross_attention and encoder_attention_mask is not None:
+ # If a 2D ou 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=encoder_hidden_states.dtype)
+ num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
+ if num_dims_encoder_attention_mask == 3:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+ if num_dims_encoder_attention_mask == 2:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
+ # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
+ # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
+
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
+ else:
+ encoder_extended_attention_mask = None
+
+ encoder_attention_mask = encoder_extended_attention_mask
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.num_hidden_layers
+ # head_mask = tf.constant([0] * self.num_hidden_layers)
+
+ position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
+
+ if inputs_embeds is None:
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
+ inputs_embeds = self.wte(input_ids)
+
+ position_embeds = self.wpe(position_ids)
+
+ if token_type_ids is not None:
+ token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
+ token_type_embeds = self.wte(token_type_ids)
+ else:
+ token_type_embeds = tf.constant(0.0)
+
+ position_embeds = tf.cast(position_embeds, dtype=inputs_embeds.dtype)
+ token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype)
+ hidden_states = inputs_embeds + position_embeds + token_type_embeds
+ hidden_states = self.drop(hidden_states, training=training)
+
+ output_shape = input_shape + [shape_list(hidden_states)[-1]]
+
+ presents = () if use_cache else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
+
+ outputs = block(
+ hidden_states,
+ layer_past,
+ attention_mask,
+ head_mask[i],
+ encoder_hidden_states,
+ encoder_attention_mask,
+ use_cache,
+ output_attentions,
+ training=training,
+ )
+
+ hidden_states, present = outputs[:2]
+ if use_cache:
+ presents = presents + (present,)
+
+ if output_attentions:
+ all_attentions = all_attentions + (outputs[2],)
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ all_cross_attentions = all_cross_attentions + (outputs[3],)
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = tf.reshape(hidden_states, output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if output_attentions:
+ # let the number of heads free (-1) so we can extract attention even after head pruning
+ attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
+ all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, presents, all_hidden_states, all_attentions, all_cross_attentions]
+ if v is not None
+ )
+
+ return TFBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "wte", None) is not None:
+ with tf.name_scope(self.wte.name):
+ self.wte.build(None)
+ if getattr(self, "wpe", None) is not None:
+ with tf.name_scope(self.wpe.name):
+ self.wpe.build(None)
+ if getattr(self, "ln_f", None) is not None:
+ with tf.name_scope(self.ln_f.name):
+ self.ln_f.build([None, None, self.embed_dim])
+ if getattr(self, "h", None) is not None:
+ for layer in self.h:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+class TFGPT2PreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = GPT2Config
+ base_model_prefix = "transformer"
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"h.\d+.attn.bias", r"h.\d+.crossattention.bias"]
+
+ @property
+ def input_signature(self):
+ # Although GPT-2 supports token_type_ids in theory, in practice they are rarely used, and the implementation
+ # means that passing token_type_ids=0 yields different outputs from token_type_ids=None.
+ # Therefore, we remove the token_type_ids argument by default, even though it would usually be included.
+ return {
+ "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
+ "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
+ }
+
+
+@dataclass
+class TFGPT2DoubleHeadsModelOutput(ModelOutput):
+ """
+ Base class for outputs of models predicting if two sentences are consecutive or not.
+
+ Args:
+ logits (`tf.Tensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ mc_logits (`tf.Tensor` of shape `(batch_size, num_choices)`):
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
+ past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ logits: tf.Tensor | None = None
+ mc_logits: tf.Tensor | None = None
+ past_key_values: list[tf.Tensor] | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+GPT2_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+GPT2_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0].shape[-2]`
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only input IDs that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ past_key_values (`list[tf.Tensor]` of length `config.n_layers`):
+ Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
+ `past_key_values` output below). Can be used to speed up sequential decoding. The token ids which have
+ their past given to this model should not be passed as input ids as they have already been computed.
+ attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
+ `len(past_key_values) + len(input_ids)`
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, input_ids_length)`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, input_ids_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`tf.Tensor` of shape `(batch_size, input_ids_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
+ GPT2_START_DOCSTRING,
+)
+class TFGPT2Model(TFGPT2PreTrainedModel):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.transformer = TFGPT2MainLayer(config, name="transformer")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPastAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]:
+ r"""
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`)
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have
+ their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past`). Set to `False` during training, `True` during generation
+ """
+
+ outputs = self.transformer(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+
+
+@add_start_docstrings(
+ """
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """,
+ GPT2_START_DOCSTRING,
+)
+class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.transformer = TFGPT2MainLayer(config, name="transformer")
+
+ def get_output_embeddings(self):
+ return self.get_input_embeddings()
+
+ def set_output_embeddings(self, value):
+ self.set_input_embeddings(value)
+
+ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
+ token_type_ids = kwargs.get("token_type_ids")
+ # only last token for inputs_ids if past is defined in kwargs
+ if past_key_values:
+ inputs = tf.expand_dims(inputs[:, -1], -1)
+ if token_type_ids is not None:
+ token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
+
+ position_ids = kwargs.get("position_ids")
+ attention_mask = kwargs.get("attention_mask")
+
+ if attention_mask is not None and position_ids is None:
+ position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
+ if past_key_values:
+ position_ids = tf.expand_dims(position_ids[:, -1], -1)
+
+ return {
+ "input_ids": inputs,
+ "attention_mask": attention_mask,
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "token_type_ids": token_type_ids,
+ }
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFCausalLMOutputWithCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFCausalLMOutputWithCrossAttentions | tuple[tf.Tensor]:
+ r"""
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`)
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have
+ their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past`). Set to `False` during training, `True` during generation
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = tf.matmul(hidden_states, self.transformer.wte.weights, transpose_b=True)
+
+ loss = None
+ if labels is not None:
+ # shift labels to the left and cut last logit token
+ shifted_logits = logits[:, :-1]
+ labels = labels[:, 1:]
+ loss = self.hf_compute_loss(labels, shifted_logits)
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFCausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ cross_attentions=transformer_outputs.cross_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+
+
+@add_start_docstrings(
+ """
+ The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
+ input sequence).
+ """,
+ GPT2_START_DOCSTRING,
+)
+class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ config.num_labels = 1
+ self.transformer = TFGPT2MainLayer(config, name="transformer")
+ self.multiple_choice_head = TFSequenceSummary(
+ config, initializer_range=config.initializer_range, name="multiple_choice_head"
+ )
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ mc_token_ids: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> TFGPT2DoubleHeadsModelOutput | tuple[tf.Tensor]:
+ r"""
+ mc_token_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
+ 1]`.
+
+ Return:
+
+ Examples:
+
+ ```python
+ >>> import tensorflow as tf
+ >>> from transformers import AutoTokenizer, TFGPT2DoubleHeadsModel
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
+ >>> model = TFGPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2")
+
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
+ >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
+
+ >>> embedding_layer = model.resize_token_embeddings(
+ ... len(tokenizer)
+ ... ) # Update the model embeddings with the new vocabulary size
+
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
+ >>> encoded_choices = [tokenizer.encode(s) for s in choices]
+ >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
+
+ >>> input_ids = tf.constant(encoded_choices)[None, :] # Batch size: 1, number of choices: 2
+ >>> mc_token_ids = tf.constant([cls_token_location]) # Batch size: 1
+
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
+ >>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
+ ```"""
+
+ if input_ids is not None:
+ input_shapes = shape_list(input_ids)
+ else:
+ input_shapes = shape_list(inputs_embeds)[:-1]
+
+ seq_length = input_shapes[-1]
+ flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
+ flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
+ flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
+ flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
+ transformer_outputs = self.transformer(
+ input_ids=flat_input_ids,
+ past_key_values=past_key_values,
+ attention_mask=flat_attention_mask,
+ token_type_ids=flat_token_type_ids,
+ position_ids=flat_position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_states = transformer_outputs[0]
+ hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
+ if return_dict and output_hidden_states:
+ # We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the
+ # input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged)
+ all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,)
+ else:
+ all_hidden_states = None
+ lm_logits = tf.matmul(hidden_states, self.transformer.wte.weights, transpose_b=True)
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
+ mc_logits = tf.squeeze(mc_logits, axis=-1)
+
+ if not return_dict:
+ return (lm_logits, mc_logits) + transformer_outputs[1:]
+
+ return TFGPT2DoubleHeadsModelOutput(
+ logits=lm_logits,
+ mc_logits=mc_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @property
+ def input_signature(self):
+ return {
+ "input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
+ "attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
+ "mc_token_ids": tf.TensorSpec((None, None), tf.int32, name="mc_token_ids"),
+ }
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+ if getattr(self, "multiple_choice_head", None) is not None:
+ with tf.name_scope(self.multiple_choice_head.name):
+ self.multiple_choice_head.build(None)
+
+
+@add_start_docstrings(
+ """
+ The GPT2 Model transformer with a sequence classification head on top (linear layer).
+
+ [`TFGPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-1) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ GPT2_START_DOCSTRING,
+)
+class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+ self.score = keras.layers.Dense(
+ config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="score",
+ use_bias=False,
+ )
+ self.transformer = TFGPT2MainLayer(config, name="transformer")
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint="microsoft/DialogRPT-updown",
+ output_type=TFSequenceClassifierOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> TFSequenceClassifierOutputWithPast | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+ logits_shape = shape_list(logits)
+ batch_size = logits_shape[0]
+
+ if self.config.pad_token_id is None:
+ last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
+ else:
+ if input_ids is not None:
+ token_indices = tf.range(shape_list(input_ids)[-1])
+ non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
+ last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
+ else:
+ last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+ loss = None
+
+ pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
+
+ if labels is not None:
+ if self.config.pad_token_id is None and logits_shape[0] != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+
+ loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
+
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "score", None) is not None:
+ with tf.name_scope(self.score.name):
+ self.score.build([None, None, self.config.n_embd])
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+
+
+__all__ = [
+ "TFGPT2DoubleHeadsModel",
+ "TFGPT2ForSequenceClassification",
+ "TFGPT2LMHeadModel",
+ "TFGPT2MainLayer",
+ "TFGPT2Model",
+ "TFGPT2PreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt2/tokenization_gpt2.py b/venv/lib/python3.13/site-packages/transformers/models/gpt2/tokenization_gpt2.py
new file mode 100644
index 0000000000000000000000000000000000000000..608164ef2d83ab15bf7f99d33f9c6eb56ed1fcff
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gpt2/tokenization_gpt2.py
@@ -0,0 +1,334 @@
+# coding=utf-8
+# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for OpenAI GPT."""
+
+import json
+import os
+from functools import lru_cache
+from typing import Optional
+
+import regex as re
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+ "vocab_file": "vocab.json",
+ "merges_file": "merges.txt",
+}
+
+
+@lru_cache
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+ characters the bpe code barfs on.
+
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
+ tables between utf-8 bytes and unicode strings.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+class GPT2Tokenizer(PreTrainedTokenizer):
+ """
+ Construct a GPT-2 tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```python
+ >>> from transformers import GPT2Tokenizer
+
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
+ >>> tokenizer("Hello world")["input_ids"]
+ [15496, 995]
+
+ >>> tokenizer(" Hello world")["input_ids"]
+ [18435, 995]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
+ call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The beginning of sequence token.
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The end of sequence token.
+ pad_token (`str`, *optional*):
+ The token used for padding, for example when batching sequences of different lengths.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (GPT2 tokenizer detect beginning of words by the preceding space).
+ add_bos_token (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial beginning of sentence token to the input. This allows to treat the leading
+ word just as any other word.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ errors="replace",
+ unk_token="<|endoftext|>",
+ bos_token="<|endoftext|>",
+ eos_token="<|endoftext|>",
+ pad_token=None,
+ add_prefix_space=False,
+ add_bos_token=False,
+ **kwargs,
+ ):
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
+
+ self.add_bos_token = add_bos_token
+
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ bpe_merges = merges_handle.read().split("\n")[1:-1]
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {}
+ self.add_prefix_space = add_prefix_space
+
+ # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+
+ super().__init__(
+ errors=errors,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ add_prefix_space=add_prefix_space,
+ add_bos_token=add_bos_token,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self):
+ return len(self.encoder)
+
+ def get_vocab(self):
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ if self.add_bos_token:
+ bos_token_ids = [self.bos_token_id]
+ else:
+ bos_token_ids = []
+
+ output = bos_token_ids + token_ids_0
+
+ if token_ids_1 is None:
+ return output
+
+ return output + bos_token_ids + token_ids_1
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if not self.add_bos_token:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0))
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
+
+ def _tokenize(self, text):
+ """Tokenize a string."""
+ bpe_tokens = []
+ for token in re.findall(self.pat, text):
+ token = "".join(
+ self.byte_encoder[b] for b in token.encode("utf-8")
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ text = "".join(tokens)
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
+ return text
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ merge_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+ )
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write("#version: 0.2\n")
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!"
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return vocab_file, merge_file
+
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
+ add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
+ if is_split_into_words or add_prefix_space:
+ text = " " + text
+ return (text, kwargs)
+
+
+__all__ = ["GPT2Tokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt2/tokenization_gpt2_fast.py b/venv/lib/python3.13/site-packages/transformers/models/gpt2/tokenization_gpt2_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..f81c155e864476cf49c24f91a0235c939f42d3e0
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gpt2/tokenization_gpt2_fast.py
@@ -0,0 +1,133 @@
+# coding=utf-8
+# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for OpenAI GPT."""
+
+from typing import Optional
+
+from ...tokenization_utils_base import BatchEncoding
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+from .tokenization_gpt2 import GPT2Tokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
+
+
+class GPT2TokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" GPT-2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
+ Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```python
+ >>> from transformers import GPT2TokenizerFast
+
+ >>> tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
+ >>> tokenizer("Hello world")["input_ids"]
+ [15496, 995]
+
+ >>> tokenizer(" Hello world")["input_ids"]
+ [18435, 995]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
+ the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`, *optional*):
+ Path to the vocabulary file.
+ merges_file (`str`, *optional*):
+ Path to the merges file.
+ tokenizer_file (`str`, *optional*):
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
+ contains everything needed to load the tokenizer.
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The beginning of sequence token.
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The end of sequence token.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (GPT2 tokenizer detect beginning of words by the preceding space).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = GPT2Tokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ merges_file=None,
+ tokenizer_file=None,
+ unk_token="<|endoftext|>",
+ bos_token="<|endoftext|>",
+ eos_token="<|endoftext|>",
+ add_prefix_space=False,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file=vocab_file,
+ merges_file=merges_file,
+ tokenizer_file=tokenizer_file,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ add_prefix_space=add_prefix_space,
+ **kwargs,
+ )
+
+ self.add_bos_token = kwargs.pop("add_bos_token", False)
+
+ def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+ assert self.add_prefix_space or not is_split_into_words, (
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
+ "to use it with pretokenized inputs."
+ )
+
+ return super()._batch_encode_plus(*args, **kwargs)
+
+ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+
+ assert self.add_prefix_space or not is_split_into_words, (
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
+ "to use it with pretokenized inputs."
+ )
+
+ return super()._encode_plus(*args, **kwargs)
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+
+__all__ = ["GPT2TokenizerFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt2/tokenization_gpt2_tf.py b/venv/lib/python3.13/site-packages/transformers/models/gpt2/tokenization_gpt2_tf.py
new file mode 100644
index 0000000000000000000000000000000000000000..145a45da0db6d36f75f5cec6091027e36541184e
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gpt2/tokenization_gpt2_tf.py
@@ -0,0 +1,119 @@
+import os
+from typing import Optional, Union
+
+import tensorflow as tf
+from tensorflow_text import pad_model_inputs
+
+from ...modeling_tf_utils import keras
+from ...utils.import_utils import is_keras_nlp_available, requires
+from .tokenization_gpt2 import GPT2Tokenizer
+
+
+if is_keras_nlp_available():
+ from keras_nlp.tokenizers import BytePairTokenizer
+
+
+@requires(backends=("keras_nlp",))
+class TFGPT2Tokenizer(keras.layers.Layer):
+ """
+ This is an in-graph tokenizer for GPT2. It should be initialized similarly to other tokenizers, using the
+ `from_pretrained()` method. It can also be initialized with the `from_tokenizer()` method, which imports settings
+ from an existing standard tokenizer object.
+
+ In-graph tokenizers, unlike other Hugging Face tokenizers, are actually Keras layers and are designed to be run
+ when the model is called, rather than during preprocessing. As a result, they have somewhat more limited options
+ than standard tokenizer classes. They are most useful when you want to create an end-to-end model that goes
+ straight from `tf.string` inputs to outputs.
+
+ Args:
+ vocab (dict[str, int]): Vocabulary dict for Byte Pair Tokenizer
+ merges (list[str]): Merges list for Byte Pair Tokenizer
+ """
+
+ def __init__(
+ self,
+ vocab: dict[str, int],
+ merges: list[str],
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ ):
+ super().__init__()
+ self.pad_token_id = pad_token_id
+ self.max_length = max_length
+ self.vocab = vocab
+ self.merges = merges
+
+ self.tf_tokenizer = BytePairTokenizer(vocab, merges, sequence_length=max_length)
+
+ @classmethod
+ def from_tokenizer(cls, tokenizer: GPT2Tokenizer, *args, **kwargs):
+ """Creates TFGPT2Tokenizer from GPT2Tokenizer
+
+ Args:
+ tokenizer (GPT2Tokenizer)
+
+ Examples:
+
+ ```python
+ from transformers import AutoTokenizer, TFGPT2Tokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
+ tf_tokenizer = TFGPT2Tokenizer.from_tokenizer(tokenizer)
+ ```
+ """
+ merges = [" ".join(m) for m in tokenizer.bpe_ranks]
+ vocab = tokenizer.get_vocab()
+ return cls(vocab, merges, *args, **kwargs)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
+ """Creates TFGPT2Tokenizer from pretrained GPT2Tokenizer
+
+ Args:
+ pretrained_model_name_or_path (Union[str, os.PathLike]): Path to pretrained model
+
+ Examples:
+
+ ```python
+ from transformers import TFGPT2Tokenizer
+
+ tf_tokenizer = TFGPT2Tokenizer.from_pretrained("openai-community/gpt2")
+ ```
+ """
+ tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
+ return cls.from_tokenizer(tokenizer, *init_inputs, **kwargs)
+
+ @classmethod
+ def from_config(cls, config):
+ """Creates TFGPT2Tokenizer from configurations
+
+ Args:
+ config (Dict): Dictionary with keys such as stated in `get_config`.
+ """
+ return cls(**config)
+
+ def get_config(self):
+ return {
+ "vocab": self.vocab,
+ "merges": self.merges,
+ "max_length": self.max_length,
+ "pad_token_id": self.pad_token_id,
+ }
+
+ def call(self, x, max_length: Optional[int] = None):
+ input_ids = self.tf_tokenizer(x)
+ attention_mask = tf.ones_like(input_ids)
+
+ if self.pad_token_id is not None:
+ # pad the tokens up to max length
+ max_length = max_length if max_length is not None else self.max_length
+
+ if max_length is not None:
+ input_ids, attention_mask = pad_model_inputs(
+ input_ids, max_seq_length=max_length, pad_value=self.pad_token_id
+ )
+
+ return {"attention_mask": attention_mask, "input_ids": input_ids}
+
+
+__all__ = ["TFGPT2Tokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_bigcode/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_bigcode/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..92e985d92734550a5b0635941294669386d35749
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_bigcode/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_gpt_bigcode import *
+ from .modeling_gpt_bigcode import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py
new file mode 100644
index 0000000000000000000000000000000000000000..127a0eed4732c15ef565a306a1a25f86b4e51ce4
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py
@@ -0,0 +1,145 @@
+# coding=utf-8
+# Copyright 2023 The BigCode team and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""GPTBigCode configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class GPTBigCodeConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`GPTBigCodeModel`]. It is used to instantiate a
+ GPTBigCode model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the GPTBigCode
+ [gpt_bigcode](https://huggingface.co/gpt_bigcode) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50257):
+ Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`GPTBigCodeModel`].
+ n_positions (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ n_embd (`int`, *optional*, defaults to 768):
+ Dimensionality of the embeddings and hidden states.
+ n_layer (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ n_head (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ n_inner (`int`, *optional*, defaults to None):
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
+ activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new",
+ "gelu_pytorch_tanh"]`.
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ embd_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the embeddings.
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
+ The epsilon to use in the layer normalization layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
+ Scale attention weights by dividing by sqrt(hidden_size)..
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
+ Whether to call the fused softmax in float32.
+ scale_attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
+ Whether to scale the attention softmax in float32.
+ attention_type (`bool`, *optional*, defaults to `True`):
+ Whether to use Multi-Query Attion (`True`) or Multi-Head Attention (`False`).
+ Example:
+
+ ```python
+ >>> from transformers import GPTBigCodeConfig, GPTBigCodeModel
+
+ >>> # Initializing a GPTBigCode configuration
+ >>> configuration = GPTBigCodeConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = GPTBigCodeModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "gpt_bigcode"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {
+ "hidden_size": "n_embd",
+ "max_position_embeddings": "n_positions",
+ "num_attention_heads": "n_head",
+ "num_hidden_layers": "n_layer",
+ }
+
+ def __init__(
+ self,
+ vocab_size=50257,
+ n_positions=1024,
+ n_embd=768,
+ n_layer=12,
+ n_head=12,
+ n_inner=None,
+ activation_function="gelu_pytorch_tanh",
+ resid_pdrop=0.1,
+ embd_pdrop=0.1,
+ attn_pdrop=0.1,
+ layer_norm_epsilon=1e-5,
+ initializer_range=0.02,
+ scale_attn_weights=True,
+ use_cache=True,
+ bos_token_id=50256,
+ eos_token_id=50256,
+ attention_softmax_in_fp32=True,
+ scale_attention_softmax_in_fp32=True,
+ multi_query=True,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.n_positions = n_positions
+ self.n_embd = n_embd
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.n_inner = n_inner
+ self.activation_function = activation_function
+ self.resid_pdrop = resid_pdrop
+ self.embd_pdrop = embd_pdrop
+ self.attn_pdrop = attn_pdrop
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_range = initializer_range
+ self.scale_attn_weights = scale_attn_weights
+ self.use_cache = use_cache
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
+ self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
+ self.multi_query = multi_query
+ self.num_key_value_heads = 1 if multi_query else n_head
+
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+
+__all__ = ["GPTBigCodeConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/venv/lib/python3.13/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
new file mode 100644
index 0000000000000000000000000000000000000000..6992dc642a4f024b97a9c143eff434bf4eea205c
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
@@ -0,0 +1,931 @@
+# coding=utf-8
+# Copyright 2023 The Bigcode team and HuggingFace Inc. team.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch GPTBigCode model."""
+
+import math
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import is_flash_attn_available
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...utils import (
+ auto_docstring,
+ can_return_tuple,
+ logging,
+)
+from .configuration_gpt_bigcode import GPTBigCodeConfig
+
+
+if is_flash_attn_available():
+ pass
+
+
+logger = logging.get_logger(__name__)
+
+
+# Fused kernels
+# Use separate functions for each case because conditionals prevent kernel fusion.
+# TODO: Could have better fused kernels depending on scaling, dropout and head mask.
+# Is it doable without writing 32 functions?
+@torch.jit.script
+def upcast_masked_softmax(
+ x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
+):
+ input_dtype = x.dtype
+ x = x.to(softmax_dtype) * scale
+ x = torch.where(mask, x, mask_value)
+ x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
+ return x
+
+
+@torch.jit.script
+def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
+ input_dtype = x.dtype
+ x = x.to(softmax_dtype) * scale
+ x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
+ return x
+
+
+@torch.jit.script
+def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
+ x = torch.where(mask, x, mask_value)
+ x = torch.nn.functional.softmax(x, dim=-1)
+ return x
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class GPTBigCodeAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
+ super().__init__()
+ self.config = config
+
+ self.mask_value = None
+ self.multi_query = config.multi_query
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ self.kv_heads = 1 if self.multi_query else self.num_heads
+ self.kv_dim = self.kv_heads * self.head_dim
+ self.num_key_value_groups = self.num_heads // self.kv_heads
+ self.split_size = self.embed_dim
+ self.is_causal = True
+
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ self.scale_attn_weights = config.scale_attn_weights
+ self.scaling = self.head_dim**-0.5 if config.scale_attn_weights else 1.0
+ self.is_cross_attention = is_cross_attention
+
+ self.layer_idx = layer_idx
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
+ self.scale_attention_softmax_in_fp32 = (
+ config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
+ )
+ self.attn_pdrop = config.attn_pdrop
+
+ if self.is_cross_attention:
+ if self.multi_query:
+ raise NotImplementedError("Multi-Query Attention not supported for cross_attention")
+
+ self.c_attn = nn.Linear(self.embed_dim, 2 * self.embed_dim)
+ self.q_attn = nn.Linear(self.embed_dim, self.embed_dim)
+ else:
+ self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim)
+
+ self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ self.attn_dropout = config.attn_pdrop
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ layer_past: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[
+ tuple[torch.Tensor, Optional[torch.Tensor]],
+ tuple[torch.Tensor, Optional[torch.Tensor], tuple[torch.Tensor, ...]],
+ ]:
+ input_shape = hidden_states.shape[:-1]
+
+ if layer_past is not None:
+ if isinstance(layer_past, EncoderDecoderCache):
+ is_updated = layer_past.is_updated.get(self.layer_idx)
+ if self.is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = layer_past.cross_attention_cache
+ else:
+ curr_past_key_value = layer_past.self_attention_cache
+ else:
+ curr_past_key_value = layer_past
+
+ if self.is_cross_attention:
+ if not hasattr(self, "q_attn") or not self.is_cross_attention:
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
+ )
+ if layer_past is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key = curr_past_key_value.layers[self.layer_idx].keys
+ value = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ query = self.q_attn(hidden_states).view(*input_shape, -1, self.head_dim).transpose(1, 2)
+ key, value = self.c_attn(encoder_hidden_states).split((self.head_dim, self.head_dim), dim=-1)
+ else:
+ if self.multi_query:
+ query, key, value = (
+ self.c_attn(hidden_states).unsqueeze(1).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=3)
+ )
+ query = query.view(*input_shape, -1, self.head_dim).transpose(1, 2)
+ else:
+ query, key, value = (
+ self.c_attn(hidden_states)
+ .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
+ .transpose(1, 2)
+ .split(3 * [self.head_dim], dim=3)
+ )
+
+ if layer_past is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not self.is_cross_attention else None
+ key, value = curr_past_key_value.update(key, value, self.layer_idx, {"cache_position": cache_position})
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if self.is_cross_attention:
+ layer_past.is_updated[self.layer_idx] = True
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attn_dropout,
+ scaling=self.scaling,
+ head_mask=head_mask,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+ return attn_output, attn_weights
+
+
+class GPTBigCodeMLP(nn.Module):
+ def __init__(self, intermediate_size, config):
+ super().__init__()
+ embed_dim = config.hidden_size
+ self.c_fc = nn.Linear(embed_dim, intermediate_size)
+ self.c_proj = nn.Linear(intermediate_size, embed_dim)
+ self.act = ACT2FN[config.activation_function]
+ self.dropout = nn.Dropout(config.resid_pdrop)
+
+ # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
+ def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.FloatTensor:
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class GPTBigCodeBlock(nn.Module):
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ hidden_size = config.hidden_size
+ self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
+
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx)
+
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ if config.add_cross_attention:
+ if config.multi_query:
+ raise NotImplementedError("Cross-attention not implemented for MQA")
+
+ self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx)
+
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ self.mlp = GPTBigCodeMLP(self.inner_dim, config)
+
+ def forward(
+ self,
+ hidden_states: Optional[tuple[torch.Tensor]],
+ layer_past: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[
+ tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ ]:
+ residual = hidden_states
+ hidden_states = self.ln_1(hidden_states)
+ attn_outputs = self.attn(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
+ outputs = attn_outputs[1:]
+ # residual connection
+ hidden_states = attn_output + residual
+
+ if encoder_hidden_states is not None:
+ # add one self-attention block for cross-attention
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
+ "cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+ residual = hidden_states
+ hidden_states = self.ln_cross_attn(hidden_states)
+ cross_attn_outputs = self.crossattention(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ attn_output = cross_attn_outputs[0]
+ # residual connection
+ hidden_states = residual + attn_output
+ outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
+
+ residual = hidden_states
+ hidden_states = self.ln_2(hidden_states)
+ feed_forward_hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + feed_forward_hidden_states
+ return (hidden_states,) + outputs
+
+
+@auto_docstring
+class GPTBigCodePreTrainedModel(PreTrainedModel):
+ config: GPTBigCodeConfig
+ base_model_prefix = "transformer"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GPTBigCodeBlock"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)):
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
+ #
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+ module.c_proj.weight.data.normal_(
+ mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
+ )
+ module.c_proj._is_hf_initialized = True
+ elif isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class GPTBigCodeModel(GPTBigCodePreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.multi_query = config.multi_query
+ self.embed_dim = config.hidden_size
+
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
+
+ self.drop = nn.Dropout(config.embd_pdrop)
+ self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ max_positions = config.max_position_embeddings
+ self.register_buffer(
+ "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False
+ )
+
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.wte
+
+ def set_input_embeddings(self, new_embeddings):
+ self.wte = new_embeddings
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ r"""
+ input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size = inputs_embeds.shape[0]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if batch_size <= 0:
+ raise ValueError("batch_size has to be defined and > 0")
+
+ if use_cache and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+ if use_cache and isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ )
+
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = (
+ encoder_attention_mask.bool()
+ if (encoder_attention_mask is not None and 0 in encoder_attention_mask)
+ else None
+ )
+ else:
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if (
+ self.config.add_cross_attention
+ and encoder_hidden_states is not None
+ and encoder_attention_mask is not None
+ ):
+ if encoder_attention_mask.dim() == 2:
+ encoder_attention_mask.unsqueeze(1)
+ assert encoder_attention_mask.dim() == 3
+ encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1)
+ else:
+ encoder_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # head_mask has shape n_layer x batch x n_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ position_embeds = self.wpe(position_ids)
+ hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+ output_shape = input_shape + (hidden_states.size(-1),)
+
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, block in enumerate(self.h):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = block(
+ hidden_states,
+ past_key_values,
+ causal_mask,
+ head_mask[i],
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (outputs[2],)
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """
+)
+class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = GPTBigCodeModel(config)
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.Tensor` of shape `(batch_size, input_ids_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ lm_logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ cross_attentions=transformer_outputs.cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The GPTBigCode Model transformer with a sequence classification head on top (linear layer).
+
+ [`GPTBigCodeForSequenceClassification`] uses the last token in order to do the classification, as other causal
+ models (e.g. GPT-1) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = GPTBigCodeModel(config)
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ **kwargs,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class GPTBigCodeForTokenClassification(GPTBigCodePreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.transformer = GPTBigCodeModel(config)
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
+ classifier_dropout = config.classifier_dropout
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, TokenClassifierOutput]:
+ r"""
+ input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1).to(logits.device))
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+__all__ = [
+ "GPTBigCodeForSequenceClassification",
+ "GPTBigCodeForTokenClassification",
+ "GPTBigCodeForCausalLM",
+ "GPTBigCodeModel",
+ "GPTBigCodePreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/granitemoeshared/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/granitemoeshared/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..33d80cdd3425f95de5d40c82a4f52132be971f1f
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/granitemoeshared/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_granitemoeshared import *
+ from .modeling_granitemoeshared import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/granitemoeshared/configuration_granitemoeshared.py b/venv/lib/python3.13/site-packages/transformers/models/granitemoeshared/configuration_granitemoeshared.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd1c4a5ca6991bff87729670eac05a41e7879181
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/granitemoeshared/configuration_granitemoeshared.py
@@ -0,0 +1,200 @@
+# coding=utf-8
+# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""GraniteMoeShared model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class GraniteMoeSharedConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GraniteMoeSharedModel`]. It is used to instantiate an GraniteMoeShared
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the [ibm-research/moe-7b-1b-active-shared-experts](https://huggingface.co/ibm-research/moe-7b-1b-active-shared-experts).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the GraniteMoeShared model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`GraniteMoeSharedModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ embedding_multiplier (`float`, *optional*, defaults to 1.0): embedding multiplier
+ logits_scaling (`float`, *optional*, defaults to 1.0): divisor for output logits
+ residual_multiplier (`float`, *optional*, defaults to 1.0): residual multiplier
+ attention_multiplier (`float`, *optional*, defaults to 1.0): attention multiplier
+ num_local_experts (`int`, *optional*, defaults to 8): total number of experts
+ num_experts_per_tok (`int`, *optional*, defaults to 2): number of experts per token
+ output_router_logits (`bool`, *optional*, defaults to `False`):
+ Whether or not the router logits should be returned by the model. Enabling this will also
+ allow the model to output the auxiliary loss.
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001): router auxiliary loss coefficient
+ shared_intermediate_size (`int`, *optional*, defaults to 0): intermediate size for shared experts. 0 implies
+ no shared experts.
+
+ ```python
+ >>> from transformers import GraniteMoeSharedModel, GraniteMoeSharedConfig
+
+ >>> # Initializing a GraniteMoeShared granitemoe-3b style configuration
+ >>> configuration = GraniteMoeSharedConfig()
+
+ >>> # Initializing a model from the granitemoe-7b style configuration
+ >>> model = GraniteMoeSharedModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "granitemoeshared"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ embedding_multiplier=1.0,
+ logits_scaling=1.0,
+ residual_multiplier=1.0,
+ attention_multiplier=1.0,
+ num_local_experts=8,
+ num_experts_per_tok=2,
+ output_router_logits=False,
+ router_aux_loss_coef=0.001,
+ shared_intermediate_size=0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ # this model has rope embedding type, hardcoded for BC
+ self.position_embedding_type = "rope"
+
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ self.embedding_multiplier = embedding_multiplier
+ self.logits_scaling = logits_scaling
+ self.residual_multiplier = residual_multiplier
+ self.attention_multiplier = attention_multiplier
+
+ self.num_local_experts = num_local_experts
+ self.num_experts_per_tok = num_experts_per_tok
+ self.output_router_logits = output_router_logits
+ self.router_aux_loss_coef = router_aux_loss_coef
+ self.shared_intermediate_size = shared_intermediate_size
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ rope_config_validation(self)
+
+
+__all__ = ["GraniteMoeSharedConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/venv/lib/python3.13/site-packages/transformers/models/granitemoeshared/modeling_granitemoeshared.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9ff21d3ebba535bba58892ba9742985a36ece80
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/granitemoeshared/modeling_granitemoeshared.py
@@ -0,0 +1,1059 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/granitemoeshared/modular_granitemoeshared.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_granitemoeshared.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, TypedDict, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_granitemoeshared import GraniteMoeSharedConfig
+
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+class GraniteFlashAttentionKwargs(TypedDict, total=False):
+ """
+ Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
+ Use cases include padding-free training and fewer `torch.compile` graph breaks.
+
+ Attributes:
+ cu_seq_lens_q (`torch.LongTensor`)
+ Gets cumulative sequence length for query state.
+ cu_seq_lens_k (`torch.LongTensor`)
+ Gets cumulative sequence length for key state.
+ max_length_q (`int`):
+ Maximum sequence length for query state.
+ max_length_k (`int`):
+ Maximum sequence length for key state.
+ seq_idx (`torch.IntTensor):
+ Index of each packed sequence.
+ """
+
+ cu_seq_lens_q: torch.LongTensor
+ cu_seq_lens_k: torch.LongTensor
+ max_length_q: int
+ max_length_k: int
+ seq_idx: torch.IntTensor
+
+
+class GraniteMoeSharedMLP(nn.Module):
+ """
+ MLP layer for shared experts
+
+ Args:
+ config:
+ Configuration object with model hyperparameters.
+ """
+
+ def __init__(self, config: GraniteMoeSharedConfig):
+ super().__init__()
+
+ self.input_size = config.hidden_size
+ self.hidden_size = config.shared_intermediate_size
+ self.activation = ACT2FN[config.hidden_act]
+ self.input_linear = nn.Linear(self.input_size, self.hidden_size * 2, bias=False)
+ self.output_linear = nn.Linear(self.hidden_size, self.input_size, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.input_linear(hidden_states)
+ chunked_hidden_states = hidden_states.chunk(2, dim=-1)
+ hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
+ hidden_states = self.output_linear(hidden_states)
+ return hidden_states
+
+
+class GraniteMoeSharedRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ GraniteMoeSharedRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class GraniteMoeSharedParallelExperts(nn.Module):
+ def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
+ """
+ Initialize the GraniteMoeSharedParallelExperts module.
+ The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with
+ many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and
+ [ScatterMoE](https://github.com/shawntan/scattermoe), as well as the
+ [MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py)
+ used in vllm.
+
+ Args:
+ num_experts (int):
+ Number of experts.
+ input_size (int):
+ Size of the input.
+ output_size (int):
+ Size of the output.
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
+ self.num_experts = num_experts
+ self.input_size = input_size
+ self.output_size = output_size
+
+ def forward(self, inputs, expert_size):
+ """
+ Forward pass of the GraniteMoeSharedParallelExperts module.
+
+ Args:
+ inputs (Tensor):
+ Input tensor.
+ expert_size:
+ Expert size information.
+
+ Returns:
+ Tensor: Output tensor.
+ """
+ input_list = inputs.split(expert_size, dim=0)
+ output_list = []
+ for i in range(self.num_experts):
+ output_list.append(F.linear(input_list[i], self.weight[i]))
+ results = torch.cat(output_list, dim=0)
+ return results
+
+
+class GraniteMoeSharedTopKGating(nn.Module):
+ def __init__(self, input_size: int, num_experts: int, top_k: int):
+ """
+ Initialize the top-k gating mechanism.
+ Args:
+ input_size (`int`):
+ Size of the input.
+ num_experts (`int`):
+ Number of experts.
+ top_k (`int`):
+ Number of top experts to select.
+ """
+ super().__init__()
+
+ self.num_experts = num_experts
+ self.input_size = input_size
+ self.top_k = top_k
+
+ self.layer = nn.Linear(input_size, num_experts, bias=False)
+
+ def forward(self, hidden_states):
+ # compute the top_k routing decision
+ logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
+ top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
+ top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
+
+ # compute number of input given to each expert
+ zeros = torch.zeros(
+ [top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
+ ) # [num_tokens, num_experts]
+ gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
+ expert_size = gates.long().sum(0) # [num_experts,]
+ # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
+ # (and `DataDependentOutputException`)
+ expert_size = expert_size.tolist()
+
+ # sort and group input tokens according to expert assignment
+ top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
+ _, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
+ batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]
+
+ # gather the gate values for grouped input tokens
+ top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
+ batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]
+
+ return index_sorted_experts, batch_index, batch_gates, expert_size, logits
+
+
+class GraniteMoeSharedMoE(nn.Module):
+ """
+ A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
+
+ Args:
+ config:
+ Configuration object with model hyperparameters.
+ """
+
+ def __init__(self, config: GraniteMoeSharedConfig):
+ super().__init__()
+
+ self.input_size = config.hidden_size
+ self.hidden_size = config.intermediate_size
+ self.activation = ACT2FN[config.hidden_act]
+ self.input_linear = GraniteMoeSharedParallelExperts(
+ config.num_local_experts, self.input_size, self.hidden_size * 2
+ )
+ self.output_linear = GraniteMoeSharedParallelExperts(
+ config.num_local_experts, self.hidden_size, self.input_size
+ )
+
+ self.router = GraniteMoeSharedTopKGating(
+ input_size=self.input_size,
+ num_experts=config.num_local_experts,
+ top_k=config.num_experts_per_tok,
+ )
+
+ def forward(self, layer_input):
+ """
+ Forward pass of the mixture of experts layer.
+
+ Args:
+ layer_input (Tensor):
+ Input tensor.
+
+ Returns:
+ Tensor:
+ Output tensor.
+ Tensor:
+ Router logits.
+ """
+ bsz, length, emb_size = layer_input.size()
+ layer_input = layer_input.reshape(-1, emb_size)
+ _, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)
+
+ expert_inputs = layer_input[batch_index]
+ hidden_states = self.input_linear(expert_inputs, expert_size)
+ chunked_hidden_states = hidden_states.chunk(2, dim=-1)
+ hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
+ expert_outputs = self.output_linear(hidden_states, expert_size)
+
+ expert_outputs = expert_outputs * batch_gates[:, None]
+
+ zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
+ layer_output = zeros.index_add(0, batch_index, expert_outputs)
+ layer_output = layer_output.view(bsz, length, self.input_size)
+ return layer_output, router_logits
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# copied from transformers.models.granite.modeling_granite.GraniteAttention with Granite->GraniteMoeShared
+# no longer copied after attention refactors
+class GraniteMoeSharedAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: GraniteMoeSharedConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.is_causal = True
+
+ self.scaling = config.attention_multiplier
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings if position_embeddings is not None else (None, None)
+ if position_embeddings is not None:
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.view(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = GraniteMoeSharedAttention(config=config, layer_idx=layer_idx)
+ if config.num_local_experts > 0:
+ self.block_sparse_moe = GraniteMoeSharedMoE(config)
+ self.input_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.residual_multiplier = config.residual_multiplier
+ self.shared_mlp = None if config.shared_intermediate_size == 0 else GraniteMoeSharedMLP(config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ output_router_logits: Optional[bool] = False,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs: Unpack[GraniteFlashAttentionKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ output_router_logits (`bool`, *optional*):
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
+ should not be returned during inference.
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for
+ padding-free training and/or improve torch.compile performance.
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = residual + hidden_states * self.residual_multiplier
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
+
+ if self.shared_mlp is None:
+ hidden_states = moe_hidden_states
+ else:
+ hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
+
+ del moe_hidden_states
+
+ hidden_states = residual + hidden_states * self.residual_multiplier
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if output_router_logits:
+ outputs += (router_logits,)
+
+ return outputs
+
+
+@auto_docstring
+class GraniteMoeSharedPreTrainedModel(PreTrainedModel):
+ config: GraniteMoeSharedConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GraniteMoeSharedDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, GraniteMoeSharedParallelExperts):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+
+class GraniteMoeSharedRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: GraniteMoeSharedConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@auto_docstring
+class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
+ def __init__(self, config: GraniteMoeSharedConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [GraniteMoeSharedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.gradient_checkpointing = False
+
+ self.embedding_multiplier = config.embedding_multiplier
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+
+ self.position_embedding_type = config.position_embedding_type
+ self.rotary_emb = GraniteMoeSharedRotaryEmbedding(config) if self.position_embedding_type == "rope" else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ inputs_embeds = inputs_embeds * self.embedding_multiplier
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ position_embeddings = None
+ # create position embeddings to be shared across the decoder layers
+ if self.rotary_emb is not None:
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_router_logits = () if output_router_logits else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ output_router_logits=output_router_logits,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if output_router_logits:
+ all_router_logits += (layer_outputs[-1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
+ )
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ router_logits=all_router_logits,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, "BlockMask"],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+def load_balancing_loss_func(
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
+ num_experts: Optional[int] = None,
+ top_k=2,
+ attention_mask: Optional[torch.Tensor] = None,
+) -> Union[torch.Tensor, int]:
+ r"""
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+
+ Args:
+ gate_logits:
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+ shape [batch_size X sequence_length, num_experts].
+ num_experts:
+ Number of experts
+ top_k:
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
+ parameter.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention_mask used in forward function
+ shape [batch_size X sequence_length] if not None.
+
+ Returns:
+ The auxiliary loss.
+ """
+ if gate_logits is None or not isinstance(gate_logits, tuple):
+ return 0
+
+ if isinstance(gate_logits, tuple):
+ compute_device = gate_logits[0].device
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
+
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
+
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+ if attention_mask is None:
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
+ else:
+ batch_size, sequence_length = attention_mask.shape
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+ expert_attention_mask = (
+ attention_mask[None, :, :, None, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+ .reshape(-1, top_k, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+ expert_attention_mask, dim=0
+ )
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+ router_per_expert_attention_mask = (
+ attention_mask[None, :, :, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1]))
+ .reshape(-1, routing_weights.shape[1])
+ .to(compute_device)
+ )
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+ router_per_expert_attention_mask, dim=0
+ )
+
+ device_index = routing_weights.device.index if routing_weights.device.index is not None else 0
+ rank = routing_weights.shape[1] * int(device_index)
+ overall_loss = torch.sum(
+ tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
+ )
+ return overall_loss * num_experts
+
+
+class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: GraniteMoeSharedConfig):
+ super().__init__(config)
+ self.model = GraniteMoeSharedModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.router_aux_loss_coef = config.router_aux_loss_coef
+ self.num_experts = config.num_local_experts
+ self.num_experts_per_tok = config.num_experts_per_tok
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> Union[tuple, MoeCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, GraniteMoeSharedForCausalLM
+
+ >>> model = GraniteMoeSharedForCausalLM.from_pretrained("ibm/PowerMoE-3b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("ibm/PowerMoE-3b")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_router_logits=output_router_logits,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ # Only compute necessary logits
+ hidden_states = outputs[0]
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ logits = logits / self.config.logits_scaling
+
+ loss = None
+ if labels is not None:
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ # Flatten the tokens
+ loss = self.loss_function(
+ logits,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits if return_dict else outputs[-1],
+ self.num_experts,
+ self.num_experts_per_tok,
+ attention_mask,
+ )
+ if labels is not None:
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ if output_router_logits:
+ output = (aux_loss,) + output
+ return (loss,) + output if loss is not None else output
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
+
+
+__all__ = ["GraniteMoeSharedForCausalLM", "GraniteMoeSharedModel", "GraniteMoeSharedPreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/granitemoeshared/modular_granitemoeshared.py b/venv/lib/python3.13/site-packages/transformers/models/granitemoeshared/modular_granitemoeshared.py
new file mode 100644
index 0000000000000000000000000000000000000000..529a07f0317a0be3ffaea3582c71b687d762a4bd
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/granitemoeshared/modular_granitemoeshared.py
@@ -0,0 +1,200 @@
+# coding=utf-8
+# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, TypedDict
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...processing_utils import Unpack
+from ...utils import logging
+from ...utils.deprecation import deprecate_kwarg
+from ..granitemoe.modeling_granitemoe import (
+ GraniteMoeDecoderLayer,
+ GraniteMoeForCausalLM,
+ GraniteMoeModel,
+ GraniteMoePreTrainedModel,
+)
+from .configuration_granitemoeshared import GraniteMoeSharedConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class GraniteFlashAttentionKwargs(TypedDict, total=False):
+ """
+ Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
+ Use cases include padding-free training and fewer `torch.compile` graph breaks.
+
+ Attributes:
+ cu_seq_lens_q (`torch.LongTensor`)
+ Gets cumulative sequence length for query state.
+ cu_seq_lens_k (`torch.LongTensor`)
+ Gets cumulative sequence length for key state.
+ max_length_q (`int`):
+ Maximum sequence length for query state.
+ max_length_k (`int`):
+ Maximum sequence length for key state.
+ seq_idx (`torch.IntTensor):
+ Index of each packed sequence.
+ """
+
+ cu_seq_lens_q: torch.LongTensor
+ cu_seq_lens_k: torch.LongTensor
+ max_length_q: int
+ max_length_k: int
+ seq_idx: torch.IntTensor
+
+
+class GraniteMoeSharedMLP(nn.Module):
+ """
+ MLP layer for shared experts
+
+ Args:
+ config:
+ Configuration object with model hyperparameters.
+ """
+
+ def __init__(self, config: GraniteMoeSharedConfig):
+ super().__init__()
+
+ self.input_size = config.hidden_size
+ self.hidden_size = config.shared_intermediate_size
+ self.activation = ACT2FN[config.hidden_act]
+ self.input_linear = nn.Linear(self.input_size, self.hidden_size * 2, bias=False)
+ self.output_linear = nn.Linear(self.hidden_size, self.input_size, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.input_linear(hidden_states)
+ chunked_hidden_states = hidden_states.chunk(2, dim=-1)
+ hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
+ hidden_states = self.output_linear(hidden_states)
+ return hidden_states
+
+
+class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer):
+ def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.shared_mlp = None if config.shared_intermediate_size == 0 else GraniteMoeSharedMLP(config)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ output_router_logits: Optional[bool] = False,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs: Unpack[GraniteFlashAttentionKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ output_router_logits (`bool`, *optional*):
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
+ should not be returned during inference.
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for
+ padding-free training and/or improve torch.compile performance.
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = residual + hidden_states * self.residual_multiplier
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
+
+ if self.shared_mlp is None:
+ hidden_states = moe_hidden_states
+ else:
+ hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
+
+ del moe_hidden_states
+
+ hidden_states = residual + hidden_states * self.residual_multiplier
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if output_router_logits:
+ outputs += (router_logits,)
+
+ return outputs
+
+
+class GraniteMoeSharedPreTrainedModel(GraniteMoePreTrainedModel):
+ config: GraniteMoeSharedConfig
+ _no_split_modules = ["GraniteMoeSharedDecoderLayer"]
+
+
+class GraniteMoeSharedModel(GraniteMoeModel):
+ def __init__(self, config: GraniteMoeSharedConfig):
+ super().__init__(config)
+ self.layers = nn.ModuleList(
+ [GraniteMoeSharedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+
+
+class GraniteMoeSharedForCausalLM(GraniteMoeForCausalLM):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: GraniteMoeSharedConfig):
+ super().__init__(config)
+ self.model = GraniteMoeSharedModel(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+
+__all__ = ["GraniteMoeSharedForCausalLM", "GraniteMoeSharedModel", "GraniteMoeSharedPreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/groupvit/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/groupvit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab7fa27d09d16590d6ba25185c9ef9c4974e2ea1
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/groupvit/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_groupvit import *
+ from .modeling_groupvit import *
+ from .modeling_tf_groupvit import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/groupvit/configuration_groupvit.py b/venv/lib/python3.13/site-packages/transformers/models/groupvit/configuration_groupvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..662447e7e98433cb8226d2ce636a7392ba29f214
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/groupvit/configuration_groupvit.py
@@ -0,0 +1,407 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""GroupViT model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+from typing import TYPE_CHECKING, Any, Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+if TYPE_CHECKING:
+ from ...processing_utils import ProcessorMixin
+ from ...utils import TensorType
+
+
+logger = logging.get_logger(__name__)
+
+
+class GroupViTTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GroupViTTextModel`]. It is used to instantiate an
+ GroupViT model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the GroupViT
+ [nvidia/groupvit-gcc-yfcc](https://huggingface.co/nvidia/groupvit-gcc-yfcc) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 49408):
+ Vocabulary size of the GroupViT text model. Defines the number of different tokens that can be represented
+ by the `inputs_ids` passed when calling [`GroupViTModel`].
+ hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 4):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ max_position_embeddings (`int`, *optional*, defaults to 77):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 1.0):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+
+ Example:
+
+ ```python
+ >>> from transformers import GroupViTTextConfig, GroupViTTextModel
+
+ >>> # Initializing a GroupViTTextModel with nvidia/groupvit-gcc-yfcc style configuration
+ >>> configuration = GroupViTTextConfig()
+
+ >>> model = GroupViTTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "groupvit_text_model"
+ base_config_key = "text_config"
+
+ def __init__(
+ self,
+ vocab_size=49408,
+ hidden_size=256,
+ intermediate_size=1024,
+ num_hidden_layers=12,
+ num_attention_heads=4,
+ max_position_embeddings=77,
+ hidden_act="quick_gelu",
+ layer_norm_eps=1e-5,
+ dropout=0.0,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ pad_token_id=1,
+ bos_token_id=49406,
+ eos_token_id=49407,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.dropout = dropout
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.attention_dropout = attention_dropout
+
+
+class GroupViTVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GroupViTVisionModel`]. It is used to instantiate
+ an GroupViT model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the GroupViT
+ [nvidia/groupvit-gcc-yfcc](https://huggingface.co/nvidia/groupvit-gcc-yfcc) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 384):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 1536):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ depths (`list[int]`, *optional*, defaults to [6, 3, 3]):
+ The number of layers in each encoder block.
+ num_group_tokens (`list[int]`, *optional*, defaults to [64, 8, 0]):
+ The number of group tokens for each stage.
+ num_output_groups (`list[int]`, *optional*, defaults to [64, 8, 8]):
+ The number of output groups for each stage, 0 means no group.
+ num_attention_heads (`int`, *optional*, defaults to 6):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 1.0):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+
+ Example:
+
+ ```python
+ >>> from transformers import GroupViTVisionConfig, GroupViTVisionModel
+
+ >>> # Initializing a GroupViTVisionModel with nvidia/groupvit-gcc-yfcc style configuration
+ >>> configuration = GroupViTVisionConfig()
+
+ >>> model = GroupViTVisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "groupvit_vision_model"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=384,
+ intermediate_size=1536,
+ depths=[6, 3, 3],
+ num_hidden_layers=12,
+ num_group_tokens=[64, 8, 0],
+ num_output_groups=[64, 8, 8],
+ num_attention_heads=6,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ hidden_act="gelu",
+ layer_norm_eps=1e-5,
+ dropout=0.0,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ assign_eps=1.0,
+ assign_mlp_ratio=[0.5, 4],
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.depths = depths
+ if num_hidden_layers != sum(depths):
+ logger.warning(
+ f"Manually setting num_hidden_layers to {num_hidden_layers}, but we expect num_hidden_layers ="
+ f" sum(depth) = {sum(depths)}"
+ )
+ self.num_hidden_layers = num_hidden_layers
+ self.num_group_tokens = num_group_tokens
+ self.num_output_groups = num_output_groups
+ self.num_attention_heads = num_attention_heads
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.assign_eps = assign_eps
+ self.assign_mlp_ratio = assign_mlp_ratio
+
+
+class GroupViTConfig(PretrainedConfig):
+ r"""
+ [`GroupViTConfig`] is the configuration class to store the configuration of a [`GroupViTModel`]. It is used to
+ instantiate a GroupViT model according to the specified arguments, defining the text model and vision model
+ configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the GroupViT
+ [nvidia/groupvit-gcc-yfcc](https://huggingface.co/nvidia/groupvit-gcc-yfcc) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`GroupViTTextConfig`].
+ vision_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`GroupViTVisionConfig`].
+ projection_dim (`int`, *optional*, defaults to 256):
+ Dimensionality of text and vision projection layers.
+ projection_intermediate_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of intermediate layer of text and vision projection layers.
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
+ The initial value of the *logit_scale* parameter. Default is used as per the original GroupViT
+ implementation.
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+ """
+
+ model_type = "groupvit"
+ sub_configs = {"text_config": GroupViTTextConfig, "vision_config": GroupViTVisionConfig}
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ projection_dim=256,
+ projection_intermediate_dim=4096,
+ logit_scale_init_value=2.6592,
+ **kwargs,
+ ):
+ # If `_config_dict` exist, we use them for the backward compatibility.
+ # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
+ # of confusion!).
+ text_config_dict = kwargs.pop("text_config_dict", None)
+ vision_config_dict = kwargs.pop("vision_config_dict", None)
+
+ super().__init__(**kwargs)
+
+ # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
+ # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
+ # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
+ if text_config_dict is not None:
+ if text_config is None:
+ text_config = {}
+
+ # This is the complete result when using `text_config_dict`.
+ _text_config_dict = GroupViTTextConfig(**text_config_dict).to_dict()
+
+ # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
+ for key, value in _text_config_dict.items():
+ if key in text_config and value != text_config[key] and key != "transformers_version":
+ # If specified in `text_config_dict`
+ if key in text_config_dict:
+ message = (
+ f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
+ f'The value `text_config_dict["{key}"]` will be used instead.'
+ )
+ # If inferred from default argument values (just to be super careful)
+ else:
+ message = (
+ f"`text_config_dict` is provided which will be used to initialize `GroupViTTextConfig`. "
+ f'The value `text_config["{key}"]` will be overridden.'
+ )
+ logger.info(message)
+
+ # Update all values in `text_config` with the ones in `_text_config_dict`.
+ text_config.update(_text_config_dict)
+
+ if vision_config_dict is not None:
+ if vision_config is None:
+ vision_config = {}
+
+ # This is the complete result when using `vision_config_dict`.
+ _vision_config_dict = GroupViTVisionConfig(**vision_config_dict).to_dict()
+ # convert keys to string instead of integer
+ if "id2label" in _vision_config_dict:
+ _vision_config_dict["id2label"] = {
+ str(key): value for key, value in _vision_config_dict["id2label"].items()
+ }
+
+ # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
+ for key, value in _vision_config_dict.items():
+ if key in vision_config and value != vision_config[key] and key != "transformers_version":
+ # If specified in `vision_config_dict`
+ if key in vision_config_dict:
+ message = (
+ f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
+ f'values. The value `vision_config_dict["{key}"]` will be used instead.'
+ )
+ # If inferred from default argument values (just to be super careful)
+ else:
+ message = (
+ f"`vision_config_dict` is provided which will be used to initialize `GroupViTVisionConfig`."
+ f' The value `vision_config["{key}"]` will be overridden.'
+ )
+ logger.info(message)
+
+ # Update all values in `vision_config` with the ones in `_vision_config_dict`.
+ vision_config.update(_vision_config_dict)
+
+ if text_config is None:
+ text_config = {}
+ logger.info("`text_config` is `None`. Initializing the `GroupViTTextConfig` with default values.")
+
+ if vision_config is None:
+ vision_config = {}
+ logger.info("`vision_config` is `None`. initializing the `GroupViTVisionConfig` with default values.")
+
+ self.text_config = GroupViTTextConfig(**text_config)
+ self.vision_config = GroupViTVisionConfig(**vision_config)
+
+ self.projection_dim = projection_dim
+ self.projection_intermediate_dim = projection_intermediate_dim
+ self.logit_scale_init_value = logit_scale_init_value
+ self.initializer_range = 0.02
+ self.initializer_factor = 1.0
+ self.output_segmentation = False
+
+
+class GroupViTOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("input_ids", {0: "batch", 1: "sequence"}),
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+ ("attention_mask", {0: "batch", 1: "sequence"}),
+ ]
+ )
+
+ @property
+ def outputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("logits_per_image", {0: "batch"}),
+ ("logits_per_text", {0: "batch"}),
+ ("text_embeds", {0: "batch"}),
+ ("image_embeds", {0: "batch"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-4
+
+ def generate_dummy_inputs(
+ self,
+ processor: "ProcessorMixin",
+ batch_size: int = -1,
+ seq_length: int = -1,
+ framework: Optional["TensorType"] = None,
+ ) -> Mapping[str, Any]:
+ text_input_dict = super().generate_dummy_inputs(
+ processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework
+ )
+ image_input_dict = super().generate_dummy_inputs(
+ processor.image_processor, batch_size=batch_size, framework=framework
+ )
+ return {**text_input_dict, **image_input_dict}
+
+ @property
+ def default_onnx_opset(self) -> int:
+ return 14
+
+
+__all__ = ["GroupViTConfig", "GroupViTOnnxConfig", "GroupViTTextConfig", "GroupViTVisionConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/groupvit/modeling_groupvit.py b/venv/lib/python3.13/site-packages/transformers/models/groupvit/modeling_groupvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..3335df375da9997c0826add782822328a8b142ad
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/groupvit/modeling_groupvit.py
@@ -0,0 +1,1431 @@
+# coding=utf-8
+# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch GroupViT model."""
+
+import collections.abc
+from dataclasses import dataclass
+from typing import Any, Optional, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from ...modeling_utils import PreTrainedModel
+from ...utils import ModelOutput, auto_docstring, filter_out_non_signature_kwargs, logging, torch_int
+from .configuration_groupvit import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+# contrastive loss function, adapted from
+# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
+def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
+
+
+# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->groupvit
+def groupvit_loss(similarity: torch.Tensor) -> torch.Tensor:
+ caption_loss = contrastive_loss(similarity)
+ image_loss = contrastive_loss(similarity.t())
+ return (caption_loss + image_loss) / 2.0
+
+
+def hard_softmax(logits: torch.Tensor, dim: int):
+ y_soft = logits.softmax(dim)
+ # Straight through.
+ index = y_soft.max(dim, keepdim=True)[1]
+ y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
+ ret = y_hard - y_soft.detach() + y_soft
+
+ return ret
+
+
+def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> torch.Tensor:
+ # more stable https://github.com/pytorch/pytorch/issues/41663
+ gumbel_dist = torch.distributions.gumbel.Gumbel(
+ torch.tensor(0.0, device=logits.device, dtype=logits.dtype),
+ torch.tensor(1.0, device=logits.device, dtype=logits.dtype),
+ )
+ gumbels = gumbel_dist.sample(logits.shape)
+
+ gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
+ y_soft = gumbels.softmax(dim)
+
+ if hard:
+ # Straight through.
+ index = y_soft.max(dim, keepdim=True)[1]
+ y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
+ ret = y_hard - y_soft.detach() + y_soft
+ else:
+ # Reparameterization trick.
+ ret = y_soft
+ return ret
+
+
+def resize_attention_map(attentions, height, width, align_corners=False):
+ """
+ Args:
+ attentions (`torch.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width]
+ height (`int`): height of the output attention map
+ width (`int`): width of the output attention map
+ align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`.
+
+ Returns:
+ `torch.Tensor`: resized attention map of shape [batch_size, groups, height, width]
+ """
+
+ scale = (height * width // attentions.shape[2]) ** 0.5
+ if height > width:
+ feat_width = int(np.round(width / scale))
+ feat_height = attentions.shape[2] // feat_width
+ else:
+ feat_height = int(np.round(height / scale))
+ feat_width = attentions.shape[2] // feat_height
+
+ batch_size = attentions.shape[0]
+ groups = attentions.shape[1] # number of group token
+ # [batch_size, groups, height*width, groups] -> [batch_size, groups, height, width]
+ attentions = attentions.reshape(batch_size, groups, feat_height, feat_width)
+ attentions = nn.functional.interpolate(
+ attentions, size=(height, width), mode="bilinear", align_corners=align_corners
+ )
+ return attentions
+
+
+def get_grouping_from_attentions(attentions, hw_shape):
+ """
+ Args:
+ attentions (`tuple(torch.FloatTensor)`: tuple of attention maps returned by `GroupViTVisionTransformer`
+ hw_shape (`tuple(int)`): height and width of the output attention map
+ Returns:
+ `torch.Tensor`: the attention map of shape [batch_size, groups, height, width]
+ """
+
+ attn_maps = []
+ with torch.no_grad():
+ prev_attn_masks = None
+ for attn_masks in attentions:
+ # [batch_size, num_groups, height x width] -> [batch_size, height x width, num_groups]
+ attn_masks = attn_masks.permute(0, 2, 1).contiguous()
+ if prev_attn_masks is None:
+ prev_attn_masks = attn_masks
+ else:
+ prev_attn_masks = prev_attn_masks @ attn_masks
+ # [batch_size, heightxwidth, num_groups] -> [batch_size, num_groups, heightxwidth] -> [batch_size, num_groups, height, width]
+ cur_attn_map = resize_attention_map(prev_attn_masks.permute(0, 2, 1).contiguous(), *hw_shape)
+ attn_maps.append(cur_attn_map)
+
+ # [batch_size, num_groups, height, width]
+ final_grouping = attn_maps[-1]
+
+ return final_grouping
+
+
+class GroupViTCrossAttentionLayer(nn.Module):
+ def __init__(self, config: GroupViTVisionConfig):
+ super().__init__()
+ self.attn = GroupViTAttention(config)
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = GroupViTMLP(config)
+ self.norm_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, query, key):
+ x = query
+ x = x + self.attn(query, encoder_hidden_states=key)[0]
+ x = x + self.mlp(self.norm2(x))
+ x = self.norm_post(x)
+ return x
+
+
+class GroupViTAssignAttention(nn.Module):
+ def __init__(self, config: GroupViTVisionConfig):
+ super().__init__()
+ self.scale = config.hidden_size**-0.5
+
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
+ self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
+ self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size)
+ self.assign_eps = config.assign_eps
+
+ def get_attn(self, attn, gumbel=True, hard=True):
+ if gumbel and self.training:
+ attn = gumbel_softmax(attn, dim=-2, hard=hard)
+ else:
+ if hard:
+ attn = hard_softmax(attn, dim=-2)
+ else:
+ attn = nn.functional.softmax(attn, dim=-2)
+
+ return attn
+
+ def forward(self, query, key):
+ value = key
+ # [batch_size, query_length, channels]
+ query = self.q_proj(query)
+
+ # [batch_size, key_length, channels]
+ key = self.k_proj(key)
+
+ # [batch_size, key_length, channels]
+ value = self.v_proj(value)
+
+ # [batch_size, query_length, key_length]
+ raw_attn = (query @ key.transpose(-2, -1)) * self.scale
+
+ attn = self.get_attn(raw_attn)
+ soft_attn = self.get_attn(raw_attn, gumbel=False, hard=False)
+
+ attn = attn / (attn.sum(dim=-1, keepdim=True) + self.assign_eps)
+
+ out = attn @ value
+
+ out = self.proj(out)
+
+ return out, soft_attn
+
+
+class GroupViTTokenAssign(nn.Module):
+ def __init__(self, config: GroupViTVisionConfig, num_group_token, num_output_group):
+ super().__init__()
+ self.num_output_group = num_output_group
+ # norm on group_tokens
+ self.norm_tokens = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ assign_mlp_ratio = (
+ config.assign_mlp_ratio
+ if isinstance(config.assign_mlp_ratio, collections.abc.Iterable)
+ else (config.assign_mlp_ratio, config.assign_mlp_ratio)
+ )
+ tokens_dim, channels_dim = [int(x * config.hidden_size) for x in assign_mlp_ratio]
+ self.mlp_inter = GroupViTMixerMLP(config, num_group_token, tokens_dim, num_output_group)
+ self.norm_post_tokens = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ # norm on x
+ self.norm_x = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pre_assign_attn = GroupViTCrossAttentionLayer(config)
+
+ self.assign = GroupViTAssignAttention(config)
+ self.norm_new_x = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp_channels = GroupViTMLP(config, config.hidden_size, channels_dim, config.hidden_size)
+
+ def project_group_token(self, group_tokens):
+ """
+ Args:
+ group_tokens (torch.Tensor): group tokens, [batch_size, num_group_tokens, channels]
+
+ Returns:
+ projected_group_tokens (torch.Tensor): [batch_size, num_output_groups, channels]
+ """
+ # [B, num_output_groups, C] <- [B, num_group_tokens, C]
+ projected_group_tokens = self.mlp_inter(group_tokens)
+ projected_group_tokens = self.norm_post_tokens(projected_group_tokens)
+ return projected_group_tokens
+
+ def forward(self, image_tokens, group_tokens):
+ """
+ Args:
+ image_tokens (`torch.Tensor`): image tokens, of shape [batch_size, input_length, channels]
+ group_tokens (`torch.Tensor`): group tokens, [batch_size, num_group_tokens, channels]
+ """
+
+ group_tokens = self.norm_tokens(group_tokens)
+ image_tokens = self.norm_x(image_tokens)
+ # [batch_size, num_output_groups, channels]
+ projected_group_tokens = self.project_group_token(group_tokens)
+ projected_group_tokens = self.pre_assign_attn(projected_group_tokens, image_tokens)
+ new_image_tokens, attention = self.assign(projected_group_tokens, image_tokens)
+ new_image_tokens += projected_group_tokens
+
+ new_image_tokens = new_image_tokens + self.mlp_channels(self.norm_new_x(new_image_tokens))
+
+ return new_image_tokens, attention
+
+
+@dataclass
+@auto_docstring
+class GroupViTModelOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
+ Contrastive loss for image-text similarity.
+ logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
+ similarity scores.
+ logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
+ similarity scores.
+ segmentation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
+ Classification scores for each pixel.
+
+
+
+ The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
+ to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
+ original image size as post-processing. You should always check your logits shape and resize as needed.
+
+
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The text embeddings obtained by applying the projection layer to the pooled output of
+ [`GroupViTTextModel`].
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The image embeddings obtained by applying the projection layer to the pooled output of
+ [`GroupViTVisionModel`].
+ text_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`GroupViTTextModel`].
+ vision_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`GroupViTVisionModel`].
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits_per_image: Optional[torch.FloatTensor] = None
+ logits_per_text: Optional[torch.FloatTensor] = None
+ segmentation_logits: Optional[torch.FloatTensor] = None
+ text_embeds: Optional[torch.FloatTensor] = None
+ image_embeds: Optional[torch.FloatTensor] = None
+ text_model_output: BaseModelOutputWithPooling = None
+ vision_model_output: BaseModelOutputWithPooling = None
+
+ def to_tuple(self) -> tuple[Any]:
+ return tuple(
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
+ for k in self.keys()
+ )
+
+
+class GroupViTPatchEmbeddings(nn.Module):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self,
+ image_size: int = 224,
+ patch_size: Union[int, tuple[int, int]] = 16,
+ num_channels: int = 3,
+ embed_dim: int = 768,
+ ):
+ super().__init__()
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if not interpolate_pos_encoding:
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
+ x = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return x
+
+
+class GroupViTVisionEmbeddings(nn.Module):
+ def __init__(self, config: GroupViTVisionConfig):
+ super().__init__()
+
+ self.patch_embeddings = GroupViTPatchEmbeddings(
+ image_size=config.image_size,
+ patch_size=config.patch_size,
+ num_channels=config.num_channels,
+ embed_dim=config.hidden_size,
+ )
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches, config.hidden_size))
+ self.dropout = nn.Dropout(config.dropout)
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.patch_size = config.patch_size
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing and no class embeddings.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1]
+ num_positions = self.position_embeddings.shape[1]
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ patch_pos_embed = self.position_embeddings
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return patch_pos_embed
+
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+ embeddings = self.layernorm(embeddings)
+
+ batch_size, seq_len, _ = embeddings.size()
+
+ # add positional encoding to each token
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->GroupViT
+class GroupViTTextEmbeddings(nn.Module):
+ def __init__(self, config: GroupViTTextConfig):
+ super().__init__()
+ embed_dim = config.hidden_size
+
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+ max_position_embedding = self.position_embedding.weight.shape[0]
+
+ if seq_length > max_position_embedding:
+ raise ValueError(
+ f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
+ f"{seq_length} and max_position_embeddings: {max_position_embedding}"
+ )
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+
+ return embeddings
+
+
+class GroupViTStage(nn.Module):
+ """This corresponds to the `GroupingLayer` class in the GroupViT implementation."""
+
+ def __init__(
+ self,
+ config: GroupViTVisionConfig,
+ depth: int,
+ num_prev_group_token: int,
+ num_group_token: int,
+ num_output_group: int,
+ ):
+ super().__init__()
+ self.depth = depth
+ self.num_group_token = num_group_token
+ if num_group_token > 0:
+ self.group_token = nn.Parameter(torch.zeros(1, num_group_token, config.hidden_size))
+ else:
+ self.group_token = None
+ self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(depth)])
+
+ if num_group_token > 0:
+ self.downsample = GroupViTTokenAssign(
+ config=config,
+ num_group_token=num_group_token,
+ num_output_group=num_output_group,
+ )
+ else:
+ self.downsample = None
+
+ if num_prev_group_token > 0 and num_group_token > 0:
+ self.group_projector = nn.Sequential(
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
+ GroupViTMixerMLP(config, num_prev_group_token, config.hidden_size // 2, num_group_token),
+ )
+ else:
+ self.group_projector = None
+
+ @property
+ def with_group_token(self):
+ return self.group_token is not None
+
+ def split_x(self, x):
+ if self.with_group_token:
+ return x[:, : -self.num_group_token], x[:, -self.num_group_token :]
+ else:
+ return x, None
+
+ def concat_x(self, x: torch.Tensor, group_token: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if group_token is None:
+ return x
+ return torch.cat([x, group_token], dim=1)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ prev_group_token: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the grouping tensors of Grouping block.
+ """
+ if self.with_group_token:
+ group_token = self.group_token.expand(hidden_states.size(0), -1, -1)
+ if self.group_projector is not None:
+ group_token = group_token + self.group_projector(prev_group_token)
+ else:
+ group_token = None
+
+ x = hidden_states
+
+ cat_x = self.concat_x(x, group_token)
+ for layer in self.layers:
+ layer_out = layer(cat_x, attention_mask=None, causal_attention_mask=None)
+ cat_x = layer_out[0]
+
+ x, group_token = self.split_x(cat_x)
+
+ attention = None
+ if self.downsample is not None:
+ x, attention = self.downsample(x, group_token)
+
+ outputs = (x, group_token)
+ if output_attentions:
+ outputs = outputs + (attention,)
+
+ return outputs
+
+
+class GroupViTMLP(nn.Module):
+ def __init__(
+ self,
+ config: GroupViTVisionConfig,
+ hidden_size: Optional[int] = None,
+ intermediate_size: Optional[int] = None,
+ output_size: Optional[int] = None,
+ ):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ hidden_size = hidden_size if hidden_size is not None else config.hidden_size
+ intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
+ output_size = output_size if output_size is not None else hidden_size
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
+ self.fc2 = nn.Linear(intermediate_size, output_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class GroupViTMixerMLP(GroupViTMLP):
+ def forward(self, x):
+ x = super().forward(x.transpose(1, 2))
+ return x.transpose(1, 2)
+
+
+class GroupViTAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, embed_dim = hidden_states.size()
+ is_cross_attention = encoder_hidden_states is not None
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scale
+ if is_cross_attention:
+ key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz)
+ else:
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GroupViT
+class GroupViTEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: GroupViTConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = GroupViTAttention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = GroupViTMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ causal_attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class GroupViTPreTrainedModel(PreTrainedModel):
+ config: GroupViTConfig
+ base_model_prefix = "groupvit"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+
+ init_range = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=init_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ factor = self.config.initializer_factor
+ if isinstance(module, GroupViTTextEmbeddings):
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
+ elif isinstance(module, GroupViTAttention):
+ factor = self.config.initializer_factor
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+ out_proj_std = (module.embed_dim**-0.5) * factor
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
+ elif isinstance(module, GroupViTMLP):
+ factor = self.config.initializer_factor
+ in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
+ nn.init.normal_(module.fc1.weight, std=fc_std)
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
+
+
+class GroupViTVisionEncoder(nn.Module):
+ def __init__(self, config: GroupViTVisionConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.stages = nn.ModuleList(
+ [
+ GroupViTStage(
+ config=config,
+ depth=config.depths[i],
+ num_group_token=config.num_group_tokens[i],
+ num_output_group=config.num_output_groups[i],
+ num_prev_group_token=config.num_output_groups[i - 1] if i > 0 else 0,
+ )
+ for i in range(len(config.depths))
+ ]
+ )
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ all_hidden_states = () if output_hidden_states else None
+ all_groupings = () if output_attentions else None
+
+ group_tokens = None
+
+ for i, stage in enumerate(self.stages):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_outputs = stage(hidden_states, group_tokens, output_attentions)
+
+ hidden_states = layer_outputs[0]
+ group_tokens = layer_outputs[1]
+
+ if output_attentions and layer_outputs[2] is not None:
+ all_groupings = all_groupings + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_groupings] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_groupings
+ )
+
+
+class GroupViTTextEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self-attention layers. Each layer is a
+ [`GroupViTEncoderLayer`].
+
+ Args:
+ config: GroupViTTextConfig
+ """
+
+ def __init__(self, config: GroupViTTextConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class GroupViTTextTransformer(nn.Module):
+ def __init__(self, config: GroupViTTextConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+ self.embeddings = GroupViTTextEmbeddings(config)
+ self.encoder = GroupViTTextEncoder(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ # For `pooled_output` computation
+ self.eos_token_id = config.eos_token_id
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is None:
+ raise ValueError("You have to specify input_ids")
+
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
+
+ # CLIP's text model uses causal mask, prepare it here.
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
+ causal_attention_mask = _create_4d_causal_attention_mask(
+ input_shape, hidden_states.dtype, device=hidden_states.device
+ )
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+ if self.eos_token_id == 2:
+ # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
+ # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
+ # ------------------------------------------------------------
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
+ pooled_output = last_hidden_state[
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
+ ]
+ else:
+ # The config gets updated `eos_token_id` from PR #24773 (so the use of extra new tokens is possible)
+ pooled_output = last_hidden_state[
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
+ # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
+ # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
+ (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
+ .int()
+ .argmax(dim=-1),
+ ]
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class GroupViTTextModel(GroupViTPreTrainedModel):
+ config: GroupViTTextConfig
+
+ def __init__(self, config: GroupViTTextConfig):
+ super().__init__(config)
+ self.text_model = GroupViTTextTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.text_model.embeddings.token_embedding
+
+ def set_input_embeddings(self, value):
+ self.text_model.embeddings.token_embedding = value
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import CLIPTokenizer, GroupViTTextModel
+
+ >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> model = GroupViTTextModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+class GroupViTVisionTransformer(nn.Module):
+ def __init__(self, config: GroupViTVisionConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = GroupViTVisionEmbeddings(config)
+ self.encoder = GroupViTVisionEncoder(config)
+ self.layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(
+ hidden_states=hidden_states,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+
+ # normalize the last hidden state
+ last_hidden_state = self.layernorm(last_hidden_state)
+ pooled_output = last_hidden_state.mean(dim=1)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class GroupViTVisionModel(GroupViTPreTrainedModel):
+ config: GroupViTVisionConfig
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: GroupViTVisionConfig):
+ super().__init__(config)
+ self.vision_model = GroupViTVisionTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> GroupViTPatchEmbeddings:
+ return self.vision_model.embeddings.patch_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ r"""
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, GroupViTVisionModel
+
+ >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> model = GroupViTVisionModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
+ ```"""
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+@auto_docstring
+class GroupViTModel(GroupViTPreTrainedModel):
+ config: GroupViTConfig
+
+ def __init__(self, config: GroupViTConfig):
+ super().__init__(config)
+
+ if not isinstance(config.text_config, GroupViTTextConfig):
+ raise TypeError(
+ "config.text_config is expected to be of type GroupViTTextConfig but is of type"
+ f" {type(config.text_config)}."
+ )
+
+ if not isinstance(config.vision_config, GroupViTVisionConfig):
+ raise TypeError(
+ "config.vision_config is expected to be of type GroupViTVisionConfig but is of type"
+ f" {type(config.vision_config)}."
+ )
+
+ text_config = config.text_config
+ vision_config = config.vision_config
+
+ self.projection_dim = config.projection_dim
+ self.projection_intermediate_dim = config.projection_intermediate_dim
+ self.text_embed_dim = text_config.hidden_size
+ self.vision_embed_dim = vision_config.hidden_size
+
+ self.text_model = GroupViTTextTransformer(text_config)
+ self.vision_model = GroupViTVisionTransformer(vision_config)
+
+ self.visual_projection = nn.Sequential(
+ nn.Linear(self.vision_embed_dim, self.projection_intermediate_dim, bias=True),
+ nn.BatchNorm1d(self.projection_intermediate_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.projection_intermediate_dim, self.projection_dim, bias=True),
+ )
+ self.text_projection = nn.Sequential(
+ nn.Linear(self.text_embed_dim, self.projection_intermediate_dim, bias=True),
+ nn.BatchNorm1d(self.projection_intermediate_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.projection_intermediate_dim, self.projection_dim, bias=True),
+ )
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @filter_out_non_signature_kwargs()
+ @auto_docstring
+ def get_text_features(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output of [`GroupViTTextModel`].
+
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> from transformers import CLIPTokenizer, GroupViTModel
+
+ >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+ >>> with torch.inference_mode():
+ ... text_features = model.get_text_features(**inputs)
+ ```"""
+ text_outputs: BaseModelOutputWithPooling = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ )
+ text_features = self.text_projection(text_outputs.pooler_output)
+ return text_features
+
+ @filter_out_non_signature_kwargs()
+ @auto_docstring
+ def get_image_features(self, pixel_values: torch.Tensor) -> torch.FloatTensor:
+ r"""
+ Returns:
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
+ applying the projection layer to the pooled output of [`GroupViTVisionModel`].
+
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoProcessor, GroupViTModel
+ >>> from transformers.image_utils import load_image
+
+ >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = load_image(url)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> with torch.inference_mode():
+ ... image_features = model.get_image_features(**inputs)
+ ```"""
+ vision_outputs: BaseModelOutputWithPooling = self.vision_model(pixel_values)
+ image_features = self.visual_projection(vision_outputs.pooler_output)
+ return image_features
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ return_loss: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_segmentation: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, GroupViTModelOutput]:
+ r"""
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+ output_segmentation (`bool`, *optional*):
+ Whether or not to return the segmentation logits.
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, GroupViTModel
+
+ >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
+ ... )
+
+ >>> outputs = model(**inputs)
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
+ ```"""
+ # Use GROUPVIT model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_segmentation = (
+ output_segmentation if output_segmentation is not None else self.config.output_segmentation
+ )
+ if output_segmentation:
+ output_attentions = True
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ image_embeds = vision_outputs[1]
+ image_embeds = self.visual_projection(image_embeds)
+
+ text_embeds = text_outputs[1]
+ text_embeds = self.text_projection(text_embeds)
+
+ # normalized features
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
+ logits_per_image = logits_per_text.t()
+
+ seg_logits = None
+ if output_segmentation:
+ # grouped features
+ # [batch_size_image, num_group, hidden_size]
+ image_group_embeds = vision_outputs[0]
+ # [batch_size_image*num_group, hidden_size]
+ image_group_embeds = self.visual_projection(image_group_embeds.reshape(-1, image_group_embeds.shape[-1]))
+ if output_hidden_states:
+ attentions = vision_outputs[3]
+ else:
+ attentions = vision_outputs[2]
+ # [batch_size_image, num_group, height, width]
+ grouping = get_grouping_from_attentions(attentions, pixel_values.shape[2:])
+
+ # normalized features
+ image_group_embeds = image_group_embeds / image_group_embeds.norm(dim=-1, keepdim=True)
+ # [batch_size_image x num_group, batch_size_text]
+ logits_per_image_group = torch.matmul(image_group_embeds, text_embeds.t()) * logit_scale
+ # [batch_size_image, batch_size_text, num_group]
+ logits_per_image_group = logits_per_image_group.reshape(
+ image_embeds.shape[0], -1, text_embeds.shape[0]
+ ).permute(0, 2, 1)
+
+ # [batch_size_image, batch_size_text, height x width]
+ flatten_grouping = grouping.reshape(grouping.shape[0], grouping.shape[1], -1)
+
+ # [batch_size_image, batch_size_text, height, width]
+ seg_logits = torch.matmul(logits_per_image_group, flatten_grouping) * logit_scale
+ seg_logits = seg_logits.reshape(
+ seg_logits.shape[0], seg_logits.shape[1], grouping.shape[2], grouping.shape[3]
+ )
+
+ loss = None
+ if return_loss:
+ loss = groupvit_loss(logits_per_text)
+
+ if not return_dict:
+ if seg_logits is not None:
+ output = (
+ logits_per_image,
+ logits_per_text,
+ seg_logits,
+ text_embeds,
+ image_embeds,
+ text_outputs,
+ vision_outputs,
+ )
+ else:
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
+ return ((loss,) + output) if loss is not None else output
+
+ return GroupViTModelOutput(
+ loss=loss,
+ logits_per_image=logits_per_image,
+ logits_per_text=logits_per_text,
+ segmentation_logits=seg_logits,
+ text_embeds=text_embeds,
+ image_embeds=image_embeds,
+ text_model_output=text_outputs,
+ vision_model_output=vision_outputs,
+ )
+
+
+__all__ = ["GroupViTModel", "GroupViTPreTrainedModel", "GroupViTTextModel", "GroupViTVisionModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/groupvit/modeling_tf_groupvit.py b/venv/lib/python3.13/site-packages/transformers/models/groupvit/modeling_tf_groupvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c999dca5f48faa2d49f0719d1c9ce1397efb72a
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/groupvit/modeling_tf_groupvit.py
@@ -0,0 +1,2141 @@
+# coding=utf-8
+# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 GroupViT model."""
+
+from __future__ import annotations
+
+import collections.abc
+import math
+from dataclasses import dataclass
+from typing import Any
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling
+from ...modeling_tf_utils import (
+ TFModelInputType,
+ TFPreTrainedModel,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_tensorflow_probability_available,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_groupvit import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+# soft dependency
+if is_tensorflow_probability_available():
+ try:
+ import tensorflow_probability as tfp
+
+ # On the first call, check whether a compatible version of TensorFlow is installed
+ # TensorFlow Probability depends on a recent stable release of TensorFlow
+ _ = tfp.distributions.Normal(loc=0.0, scale=1.0)
+ except ImportError:
+ logger.error(
+ "GroupViT models are not usable since `tensorflow_probability` can't be loaded. "
+ "It seems you have `tensorflow_probability` installed with the wrong tensorflow version."
+ "Please try to reinstall it following the instructions here: https://github.com/tensorflow/probability."
+ )
+else:
+ try:
+ import tensorflow_probability as tfp
+
+ # On the first call, check whether a compatible version of TensorFlow is installed
+ # TensorFlow Probability depends on a recent stable release of TensorFlow
+ _ = tfp.distributions.Normal(loc=0.0, scale=1.0)
+ except ImportError:
+ pass
+
+_CHECKPOINT_FOR_DOC = "nvidia/groupvit-gcc-yfcc"
+
+
+LARGE_NEGATIVE = -1e8
+
+
+# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
+def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ src_len = shape_list(mask)[1]
+ tgt_len = tgt_len if tgt_len is not None else src_len
+ one_cst = tf.constant(1.0)
+ mask = tf.cast(mask, dtype=one_cst.dtype)
+ expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
+
+ return (one_cst - expanded_mask) * LARGE_NEGATIVE
+
+
+# contrastive loss function, adapted from
+# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
+def contrastive_loss(logits: tf.Tensor) -> tf.Tensor:
+ return tf.math.reduce_mean(
+ keras.metrics.sparse_categorical_crossentropy(
+ y_true=tf.range(shape_list(logits)[0]), y_pred=logits, from_logits=True
+ )
+ )
+
+
+# Copied from transformers.models.clip.modeling_tf_clip.clip_loss with clip->groupvit
+def groupvit_loss(similarity: tf.Tensor) -> tf.Tensor:
+ caption_loss = contrastive_loss(similarity)
+ image_loss = contrastive_loss(tf.transpose(similarity))
+ return (caption_loss + image_loss) / 2.0
+
+
+def hard_softmax(logits: tf.Tensor, dim: int) -> tf.Tensor:
+ y_soft = stable_softmax(logits, dim)
+ # Straight through.
+ index = tf.argmax(y_soft, dim)
+ y_hard = tf.one_hot(
+ index,
+ depth=shape_list(logits)[dim],
+ # TensorFlow expects axis to be -1 or between [0, 3). But received: -2
+ # This is why the following code snippet is used.
+ axis=range(len(shape_list(logits)))[dim],
+ dtype=y_soft.dtype,
+ )
+ ret = y_hard - tf.stop_gradient(y_soft) + y_soft
+
+ return ret
+
+
+def gumbel_softmax(logits: tf.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> tf.Tensor:
+ gumbel_dist = tfp.distributions.Gumbel(0.0, 1.0)
+ gumbels = gumbel_dist.sample(tf.shape(logits), dtype=logits.dtype)
+
+ gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
+ y_soft = stable_softmax(gumbels, dim)
+
+ if hard:
+ # Straight through.
+ index = tf.argmax(y_soft, dim)
+ y_hard = tf.one_hot(
+ index,
+ depth=shape_list(logits)[dim],
+ # TensorFlow expects axis to be -1 or between [0, 3). But received: -2
+ # This is why the following code snippet is used.
+ axis=range(len(shape_list(logits)))[dim],
+ dtype=y_soft.dtype,
+ )
+ ret = y_hard - tf.stop_gradient(y_soft) + y_soft
+ else:
+ # Reparametrization trick.
+ ret = y_soft
+ return ret
+
+
+def resize_attention_map(attentions: tf.Tensor, height: int, width: int, align_corners: bool = False) -> tf.Tensor:
+ """
+ Args:
+ attentions (`tf.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width]
+ height (`int`): height of the output attention map
+ width (`int`): width of the output attention map
+ align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`.
+
+ Returns:
+ `tf.Tensor`: resized attention map of shape [batch_size, groups, height, width]
+ """
+
+ scale = (height * width // attentions.shape[2]) ** 0.5
+ if height > width:
+ feat_width = int(np.round(width / scale))
+ feat_height = shape_list(attentions)[2] // feat_width
+ else:
+ feat_height = int(np.round(height / scale))
+ feat_width = shape_list(attentions)[2] // feat_height
+
+ batch_size = shape_list(attentions)[0]
+ groups = shape_list(attentions)[1] # number of group token
+ # [batch_size, groups, height x width, groups] -> [batch_size, groups, height, width]
+ attentions = tf.reshape(attentions, (batch_size, groups, feat_height, feat_width))
+ attentions = tf.transpose(attentions, perm=(0, 2, 3, 1))
+ if align_corners:
+ attentions = tf.compat.v1.image.resize(
+ attentions,
+ size=(height, width),
+ method="bilinear",
+ align_corners=align_corners,
+ )
+ else:
+ attentions = tf.image.resize(attentions, size=(height, width), method="bilinear")
+ attentions = tf.transpose(attentions, perm=(0, 3, 1, 2))
+ return attentions
+
+
+def get_grouping_from_attentions(attentions: tuple[tf.Tensor], hw_shape: tuple[int]) -> tf.Tensor:
+ """
+ Args:
+ attentions (`tuple(tf.Tensor)`: tuple of attention maps returned by `TFGroupViTVisionTransformer`
+ hw_shape (`tuple(int)`): height and width of the output attention map
+ Returns:
+ `tf.Tensor`: the attention map of shape [batch_size, groups, height, width]
+ """
+
+ attn_maps = []
+ prev_attn_masks = None
+ for attn_masks in attentions:
+ # [batch_size, num_groups, height x width] -> [batch_size, height x width, num_groups]
+ attn_masks = tf.transpose(attn_masks, perm=(0, 2, 1))
+ if prev_attn_masks is None:
+ prev_attn_masks = attn_masks
+ else:
+ prev_attn_masks = tf.matmul(prev_attn_masks, attn_masks)
+ # [batch_size, height x width, num_groups] -> [batch_size, num_groups, height x width] -> [batch_size, num_groups, height, width]
+ cur_attn_map = resize_attention_map(tf.transpose(prev_attn_masks, perm=(0, 2, 1)), *hw_shape)
+ attn_maps.append(cur_attn_map)
+
+ # [batch_size, num_groups, height, width]
+ final_grouping = attn_maps[-1]
+
+ return tf.stop_gradient(final_grouping)
+
+
+@dataclass
+class TFGroupViTModelOutput(ModelOutput):
+ """
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
+ Contrastive loss for image-text similarity.
+ logits_per_image (`tf.Tensor` of shape `(image_batch_size, text_batch_size)`):
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
+ similarity scores.
+ logits_per_text (`tf.Tensor` of shape `(text_batch_size, image_batch_size)`):
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
+ similarity scores.
+ segmentation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
+ Classification scores for each pixel.
+
+
+
+ The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
+ to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
+ original image size as post-processing. You should always check your logits shape and resize as needed.
+
+
+
+ text_embeds (`tf.Tensor` of shape `(batch_size, output_dim`):
+ The text embeddings obtained by applying the projection layer to the pooled output of
+ [`TFGroupViTTextModel`].
+ image_embeds (`tf.Tensor` of shape `(batch_size, output_dim`):
+ The image embeddings obtained by applying the projection layer to the pooled output of
+ [`TFGroupViTVisionModel`].
+ text_model_output (`TFBaseModelOutputWithPooling`):
+ The output of the [`TFGroupViTTextModel`].
+ vision_model_output (`TFBaseModelOutputWithPooling`):
+ The output of the [`TFGroupViTVisionModel`].
+ """
+
+ loss: tf.Tensor | None = None
+ logits_per_image: tf.Tensor | None = None
+ logits_per_text: tf.Tensor | None = None
+ segmentation_logits: tf.Tensor | None = None
+ text_embeds: tf.Tensor | None = None
+ image_embeds: tf.Tensor | None = None
+ text_model_output: TFBaseModelOutputWithPooling = None
+ vision_model_output: TFBaseModelOutputWithPooling = None
+
+ def to_tuple(self) -> tuple[Any]:
+ return tuple(
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
+ for k in self.keys()
+ )
+
+
+class TFGroupViTCrossAttentionLayer(keras.layers.Layer):
+ def __init__(self, config: GroupViTVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.attn = TFGroupViTAttention(config, name="attn")
+ self.norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm2")
+ self.mlp = TFGroupViTMLP(config, name="mlp")
+ self.norm_post = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_post")
+ self.config = config
+
+ def call(self, query: tf.Tensor, key: tf.Tensor, training: bool = False) -> tf.Tensor:
+ x = query
+ x = x + self.attn(query, encoder_hidden_states=key)[0]
+ x = x + self.mlp(self.norm2(x))
+ x = self.norm_post(x)
+ return x
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attn", None) is not None:
+ with tf.name_scope(self.attn.name):
+ self.attn.build(None)
+ if getattr(self, "norm2", None) is not None:
+ with tf.name_scope(self.norm2.name):
+ self.norm2.build([None, None, self.config.hidden_size])
+ if getattr(self, "mlp", None) is not None:
+ with tf.name_scope(self.mlp.name):
+ self.mlp.build(None)
+ if getattr(self, "norm_post", None) is not None:
+ with tf.name_scope(self.norm_post.name):
+ self.norm_post.build([None, None, self.config.hidden_size])
+
+
+class TFGroupViTAssignAttention(keras.layers.Layer):
+ def __init__(self, config: GroupViTVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.scale = config.hidden_size**-0.5
+
+ self.q_proj = keras.layers.Dense(config.hidden_size, name="q_proj")
+ self.k_proj = keras.layers.Dense(config.hidden_size, name="k_proj")
+ self.v_proj = keras.layers.Dense(config.hidden_size, name="v_proj")
+ self.proj = keras.layers.Dense(config.hidden_size, name="proj")
+ self.assign_eps = config.assign_eps
+ self.config = config
+
+ def get_attn(self, attn: tf.Tensor, gumbel: bool = True, hard: bool = True, training: bool = False) -> tf.Tensor:
+ if gumbel and training:
+ attn = gumbel_softmax(attn, dim=-2, hard=hard)
+ else:
+ if hard:
+ attn = hard_softmax(attn, dim=-2)
+ else:
+ attn = stable_softmax(attn, axis=-2)
+
+ return attn
+
+ def call(self, query: tf.Tensor, key: tf.Tensor, training: bool = False):
+ value = key
+ # [batch_size, query_length, channels]
+ query = self.q_proj(query)
+
+ # [batch_size, key_length, channels]
+ key = self.k_proj(key)
+
+ # [batch_size, key_length, channels]
+ value = self.v_proj(value)
+
+ # [batch_size, query_length, key_length]
+ raw_attn = tf.matmul(query, key, transpose_b=True) * self.scale
+
+ attn = self.get_attn(raw_attn, training=training)
+ soft_attn = self.get_attn(raw_attn, training=training, gumbel=False, hard=False)
+
+ attn = attn / (tf.math.reduce_sum(attn, axis=-1, keepdims=True) + self.assign_eps)
+
+ out = tf.matmul(attn, value)
+
+ out = self.proj(out)
+
+ return out, soft_attn
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "q_proj", None) is not None:
+ with tf.name_scope(self.q_proj.name):
+ self.q_proj.build([None, None, self.config.hidden_size])
+ if getattr(self, "k_proj", None) is not None:
+ with tf.name_scope(self.k_proj.name):
+ self.k_proj.build([None, None, self.config.hidden_size])
+ if getattr(self, "v_proj", None) is not None:
+ with tf.name_scope(self.v_proj.name):
+ self.v_proj.build([None, None, self.config.hidden_size])
+ if getattr(self, "proj", None) is not None:
+ with tf.name_scope(self.proj.name):
+ self.proj.build([None, None, self.config.hidden_size])
+
+
+class TFGroupViTTokenAssign(keras.layers.Layer):
+ def __init__(self, config: GroupViTVisionConfig, num_group_token: int, num_output_group: int, **kwargs):
+ super().__init__(**kwargs)
+ self.num_output_group = num_output_group
+ # norm on group_tokens
+ self.norm_tokens = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_tokens")
+ assign_mlp_ratio = (
+ config.assign_mlp_ratio
+ if isinstance(config.assign_mlp_ratio, collections.abc.Iterable)
+ else (config.assign_mlp_ratio, config.assign_mlp_ratio)
+ )
+ tokens_dim, channels_dim = [int(x * config.hidden_size) for x in assign_mlp_ratio]
+ self.mlp_inter = TFGroupViTMixerMLP(config, num_group_token, tokens_dim, num_output_group, name="mlp_inter")
+ self.norm_post_tokens = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_post_tokens")
+ # norm on x
+ self.norm_x = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_x")
+ self.pre_assign_attn = TFGroupViTCrossAttentionLayer(config, name="pre_assign_attn")
+
+ self.assign = TFGroupViTAssignAttention(config, name="assign")
+ self.norm_new_x = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_new_x")
+ self.mlp_channels = TFGroupViTMLP(
+ config, config.hidden_size, channels_dim, config.hidden_size, name="mlp_channels"
+ )
+ self.config = config
+
+ def project_group_token(self, group_tokens: tf.Tensor) -> tf.Tensor:
+ """
+ Args:
+ group_tokens (tf.Tensor): group tokens, [batch_size, num_group_tokens, channels]
+
+ Returns:
+ projected_group_tokens (tf.Tensor): [batch_size, num_output_groups, channels]
+ """
+ # [B, num_output_groups, C] <- [B, num_group_tokens, C]
+ projected_group_tokens = self.mlp_inter(group_tokens)
+ projected_group_tokens = self.norm_post_tokens(projected_group_tokens)
+ return projected_group_tokens
+
+ def call(self, image_tokens: tf.Tensor, group_tokens: tf.Tensor, training: bool = False):
+ """
+ Args:
+ image_tokens (`tf.Tensor`): image tokens, of shape [batch_size, input_length, channels]
+ group_tokens (`tf.Tensor`): group tokens, [batch_size, num_group_tokens, channels]
+ """
+
+ group_tokens = self.norm_tokens(group_tokens)
+ image_tokens = self.norm_x(image_tokens)
+ # [batch_size, num_output_groups, channels]
+ projected_group_tokens = self.project_group_token(group_tokens)
+ projected_group_tokens = self.pre_assign_attn(projected_group_tokens, image_tokens)
+ new_image_tokens, attention = self.assign(projected_group_tokens, image_tokens)
+ new_image_tokens += projected_group_tokens
+
+ new_image_tokens = new_image_tokens + self.mlp_channels(self.norm_new_x(new_image_tokens))
+
+ return new_image_tokens, attention
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "norm_tokens", None) is not None:
+ with tf.name_scope(self.norm_tokens.name):
+ self.norm_tokens.build([None, None, self.config.hidden_size])
+ if getattr(self, "mlp_inter", None) is not None:
+ with tf.name_scope(self.mlp_inter.name):
+ self.mlp_inter.build(None)
+ if getattr(self, "norm_post_tokens", None) is not None:
+ with tf.name_scope(self.norm_post_tokens.name):
+ self.norm_post_tokens.build([None, None, self.config.hidden_size])
+ if getattr(self, "norm_x", None) is not None:
+ with tf.name_scope(self.norm_x.name):
+ self.norm_x.build([None, None, self.config.hidden_size])
+ if getattr(self, "pre_assign_attn", None) is not None:
+ with tf.name_scope(self.pre_assign_attn.name):
+ self.pre_assign_attn.build(None)
+ if getattr(self, "assign", None) is not None:
+ with tf.name_scope(self.assign.name):
+ self.assign.build(None)
+ if getattr(self, "norm_new_x", None) is not None:
+ with tf.name_scope(self.norm_new_x.name):
+ self.norm_new_x.build([None, None, self.config.hidden_size])
+ if getattr(self, "mlp_channels", None) is not None:
+ with tf.name_scope(self.mlp_channels.name):
+ self.mlp_channels.build(None)
+
+
+# Adapted from transformers.models.vit.modeling_tf_vit.TFViTPatchEmbeddings with ViT->GroupViT
+class TFGroupViTPatchEmbeddings(keras.layers.Layer):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config: GroupViTConfig, **kwargs):
+ super().__init__(**kwargs)
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels = config.num_channels
+ # hidden_size is a member as it will be required in the call method
+ self.hidden_size = config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ self.num_channels = num_channels
+ self.config = config
+
+ self.projection = keras.layers.Conv2D(
+ filters=self.hidden_size,
+ kernel_size=patch_size,
+ strides=patch_size,
+ padding="valid",
+ data_format="channels_last",
+ use_bias=True,
+ kernel_initializer=get_initializer(self.config.initializer_range),
+ bias_initializer="zeros",
+ name="projection",
+ )
+
+ def call(
+ self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
+ ) -> tf.Tensor:
+ batch_size, num_channels, height, width = shape_list(pixel_values)
+ if tf.executing_eagerly() and num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ if (
+ not interpolate_pos_encoding
+ and tf.executing_eagerly()
+ and (height != self.image_size[0] or width != self.image_size[1])
+ ):
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+ )
+
+ # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
+ # So change the input format from `NCHW` to `NHWC`.
+ # shape = (batch_size, in_height, in_width, in_channels=num_channels)
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+
+ projection = self.projection(pixel_values)
+
+ # Change the 2D spatial dimensions to a single temporal dimension.
+ # shape = (batch_size, num_patches, out_channels=embed_dim)
+ num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
+ # In the TFGroupViTVisionEmbeddings the embeddings from this layer will be layer normalized
+ # LayerNormalization layer needs to have static last dimension (otherwise the test_keras_save_load fails with symbolic tensors)
+ # This is why we have used the hidden_size in the reshape method
+ embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, self.hidden_size))
+
+ return embeddings
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "projection", None) is not None:
+ with tf.name_scope(self.projection.name):
+ self.projection.build([None, None, None, self.num_channels])
+
+
+# Adapted from transformers.vit.modeling_tf_vit.TFViTEmbeddings
+class TFGroupViTVisionEmbeddings(keras.layers.Layer):
+ """
+ Construct the position and patch embeddings.
+
+ """
+
+ def __init__(self, config: GroupViTVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.patch_embeddings = TFGroupViTPatchEmbeddings(config, name="patch_embeddings")
+ self.dropout = keras.layers.Dropout(rate=config.dropout, name="dropout")
+ self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+ self.config = config
+
+ def build(self, input_shape=None):
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = self.add_weight(
+ shape=(1, num_patches, self.config.hidden_size),
+ initializer="zeros",
+ trainable=True,
+ name="position_embeddings",
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "patch_embeddings", None) is not None:
+ with tf.name_scope(self.patch_embeddings.name):
+ self.patch_embeddings.build(None)
+ if getattr(self, "dropout", None) is not None:
+ with tf.name_scope(self.dropout.name):
+ self.dropout.build(None)
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, None, self.config.hidden_size])
+
+ def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+
+ batch_size, num_patches, dim = shape_list(embeddings)
+ num_positions = shape_list(self.position_embeddings)[1]
+
+ if num_patches == num_positions and height == width:
+ return self.position_embeddings
+ patch_pos_embed = self.position_embeddings
+ h0 = height // self.config.patch_size
+ w0 = width // self.config.patch_size
+ patch_pos_embed = tf.image.resize(
+ images=tf.reshape(
+ patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+ ),
+ size=(h0, w0),
+ method="bicubic",
+ )
+ patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
+ return patch_pos_embed
+
+ def call(
+ self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
+ ) -> tf.Tensor:
+ _, _, height, width = shape_list(pixel_values)
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+ embeddings = self.layernorm(embeddings)
+
+ # add positional encoding to each token
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextEmbeddings with CLIP->GroupViT
+class TFGroupViTTextEmbeddings(keras.layers.Layer):
+ def __init__(self, config: GroupViTTextConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.embed_dim = config.hidden_size
+
+ self.config = config
+
+ def build(self, input_shape: tf.TensorShape = None):
+ with tf.name_scope("token_embedding"):
+ self.weight = self.add_weight(
+ shape=(self.config.vocab_size, self.embed_dim),
+ initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range),
+ trainable=True,
+ name="weight",
+ )
+
+ with tf.name_scope("position_embedding"):
+ self.position_embedding = self.add_weight(
+ shape=(self.config.max_position_embeddings, self.embed_dim),
+ initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range),
+ trainable=True,
+ name="embeddings",
+ )
+
+ super().build(input_shape)
+
+ def call(
+ self,
+ input_ids: tf.Tensor | None = None,
+ position_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ ) -> tf.Tensor:
+ """
+ Applies embedding based on inputs tensor.
+
+ Returns:
+ final_embeddings (`tf.Tensor`): output embedding tensor.
+ """
+ if input_ids is None and inputs_embeds is None:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
+ inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
+
+ input_shape = shape_list(inputs_embeds)[:-1]
+
+ if position_ids is None:
+ position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
+
+ position_embeds = tf.gather(params=self.position_embedding, indices=position_ids)
+ position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
+ final_embeddings = inputs_embeds + position_embeds
+
+ return final_embeddings
+
+
+class TFGroupViTStage(keras.layers.Layer):
+ """This corresponds to the `GroupingLayer` class in the GroupViT implementation."""
+
+ def __init__(
+ self,
+ config: GroupViTVisionConfig,
+ depth: int,
+ num_prev_group_token: int,
+ num_group_token: int,
+ num_output_group: int,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.config = config
+ self.depth = depth
+ self.num_group_token = num_group_token
+ self.layers = [TFGroupViTEncoderLayer(config, name=f"layers_._{i}") for i in range(depth)]
+
+ if num_group_token > 0:
+ self.downsample = TFGroupViTTokenAssign(
+ config=config,
+ num_group_token=num_group_token,
+ num_output_group=num_output_group,
+ name="downsample",
+ )
+ else:
+ self.downsample = None
+
+ if num_prev_group_token > 0 and num_group_token > 0:
+ self.group_projector = [
+ keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="group_projector.0"),
+ TFGroupViTMixerMLP(
+ config, num_prev_group_token, config.hidden_size // 2, num_group_token, name="group_projector.1"
+ ),
+ ]
+ else:
+ self.group_projector = None
+
+ def build(self, input_shape=None):
+ if self.num_group_token > 0:
+ self.group_token = self.add_weight(
+ shape=(1, self.num_group_token, self.config.hidden_size),
+ initializer="zeros",
+ trainable=True,
+ name="group_token",
+ )
+ else:
+ self.group_token = None
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "downsample", None) is not None:
+ with tf.name_scope(self.downsample.name):
+ self.downsample.build(None)
+ if getattr(self, "layers", None) is not None:
+ for layer in self.layers:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+ if getattr(self, "group_projector", None) is not None:
+ with tf.name_scope(self.group_projector[0].name):
+ self.group_projector[0].build([None, None, self.config.hidden_size])
+ with tf.name_scope(self.group_projector[1].name):
+ self.group_projector[1].build(None)
+
+ @property
+ def with_group_token(self):
+ return self.group_token is not None
+
+ def split_x(self, x: tf.Tensor) -> tf.Tensor:
+ if self.with_group_token:
+ return x[:, : -self.num_group_token], x[:, -self.num_group_token :]
+ else:
+ return x, None
+
+ def concat_x(self, x: tf.Tensor, group_token: tf.Tensor | None = None) -> tf.Tensor:
+ if group_token is None:
+ return x
+ return tf.concat([x, group_token], axis=1)
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ prev_group_token: tf.Tensor | None = None,
+ output_attentions: bool = False,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ """
+ Args:
+ hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`tf.Tensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the grouping tensors of Grouping block.
+ """
+ if self.with_group_token:
+ group_token = tf.tile(self.group_token, multiples=(shape_list(hidden_states)[0], 1, 1))
+ if self.group_projector is not None:
+ for layer in self.group_projector:
+ prev_group_token = layer(prev_group_token)
+ group_token = group_token + prev_group_token
+ else:
+ group_token = None
+
+ x = hidden_states
+
+ cat_x = self.concat_x(x, group_token)
+ for layer in self.layers:
+ layer_out = layer(
+ cat_x,
+ attention_mask=None,
+ causal_attention_mask=None,
+ output_attentions=None,
+ )
+ cat_x = layer_out[0]
+
+ x, group_token = self.split_x(cat_x)
+
+ attention = None
+ if self.downsample is not None:
+ x, attention = self.downsample(x, group_token)
+
+ outputs = (x, group_token)
+ if output_attentions:
+ outputs = outputs + (attention,)
+
+ return outputs
+
+
+class TFGroupViTMLP(keras.layers.Layer):
+ def __init__(
+ self,
+ config: GroupViTVisionConfig,
+ hidden_size: int | None = None,
+ intermediate_size: int | None = None,
+ output_size: int | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.config = config
+ self.activation_fn = get_tf_activation(config.hidden_act)
+ hidden_size = hidden_size if hidden_size is not None else config.hidden_size
+ intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
+ output_size = output_size if output_size is not None else hidden_size
+ self.fc1 = keras.layers.Dense(intermediate_size, name="fc1")
+ self.fc2 = keras.layers.Dense(output_size, name="fc2")
+ self.intermediate_size = intermediate_size
+ self.hidden_size = hidden_size
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "fc1", None) is not None:
+ with tf.name_scope(self.fc1.name):
+ self.fc1.build([None, None, self.hidden_size])
+ if getattr(self, "fc2", None) is not None:
+ with tf.name_scope(self.fc2.name):
+ self.fc2.build([None, None, self.intermediate_size])
+
+
+class TFGroupViTMixerMLP(TFGroupViTMLP):
+ def call(self, x, training: bool = False):
+ x = super().call(hidden_states=tf.transpose(x, perm=(0, 2, 1)))
+ return tf.transpose(x, perm=(0, 2, 1))
+
+
+# Adapted from transformers.models.clip.modeling_tf_clip.TFCLIPAttention
+class TFGroupViTAttention(keras.layers.Layer):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: GroupViTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.embed_dim = config.hidden_size
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = self.embed_dim // self.num_attention_heads
+ if self.attention_head_size * self.num_attention_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_attention_heads})."
+ )
+
+ factor = config.initializer_factor
+ in_proj_std = (self.embed_dim**-0.5) * ((2 * config.num_hidden_layers) ** -0.5) * factor
+ out_proj_std = (self.embed_dim**-0.5) * factor
+
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
+
+ self.q_proj = keras.layers.Dense(
+ units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="q_proj"
+ )
+ self.k_proj = keras.layers.Dense(
+ units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="k_proj"
+ )
+ self.v_proj = keras.layers.Dense(
+ units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="v_proj"
+ )
+
+ self.dropout = keras.layers.Dropout(rate=config.attention_dropout)
+
+ self.out_proj = keras.layers.Dense(
+ units=self.embed_dim, kernel_initializer=get_initializer(out_proj_std), name="out_proj"
+ )
+
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention.transpose_for_scores
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
+
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor | None = None,
+ causal_attention_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ encoder_hidden_states: tf.Tensor | None = None,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size = shape_list(hidden_states)[0]
+ is_cross_attention = encoder_hidden_states is not None
+
+ mixed_query_layer = self.q_proj(inputs=hidden_states)
+ if is_cross_attention:
+ mixed_key_layer = self.k_proj(inputs=encoder_hidden_states)
+ mixed_value_layer = self.v_proj(inputs=encoder_hidden_states)
+ else:
+ mixed_key_layer = self.k_proj(inputs=hidden_states)
+ mixed_value_layer = self.v_proj(inputs=hidden_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
+ key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
+ value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ # (batch size, num_heads, seq_len_q, seq_len_k)
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+ dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
+ attention_scores = tf.divide(attention_scores, dk)
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ # Apply the causal attention mask (precomputed for all layers in TFCLIPModel call() function)
+ attention_scores = tf.add(attention_scores, causal_attention_mask)
+
+ if attention_mask is not None:
+ # Apply the attention mask (precomputed for all layers in TFCLIPModel call() function)
+ attention_scores = tf.add(attention_scores, attention_mask)
+
+ # Normalize the attention scores to probabilities.
+ _attention_probs = stable_softmax(logits=attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(inputs=_attention_probs)
+
+ attention_output = tf.matmul(attention_probs, value_layer)
+ attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
+
+ # (batch_size, seq_len_q, embed_dim)
+ attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.embed_dim))
+
+ attention_output = self.out_proj(attention_output)
+ # In TFBert, attention weights are returned after dropout.
+ # However, in CLIP, they are returned before dropout.
+ outputs = (attention_output, _attention_probs) if output_attentions else (attention_output,)
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "q_proj", None) is not None:
+ with tf.name_scope(self.q_proj.name):
+ self.q_proj.build([None, None, self.embed_dim])
+ if getattr(self, "k_proj", None) is not None:
+ with tf.name_scope(self.k_proj.name):
+ self.k_proj.build([None, None, self.embed_dim])
+ if getattr(self, "v_proj", None) is not None:
+ with tf.name_scope(self.v_proj.name):
+ self.v_proj.build([None, None, self.embed_dim])
+ if getattr(self, "out_proj", None) is not None:
+ with tf.name_scope(self.out_proj.name):
+ self.out_proj.build([None, None, self.embed_dim])
+
+
+# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPEncoderLayer with CLIP->GroupViT
+class TFGroupViTEncoderLayer(keras.layers.Layer):
+ def __init__(self, config: GroupViTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.embed_dim = config.hidden_size
+ self.self_attn = TFGroupViTAttention(config, name="self_attn")
+ self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
+ self.mlp = TFGroupViTMLP(config, name="mlp")
+ self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ causal_attention_mask: tf.Tensor,
+ output_attentions: bool,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ """
+ Args:
+ hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`tf.Tensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ causal_attention_mask (`tf.Tensor`): causal attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`):
+ Whether or not to return the attentions tensors of all attention layers. See `outputs` under returned
+ tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(inputs=hidden_states)
+ attention_outputs = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ hidden_states = attention_outputs[0]
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(inputs=hidden_states)
+ hidden_states = self.mlp(hidden_states=hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,) + attention_outputs[1:] # add attentions if we output them
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "self_attn", None) is not None:
+ with tf.name_scope(self.self_attn.name):
+ self.self_attn.build(None)
+ if getattr(self, "layer_norm1", None) is not None:
+ with tf.name_scope(self.layer_norm1.name):
+ self.layer_norm1.build([None, None, self.embed_dim])
+ if getattr(self, "mlp", None) is not None:
+ with tf.name_scope(self.mlp.name):
+ self.mlp.build(None)
+ if getattr(self, "layer_norm2", None) is not None:
+ with tf.name_scope(self.layer_norm2.name):
+ self.layer_norm2.build([None, None, self.embed_dim])
+
+
+# Adapted from transformers.models.clip.modeling_tf_clip.TFGroupViTTextEncoder
+class TFGroupViTTextEncoder(keras.layers.Layer):
+ def __init__(self, config: GroupViTTextConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.layers = [TFGroupViTEncoderLayer(config, name=f"layers_._{i}") for i in range(config.num_hidden_layers)]
+
+ def call(
+ self,
+ hidden_states,
+ attention_mask: tf.Tensor,
+ causal_attention_mask: tf.Tensor,
+ output_attentions: bool,
+ output_hidden_states: bool,
+ return_dict: bool,
+ training: bool = False,
+ ) -> tuple | TFBaseModelOutput:
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layers", None) is not None:
+ for layer in self.layers:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+class TFGroupViTVisionEncoder(keras.layers.Layer):
+ def __init__(self, config: GroupViTVisionConfig, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ self.stages = [
+ TFGroupViTStage(
+ config=config,
+ depth=config.depths[i],
+ num_group_token=config.num_group_tokens[i],
+ num_output_group=config.num_output_groups[i],
+ num_prev_group_token=config.num_output_groups[i - 1] if i > 0 else 0,
+ name=f"stages_._{i}",
+ )
+ for i in range(len(config.depths))
+ ]
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ output_hidden_states: bool,
+ output_attentions: bool,
+ return_dict: bool,
+ training: bool = False,
+ ) -> tuple | TFBaseModelOutput:
+ all_hidden_states = () if output_hidden_states else None
+ all_groupings = () if output_attentions else None
+
+ group_tokens = None
+
+ for stage in self.stages:
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_outputs = stage(hidden_states, group_tokens, output_attentions)
+
+ hidden_states = layer_outputs[0]
+ group_tokens = layer_outputs[1]
+
+ if output_attentions and layer_outputs[2] is not None:
+ all_groupings = all_groupings + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_groupings] if v is not None)
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_groupings
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "stages", None) is not None:
+ for layer in self.stages:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextTransformer with CLIPText->GroupViTText, CLIPEncoder->GroupViTTextEncoder
+class TFGroupViTTextTransformer(keras.layers.Layer):
+ def __init__(self, config: GroupViTTextConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.embeddings = TFGroupViTTextEmbeddings(config, name="embeddings")
+ self.encoder = TFGroupViTTextEncoder(config, name="encoder")
+ self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm")
+
+ # For `pooled_output` computation
+ self.eos_token_id = config.eos_token_id
+ self.embed_dim = config.hidden_size
+
+ def call(
+ self,
+ input_ids: TFModelInputType,
+ attention_mask: tf.Tensor,
+ position_ids: tf.Tensor,
+ output_attentions: bool,
+ output_hidden_states: bool,
+ return_dict: bool,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
+ input_shape = shape_list(input_ids)
+
+ embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids)
+
+ batch_size, seq_length = input_shape
+ # CLIP's text model uses causal mask, prepare it here.
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
+ causal_attention_mask = self._build_causal_attention_mask(batch_size, seq_length, dtype=embedding_output.dtype)
+
+ # check attention mask and invert
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask)
+
+ encoder_outputs = self.encoder(
+ hidden_states=embedding_output,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.final_layer_norm(inputs=sequence_output)
+
+ if self.eos_token_id == 2:
+ # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
+ # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
+ # ------------------------------------------------------------
+ # text_embeds.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ pooled_output = tf.gather_nd(
+ params=sequence_output,
+ indices=tf.stack(
+ values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1
+ ),
+ )
+ else:
+ # The config gets updated `eos_token_id` from PR #24773 (so the use of extra new tokens is possible)
+ pooled_output = tf.gather_nd(
+ params=sequence_output,
+ indices=tf.stack(
+ values=(
+ tf.range(input_shape[0], dtype=tf.int64),
+ tf.math.argmax(tf.cast(input_ids == self.eos_token_id, dtype=tf.int8), axis=-1),
+ ),
+ axis=1,
+ ),
+ )
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return TFBaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+ def _build_causal_attention_mask(self, batch_size, seq_length, dtype=tf.float32):
+ # It is possible with an unspecified sequence length for seq_length to be
+ # a runtime value, which is unsupported by tf.constant. Per the TensorFlow
+ # docs, tf.fill can handle runtime dynamic shapes:
+ # https://www.tensorflow.org/api_docs/python/tf/fill
+ diag = tf.cast(tf.fill((seq_length,), 0.0), dtype)
+
+ # set an additive 2D attention mask with all places being masked
+ to_mask = tf.cast(tf.fill((seq_length, seq_length), -10000.0), dtype)
+
+ # set diagonal & lower triangular parts to 0 (i.e. the places not to be masked)
+ # TIP: think the 2D matrix as the space of (query_seq, key_seq)
+ to_mask = tf.linalg.band_part(to_mask, 0, -1)
+ # to_mask = tf.linalg.band_part(to_mask, -1, 0)
+ to_mask = tf.linalg.set_diag(to_mask, diagonal=diag)
+
+ return tf.broadcast_to(input=to_mask, shape=(batch_size, 1, seq_length, seq_length))
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "final_layer_norm", None) is not None:
+ with tf.name_scope(self.final_layer_norm.name):
+ self.final_layer_norm.build([None, None, self.embed_dim])
+
+
+# Adapted from transformers.models.clip.modeling_tf_clip.TFCLIPVisionTransformer
+class TFGroupViTVisionTransformer(keras.layers.Layer):
+ def __init__(self, config: GroupViTVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.embeddings = TFGroupViTVisionEmbeddings(config, name="embeddings")
+ self.encoder = TFGroupViTVisionEncoder(config, name="encoder")
+ self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+ self.embed_dim = config.hidden_size
+
+ def call(
+ self,
+ pixel_values: TFModelInputType,
+ output_attentions: bool,
+ output_hidden_states: bool,
+ return_dict: bool,
+ training: bool = False,
+ ) -> tuple | TFBaseModelOutputWithPooling:
+ embedding_output = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(
+ hidden_states=embedding_output,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+
+ # normalize the last hidden state
+ last_hidden_state = self.layernorm(last_hidden_state)
+ pooled_output = tf.math.reduce_mean(last_hidden_state, axis=1)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return TFBaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, None, self.embed_dim])
+
+
+@keras_serializable
+# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextMainLayer with CLIP->GroupViT
+class TFGroupViTTextMainLayer(keras.layers.Layer):
+ config_class = GroupViTTextConfig
+
+ def __init__(self, config: GroupViTTextConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.text_model = TFGroupViTTextTransformer(config, name="text_model")
+
+ def get_input_embeddings(self) -> keras.layers.Layer:
+ return self.text_model.embeddings
+
+ def set_input_embeddings(self, value: tf.Variable):
+ self.text_model.embeddings.weight = value
+ self.text_model.embeddings.vocab_size = shape_list(value)[0]
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
+ if input_ids is None:
+ raise ValueError("You have to specify input_ids")
+
+ input_shape = shape_list(input_ids)
+
+ if attention_mask is None:
+ attention_mask = tf.fill(dims=input_shape, value=1)
+
+ text_model_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return text_model_outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "text_model", None) is not None:
+ with tf.name_scope(self.text_model.name):
+ self.text_model.build(None)
+
+
+@keras_serializable
+# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPVisionMainLayer with CLIP->GroupViT
+class TFGroupViTVisionMainLayer(keras.layers.Layer):
+ config_class = GroupViTVisionConfig
+
+ def __init__(self, config: GroupViTVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.vision_model = TFGroupViTVisionTransformer(config, name="vision_model")
+
+ def get_input_embeddings(self) -> keras.layers.Layer:
+ return self.vision_model.embeddings
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ vision_model_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return vision_model_outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "vision_model", None) is not None:
+ with tf.name_scope(self.vision_model.name):
+ self.vision_model.build(None)
+
+
+@keras_serializable
+# Adapted from transformers.models.clip.modeling_tf_clip.TFCLIPMainLayer
+class TFGroupViTMainLayer(keras.layers.Layer):
+ config_class = GroupViTConfig
+
+ def __init__(self, config: GroupViTConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ if not isinstance(config.text_config, GroupViTTextConfig):
+ raise TypeError(
+ "config.text_config is expected to be of type GroupViTTextConfig but is of type"
+ f" {type(config.text_config)}."
+ )
+
+ if not isinstance(config.vision_config, GroupViTVisionConfig):
+ raise TypeError(
+ "config.vision_config is expected to be of type GroupViTVisionConfig but is of type"
+ f" {type(config.vision_config)}."
+ )
+
+ self.config = config
+
+ text_config = config.text_config
+ vision_config = config.vision_config
+
+ self.projection_dim = config.projection_dim
+ self.projection_intermediate_dim = config.projection_intermediate_dim
+ self.text_embed_dim = text_config.hidden_size
+ self.vision_embed_dim = vision_config.hidden_size
+
+ self.text_model = TFGroupViTTextTransformer(text_config, name="text_model")
+ self.vision_model = TFGroupViTVisionTransformer(vision_config, name="vision_model")
+
+ self.visual_projection = [
+ keras.layers.Dense(self.projection_intermediate_dim, name="visual_projection.0"),
+ keras.layers.BatchNormalization(name="visual_projection.1", momentum=0.9, epsilon=1e-5),
+ keras.layers.ReLU(name="visual_projection.2"),
+ keras.layers.Dense(self.projection_dim, name="visual_projection.3"),
+ ]
+ self.text_projection = [
+ keras.layers.Dense(self.projection_intermediate_dim, name="text_projection.0"),
+ keras.layers.BatchNormalization(name="text_projection.1", momentum=0.9, epsilon=1e-5),
+ keras.layers.ReLU(name="text_projection.2"),
+ keras.layers.Dense(self.projection_dim, name="text_projection.3"),
+ ]
+
+ def build(self, input_shape=None):
+ self.logit_scale = self.add_weight(
+ shape=(1,),
+ initializer=keras.initializers.Constant(self.config.logit_scale_init_value),
+ trainable=True,
+ name="logit_scale",
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "text_model", None) is not None:
+ with tf.name_scope(self.text_model.name):
+ self.text_model.build(None)
+ if getattr(self, "vision_model", None) is not None:
+ with tf.name_scope(self.vision_model.name):
+ self.vision_model.build(None)
+ if getattr(self, "visual_projection", None) is not None:
+ with tf.name_scope(self.visual_projection[0].name):
+ self.visual_projection[0].build([None, None, None, self.vision_embed_dim])
+ with tf.name_scope(self.visual_projection[1].name):
+ self.visual_projection[1].build((None, self.projection_intermediate_dim))
+ with tf.name_scope(self.visual_projection[3].name):
+ self.visual_projection[3].build([None, None, None, self.projection_intermediate_dim])
+ if getattr(self, "text_projection", None) is not None:
+ with tf.name_scope(self.text_projection[0].name):
+ self.text_projection[0].build([None, None, None, self.text_embed_dim])
+ with tf.name_scope(self.text_projection[1].name):
+ self.text_projection[1].build((None, self.projection_intermediate_dim))
+ with tf.name_scope(self.text_projection[3].name):
+ self.text_projection[3].build([None, None, None, self.projection_intermediate_dim])
+
+ @unpack_inputs
+ def get_text_features(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> tf.Tensor:
+ if input_ids is None:
+ raise ValueError("You have to specify either input_ids")
+
+ input_shape = shape_list(input_ids)
+
+ if attention_mask is None:
+ attention_mask = tf.fill(dims=input_shape, value=1)
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ pooled_output = text_outputs[1]
+ for layer in self.text_projection:
+ pooled_output = layer(pooled_output)
+
+ text_features = pooled_output
+ return text_features
+
+ @unpack_inputs
+ def get_image_features(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> tf.Tensor:
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ pooled_output = vision_outputs[1]
+ for layer in self.visual_projection:
+ pooled_output = layer(pooled_output)
+
+ image_features = pooled_output
+ return image_features
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ pixel_values: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ return_loss: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ output_segmentation: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFGroupViTModelOutput | tuple[tf.Tensor]:
+ if input_ids is None:
+ raise ValueError("You have to specify either input_ids")
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ input_shape = shape_list(input_ids)
+
+ if attention_mask is None:
+ attention_mask = tf.fill(dims=input_shape, value=1)
+ if output_segmentation:
+ output_attentions = True
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ image_embeds = vision_outputs[1]
+ for layer in self.visual_projection:
+ image_embeds = layer(image_embeds)
+
+ text_embeds = text_outputs[1]
+ for layer in self.text_projection:
+ text_embeds = layer(text_embeds)
+
+ # normalized features
+ image_embeds = image_embeds / tf.norm(image_embeds, axis=-1, keepdims=True)
+ text_embeds = text_embeds / tf.norm(text_embeds, axis=-1, keepdims=True)
+
+ # cosine similarity as logits
+ logit_scale = tf.math.exp(self.logit_scale)
+ logits_per_text = tf.matmul(text_embeds, image_embeds, transpose_b=True) * logit_scale
+ logits_per_image = tf.transpose(logits_per_text)
+
+ seg_logits = None
+ if output_segmentation:
+ # grouped features
+ # [batch_size_image, num_group, hidden_size]
+ image_group_embeds = vision_outputs[0]
+ # [batch_size_image*num_group, hidden_size]
+ image_group_embeds = tf.reshape(image_group_embeds, shape=(-1, shape_list(image_group_embeds)[-1]))
+ for layer in self.visual_projection:
+ image_group_embeds = layer(image_group_embeds)
+ if output_hidden_states:
+ attentions = vision_outputs[3]
+ else:
+ attentions = vision_outputs[2]
+ # [batch_size_image, num_group, height, width]
+ grouping = get_grouping_from_attentions(attentions, pixel_values.shape[2:])
+
+ # normalized features
+ image_group_embeds = image_group_embeds / tf.norm(
+ tensor=image_group_embeds, ord="euclidean", axis=-1, keepdims=True
+ )
+ # [batch_size_image x num_group, batch_size_text]
+ logits_per_image_group = tf.matmul(image_group_embeds, text_embeds, transpose_b=True) * logit_scale
+ # [batch_size_image, batch_size_text, num_group]
+ logits_per_image_group = tf.reshape(
+ logits_per_image_group, shape=(image_embeds.shape[0], -1, text_embeds.shape[0])
+ )
+ logits_per_image_group = tf.transpose(logits_per_image_group, perm=(0, 2, 1))
+
+ # [batch_size_image, batch_size_text, height x width]
+ flatten_grouping = tf.reshape(grouping, shape=(shape_list(grouping)[0], shape_list(grouping)[1], -1))
+
+ # [batch_size_image, batch_size_text, height, width]
+ seg_logits = tf.matmul(logits_per_image_group, flatten_grouping) * logit_scale
+ seg_logits = tf.reshape(
+ seg_logits, shape=(seg_logits.shape[0], seg_logits.shape[1], grouping.shape[2], grouping.shape[3])
+ )
+
+ loss = None
+ if return_loss:
+ loss = groupvit_loss(logits_per_text)[None, ...]
+
+ if not return_dict:
+ if seg_logits is not None:
+ output = (
+ logits_per_image,
+ logits_per_text,
+ seg_logits,
+ text_embeds,
+ image_embeds,
+ text_outputs,
+ vision_outputs,
+ )
+ else:
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
+ return ((loss,) + output) if loss is not None else output
+
+ return TFGroupViTModelOutput(
+ loss=loss,
+ logits_per_image=logits_per_image,
+ logits_per_text=logits_per_text,
+ segmentation_logits=seg_logits,
+ text_embeds=text_embeds,
+ image_embeds=image_embeds,
+ text_model_output=text_outputs,
+ vision_model_output=vision_outputs,
+ )
+
+
+class TFGroupViTPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = GroupViTConfig
+ base_model_prefix = "groupvit"
+
+
+GROUPVIT_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TF 2.0 models accepts two formats as inputs:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional arguments.
+
+ This second option is useful when using [`keras.Model.fit`] method which currently requires having all the
+ tensors in the first argument of the model call function: `model(inputs)`.
+
+ If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the
+ first positional argument :
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+
+
+ Args:
+ config ([`GroupViTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+GROUPVIT_TEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False``):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+GROUPVIT_VISION_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]`, `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`CLIPImageProcessor.__call__`] for details.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False``):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+GROUPVIT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`CLIPImageProcessor.__call__`] for details.
+ attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False``):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+class TFGroupViTTextModel(TFGroupViTPreTrainedModel):
+ config_class = GroupViTTextConfig
+ main_input_name = "input_ids"
+
+ def __init__(self, config: GroupViTTextConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.groupvit = TFGroupViTTextMainLayer(config, name="groupvit")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=GroupViTTextConfig)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import CLIPTokenizer, TFGroupViTTextModel
+
+ >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> model = TFGroupViTTextModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+
+ outputs = self.groupvit(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "groupvit", None) is not None:
+ with tf.name_scope(self.groupvit.name):
+ self.groupvit.build(None)
+
+
+class TFGroupViTVisionModel(TFGroupViTPreTrainedModel):
+ config_class = GroupViTVisionConfig
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: GroupViTVisionConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.groupvit = TFGroupViTVisionMainLayer(config, name="groupvit")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(GROUPVIT_VISION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=GroupViTVisionConfig)
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, TFGroupViTVisionModel
+
+ >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> model = TFGroupViTVisionModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="tf")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
+ ```"""
+
+ outputs = self.groupvit(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "groupvit", None) is not None:
+ with tf.name_scope(self.groupvit.name):
+ self.groupvit.build(None)
+
+
+@add_start_docstrings(GROUPVIT_START_DOCSTRING)
+class TFGroupViTModel(TFGroupViTPreTrainedModel):
+ config_class = GroupViTConfig
+
+ def __init__(self, config: GroupViTConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.groupvit = TFGroupViTMainLayer(config, name="groupvit")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def get_text_features(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> tf.Tensor:
+ r"""
+ Returns:
+ text_features (`tf.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying
+ the projection layer to the pooled output of [`TFGroupViTTextModel`].
+
+ Examples:
+
+ ```python
+ >>> from transformers import CLIPTokenizer, TFGroupViTModel
+
+ >>> model = TFGroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf")
+ >>> text_features = model.get_text_features(**inputs)
+ ```"""
+
+ text_features = self.groupvit.get_text_features(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return text_features
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(GROUPVIT_VISION_INPUTS_DOCSTRING)
+ def get_image_features(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> tf.Tensor:
+ r"""
+ Returns:
+ image_features (`tf.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying
+ the projection layer to the pooled output of [`TFGroupViTVisionModel`].
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, TFGroupViTModel
+
+ >>> model = TFGroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="tf")
+
+ >>> image_features = model.get_image_features(**inputs)
+ ```"""
+
+ image_features = self.groupvit.get_image_features(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return image_features
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(GROUPVIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TFGroupViTModelOutput, config_class=GroupViTConfig)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ pixel_values: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ return_loss: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ output_segmentation: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFGroupViTModelOutput | tuple[tf.Tensor]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, TFGroupViTModel
+ >>> import tensorflow as tf
+
+ >>> model = TFGroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
+ >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="tf", padding=True
+ ... )
+
+ >>> outputs = model(**inputs)
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
+ >>> probs = tf.math.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities
+ ```"""
+
+ outputs = self.groupvit(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ return_loss=return_loss,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_segmentation=output_segmentation,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+ def serving_output(self, output: TFGroupViTModelOutput) -> TFGroupViTModelOutput:
+ # TODO: As is this currently fails with saved_model=True, because
+ # TensorFlow cannot trace through nested dataclasses. Reference:
+ # https://github.com/huggingface/transformers/pull/16886
+ return output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "groupvit", None) is not None:
+ with tf.name_scope(self.groupvit.name):
+ self.groupvit.build(None)
+
+
+__all__ = ["TFGroupViTModel", "TFGroupViTPreTrainedModel", "TFGroupViTTextModel", "TFGroupViTVisionModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hiera/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/hiera/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..841f13be4c0d2f48f54eecc916acd826395449af
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hiera/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_hiera import *
+ from .modeling_hiera import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hiera/configuration_hiera.py b/venv/lib/python3.13/site-packages/transformers/models/hiera/configuration_hiera.py
new file mode 100644
index 0000000000000000000000000000000000000000..2342d7e562a50de0c0937040a8e8279c7860e931
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hiera/configuration_hiera.py
@@ -0,0 +1,194 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Hiera model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class HieraConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`HieraModel`]. It is used to instantiate a Hiera
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Hiera
+ [facebook/hiera-base-224](https://huggingface.co/facebook/hiera-base-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ embed_dim (`int`, *optional*, defaults to 96):
+ Dimensionality of patch embedding.
+ image_size (`list(int)`, *optional*, defaults to `[224, 224]`):
+ The size (resolution) of input in the format (height, width) for images
+ and (frames, height, width) for videos.
+ patch_size (`list(int)`, *optional*, defaults to `[7, 7]`):
+ The size (resolution) of each patch.
+ patch_stride (`list(int)`, *optional*, defaults to `[4, 4]`):
+ The stride of the patch.
+ patch_padding (`list(int)`, *optional*, defaults to `[3, 3]`):
+ The padding of the patch.
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ The ratio of mlp hidden dim to embedding dim.
+ depths (`list(int)`, *optional*, defaults to `[2, 3, 16, 3]`):
+ Depth of each layer in the Transformer encoder.
+ num_heads (`list(int)`, *optional*, defaults to `[1, 2, 4, 8]`):
+ Number of attention heads in each layer of the Transformer encoder.
+ embed_dim_multiplier (`float`, *optional*, defaults to 2.0):
+ The multiplier to the dimensionality of patch embedding in each layer of the Transformer encoder.
+ num_query_pool (`int`, *optional*, defaults to 3):
+ The number of query pool stages.
+ query_stride (`list(int)`, *optional*, defaults to `[2, 2]`):
+ The stride of the query pool.
+ masked_unit_size (`list(int)`, *optional*, defaults to `[8, 8]`):
+ The size of the masked unit.
+ masked_unit_attention (`list(bool)`, *optional*, defaults to `[True, True, False, False]`):
+ Whether to use masked unit attention in each layer of the Transformer encoder.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ The drop path rate.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
+ `"selu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices and
+ the zero_initializer for initializing all bias vectors.
+ layer_norm_init (`float`, *optional*, defaults to 1.0):
+ The initial weight value for layer normalization layers.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ decoder_hidden_size (`int`, *optional*):
+ Dimensionality of decoder embeddings for MAE pretraining.
+ decoder_depth (`int`, *optional*):
+ Depth of the decoder for MAE pretraining.
+ decoder_num_heads (`int`, *optional*):
+ Number of attention heads in each layer of the decoder for MAE pretraining.
+ normalize_pixel_loss (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the pixel loss by the number of pixels.
+ mask_ratio (`float`, *optional*, defaults to 0.6):
+ The ratio of masked tokens in the input.
+ out_features (`list[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`list[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+
+
+ Example:
+
+ ```python
+ >>> from transformers import HieraConfig, HieraModel
+
+ >>> # Initializing a Hiera hiera-base-patch16-224 style configuration
+ >>> configuration = HieraConfig()
+
+ >>> # Initializing a model (with random weights) from the hiera-base-patch16-224 style configuration
+ >>> model = HieraModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "hiera"
+
+ attribute_map = {"num_hidden_layers": "num_layers"}
+
+ def __init__(
+ self,
+ embed_dim=96,
+ image_size=[224, 224],
+ patch_size=[7, 7],
+ patch_stride=[4, 4],
+ patch_padding=[3, 3],
+ mlp_ratio=4.0,
+ depths=[2, 3, 16, 3],
+ num_heads=[1, 2, 4, 8],
+ embed_dim_multiplier=2.0,
+ num_query_pool=3,
+ query_stride=[2, 2],
+ masked_unit_size=[8, 8],
+ masked_unit_attention=[True, True, False, False],
+ drop_path_rate=0.0,
+ num_channels=3,
+ hidden_act="gelu",
+ initializer_range=0.02,
+ layer_norm_init=1.0,
+ layer_norm_eps=1e-6,
+ decoder_hidden_size=None,
+ decoder_depth=None,
+ decoder_num_heads=None,
+ normalize_pixel_loss=True,
+ mask_ratio=0.6,
+ out_features=None,
+ out_indices=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if masked_unit_size[0] % query_stride[0] ** (len(depths) - 1) != 0:
+ raise ValueError(
+ f"masked_unit_size[0] ({masked_unit_size[0]}) must be divisible by query_stride[0] ({query_stride[0]}) "
+ f"raised to the power of the number of layers ({len(depths) - 1})"
+ )
+
+ if num_query_pool >= len(depths):
+ raise ValueError(
+ f"num_query_pool ({num_query_pool}) must be less than the number of layers ({len(depths)})"
+ )
+
+ self.embed_dim = embed_dim
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.patch_stride = patch_stride
+ self.patch_padding = patch_padding
+ self.mlp_ratio = mlp_ratio
+ self.depths = depths
+ self.num_heads = num_heads
+ self.num_layers = len(depths)
+ self.embed_dim_multiplier = embed_dim_multiplier
+ self.num_query_pool = num_query_pool
+ self.query_stride = query_stride
+ self.masked_unit_size = masked_unit_size
+ self.masked_unit_attention = masked_unit_attention
+ self.drop_path_rate = drop_path_rate
+ self.num_channels = num_channels
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.layer_norm_init = layer_norm_init
+ self.layer_norm_eps = layer_norm_eps
+ self.decoder_hidden_size = decoder_hidden_size
+ self.decoder_depth = decoder_depth
+ self.decoder_num_heads = decoder_num_heads
+ self.normalize_pixel_loss = normalize_pixel_loss
+ self.mask_ratio = mask_ratio
+ # we set the hidden_size attribute in order to make Hiera work with VisionEncoderDecoderModel
+ # this indicates the channel dimension after the last stage of the model
+ self.hidden_size = int(embed_dim * embed_dim_multiplier ** (len(depths) - 1))
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+
+
+__all__ = ["HieraConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hiera/modeling_hiera.py b/venv/lib/python3.13/site-packages/transformers/models/hiera/modeling_hiera.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c084f0f836e18c4891491ea983676e9752e80a2
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hiera/modeling_hiera.py
@@ -0,0 +1,1439 @@
+# coding=utf-8
+# Copyright 2024 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Hiera model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BackboneOutput,
+ BaseModelOutput,
+ BaseModelOutputWithPooling,
+ ImageClassifierOutput,
+ ModelOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging, torch_int
+from ...utils.backbone_utils import BackboneMixin
+from .configuration_hiera import HieraConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Hiera encoder's outputs, with potential hidden states and attentions.
+ """
+)
+class HieraEncoderOutput(ModelOutput):
+ r"""
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Hiera model's outputs that also contains a pooling of the last hidden states.
+ """
+)
+class HieraModelOutput(ModelOutput):
+ r"""
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
+ Average pooling of the last layer hidden-state.
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
+ Tensor indicating which patches are masked (0) and which are not (1).
+ ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Tensor containing the original index of the (shuffled) masked patches.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ bool_masked_pos: Optional[torch.BoolTensor] = None
+ ids_restore: Optional[torch.LongTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Hiera image classification outputs.
+ """
+)
+class HieraForImageClassificationOutput(ImageClassifierOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, `optional`):
+ Loss value for the training task.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
+ Prediction scores of the classification head (logits of the output layer).
+ hidden_states (`tuple(torch.FloatTensor)`, `optional`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`. These are the unrolled hidden states of the model.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, `optional`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, `optional`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Class for HieraForPreTraining's outputs, with potential hidden states and attentions.
+ """
+)
+class HieraForPreTrainingOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`):
+ Pixel reconstruction loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
+ Pixel reconstruction logits.
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
+ Tensor indicating which patches are masked (0) and which are not (1).
+ ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Tensor containing the original index of the (shuffled) masked patches.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, height, width, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs reshaped to include the spatial dimensions.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ bool_masked_pos: Optional[torch.BoolTensor] = None
+ ids_restore: Optional[torch.LongTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ reshaped_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+class HieraPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config, is_mae: bool = False):
+ super().__init__()
+
+ # Support any number of spatial dimensions
+ self.spatial_dims = len(config.patch_size)
+ if self.spatial_dims != 2:
+ raise ValueError(f"The number of dimensions of the input image should be 2, but got {self.spatial_dims}.")
+ self.num_channels = config.num_channels
+ self.image_size = config.image_size[-2:]
+ self.tokens_spatial_shape = [i // s for i, s in zip(config.image_size, config.patch_stride)]
+ self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, config.masked_unit_size)]
+ self.mask_ratio = config.mask_ratio
+ self.is_mae = is_mae
+ self.projection = nn.Conv2d(
+ self.num_channels,
+ config.embed_dim,
+ kernel_size=config.patch_size,
+ stride=config.patch_stride,
+ padding=config.patch_padding,
+ )
+
+ def masked_conv(
+ self, pixel_values: torch.FloatTensor, bool_masked_pos: Optional[torch.BoolTensor] = None
+ ) -> torch.Tensor:
+ """Zero-out the masked regions of the input before conv.
+ Prevents leakage of masked regions when using overlapping kernels.
+ """
+ if bool_masked_pos is None:
+ return self.projection(pixel_values)
+
+ target_size = pixel_values.shape[2:]
+ # Reshape bool_masked_pos to (batch_size, 1, mask_unit_height, mask_unit_width)
+ bool_masked_pos = bool_masked_pos.view(pixel_values.shape[0], 1, *self.mask_spatial_shape)
+
+ bool_masked_pos = nn.functional.interpolate(bool_masked_pos.float(), size=target_size)
+
+ return self.projection(pixel_values * bool_masked_pos)
+
+ def random_masking(
+ self, pixel_values: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None
+ ) -> tuple[torch.BoolTensor, torch.LongTensor]:
+ """
+ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
+ noise.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`)
+ noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
+ mainly used for testing purposes to control randomness and maintain the reproducibility
+ """
+ batch_size = pixel_values.shape[0]
+ # Tokens selected for masking at mask unit level
+ num_windows = math.prod(self.mask_spatial_shape)
+ len_keep = int(num_windows * (1 - self.mask_ratio))
+
+ if noise is None:
+ noise = torch.rand(batch_size, num_windows, device=pixel_values.device)
+
+ # Sort noise for each sample
+ ids_shuffle = torch.argsort(noise, dim=1)
+ # ascend: small is keep, large is remove
+ ids_restore = torch.argsort(ids_shuffle, dim=1).to(pixel_values.device)
+
+ # Generate the binary bool_masked_pos: 1 is *keep*, 0 is *remove*
+ # Note this is opposite to original MAE
+ bool_masked_pos = torch.zeros([batch_size, num_windows], device=pixel_values.device)
+ bool_masked_pos[:, :len_keep] = 1
+ # Unshuffle to get the binary bool_masked_pos
+ bool_masked_pos = torch.gather(bool_masked_pos, dim=1, index=ids_restore).bool()
+
+ return bool_masked_pos, ids_restore
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ noise: Optional[torch.FloatTensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.BoolTensor], Optional[torch.LongTensor]]:
+ (bool_masked_pos, ids_restore) = (
+ self.random_masking(pixel_values, noise=noise) if self.is_mae else (None, None)
+ )
+
+ embeddings = self.masked_conv(pixel_values, bool_masked_pos)
+ embeddings = embeddings.flatten(2).transpose(2, 1)
+
+ return embeddings, bool_masked_pos, ids_restore
+
+
+class HieraEmbeddings(nn.Module):
+ """
+ Construct position and patch embeddings.
+ """
+
+ def __init__(self, config: HieraConfig, is_mae: bool = False) -> None:
+ super().__init__()
+ self.patch_stride = config.patch_stride
+ tokens_spatial_shape = [i // s for i, s in zip(config.image_size, config.patch_stride)]
+ self.mask_spatial_shape = [i // s for i, s in zip(tokens_spatial_shape, config.masked_unit_size)]
+ self.num_tokens = math.prod(tokens_spatial_shape)
+ self.is_mae = is_mae
+
+ self.patch_embeddings = HieraPatchEmbeddings(config, is_mae=is_mae)
+
+ self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_tokens, config.embed_dim))
+
+ def interpolate_pos_encoding(
+ self, embeddings: torch.Tensor, pos_embeds: torch.Tensor, height: int, width: int
+ ) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing, no class embeddings, and different patch strides.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1]
+ num_positions = pos_embeds.shape[1]
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return pos_embeds
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_stride[0]
+ new_width = width // self.patch_stride[1]
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ pos_embeds = pos_embeds.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ pos_embeds = pos_embeds.permute(0, 3, 1, 2)
+
+ pos_embeds = nn.functional.interpolate(
+ pos_embeds,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ pos_embeds = pos_embeds.permute(0, 2, 3, 1).view(1, -1, dim)
+ return pos_embeds
+
+ def get_position_embedding(
+ self, embeddings: torch.Tensor, height: int, width: int, interpolate_pos_encoding: bool
+ ) -> torch.FloatTensor:
+ return (
+ self.interpolate_pos_encoding(embeddings, self.position_embeddings, height, width)
+ if interpolate_pos_encoding
+ else self.position_embeddings
+ )
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ noise: Optional[torch.FloatTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> tuple[torch.Tensor, Optional[torch.BoolTensor], Optional[torch.LongTensor]]:
+ height, width = pixel_values.shape[-2:]
+ embeddings, bool_masked_pos, ids_restore = self.patch_embeddings(pixel_values, noise=noise)
+ embeddings = embeddings + self.get_position_embedding(embeddings, height, width, interpolate_pos_encoding)
+ return embeddings, bool_masked_pos, ids_restore
+
+
+class HieraMaskUnitAttention(nn.Module):
+ """
+ Computes either Mask Unit or Global Attention. Also is able to perform query pooling.
+
+ Note: this assumes the tokens have already been flattened and unrolled into mask units.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ hidden_size_output: int,
+ num_heads: int,
+ query_stride: int = 1,
+ window_size: int = 0,
+ use_mask_unit_attn: bool = False,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ self.query_stride = query_stride
+ self.hidden_size_output = hidden_size_output
+
+ self.head_dim = hidden_size_output // num_heads
+ self.scale = (self.head_dim) ** -0.5
+
+ self.qkv = nn.Linear(hidden_size, 3 * hidden_size_output)
+ self.proj = nn.Linear(hidden_size_output, hidden_size_output)
+
+ self.window_size = window_size
+ self.use_mask_unit_attn = use_mask_unit_attn
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input should be of shape [batch, tokens, channels]."""
+ batch_size, seq_len, _ = hidden_states.shape
+
+ num_windows = 1
+ if self.use_mask_unit_attn:
+ num_windows = seq_len // (self.query_stride * self.window_size)
+
+ qkv = self.qkv(hidden_states)
+ qkv = qkv.reshape(batch_size, -1, num_windows, 3, self.num_heads, self.head_dim)
+ qkv = qkv.permute(3, 0, 4, 2, 1, 5)
+
+ query, key, value = qkv.unbind(0)
+
+ if self.query_stride > 1:
+ # Refer to unroll to see how this performs a maxpool-Nd
+ query = query.view(batch_size, self.num_heads, num_windows, self.query_stride, -1, self.head_dim)
+ query = query.max(dim=3).values
+
+ attn_weights = (query * self.scale) @ key.transpose(-1, -2)
+ attn_weights = attn_weights.softmax(dim=-1)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = attn_weights @ value
+ attn_output = attn_output.transpose(1, 3).reshape(batch_size, -1, self.hidden_size_output)
+ attn_output = self.proj(attn_output)
+
+ return (attn_output, attn_weights) if output_attentions else (attn_output, None)
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Hiera
+class HieraDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return f"p={self.drop_prob}"
+
+
+class HieraMlp(nn.Module):
+ def __init__(self, config, dim: int) -> None:
+ super().__init__()
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(dim, int(dim * config.mlp_ratio))
+ self.fc2 = nn.Linear(int(dim * config.mlp_ratio), dim)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class HieraLayer(nn.Module):
+ def __init__(
+ self,
+ config,
+ hidden_size: int,
+ hidden_size_output: int,
+ num_heads: int,
+ drop_path: float = 0.0,
+ query_stride: int = 1,
+ window_size: int = 0,
+ use_mask_unit_attn: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.hidden_size_output = hidden_size_output
+ self.query_stride = query_stride
+
+ self.layernorm_before = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+ self.attn = HieraMaskUnitAttention(
+ hidden_size=hidden_size,
+ hidden_size_output=hidden_size_output,
+ num_heads=num_heads,
+ query_stride=query_stride,
+ window_size=window_size,
+ use_mask_unit_attn=use_mask_unit_attn,
+ )
+
+ self.layernorm_after = nn.LayerNorm(hidden_size_output, eps=config.layer_norm_eps)
+ self.mlp = HieraMlp(config, hidden_size_output)
+
+ self.drop_path = HieraDropPath(drop_path) if drop_path > 0 else nn.Identity()
+ if hidden_size != hidden_size_output:
+ self.proj = nn.Linear(hidden_size, hidden_size_output)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ batch_size, seq_len, _ = hidden_states.shape
+ # Attention + Q Pooling
+ hidden_states_norm = self.layernorm_before(hidden_states)
+ if self.hidden_size != self.hidden_size_output:
+ hidden_states = self.proj(hidden_states_norm)
+ # Refer to unroll to see how this performs a maxpool-Nd
+ hidden_states = (
+ hidden_states.view(batch_size, self.query_stride, -1, self.hidden_size_output).max(dim=1).values
+ )
+
+ (hidden_states_norm, attn_weights) = self.attn(
+ hidden_states_norm, head_mask, output_attentions=output_attentions
+ )
+ hidden_states = hidden_states + self.drop_path(hidden_states_norm)
+
+ residual = hidden_states
+ hidden_states = self.layernorm_after(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + self.drop_path(hidden_states)
+
+ return (hidden_states, attn_weights)
+
+
+class HieraStage(GradientCheckpointingLayer):
+ def __init__(
+ self,
+ config,
+ depth: int,
+ hidden_size: int,
+ hidden_size_output: int,
+ num_heads: int,
+ drop_path: list[float],
+ query_stride: list[int],
+ window_size: int,
+ use_mask_unit_attn: bool,
+ stage_num: Optional[int] = None,
+ ) -> None:
+ super().__init__()
+ # we need to know if the previous stage used masked attention
+ # mask unit or global attention.
+ # lag by 1 layer, so that global attention,
+ # applied post pooling on lower resolution
+ previous_stage_used_masked_attention = False
+ if stage_num is not None:
+ previous_stage_used_masked_attention = config.masked_unit_attention[stage_num - 1 if stage_num > 0 else 0]
+ self.layers = nn.ModuleList(
+ [
+ HieraLayer(
+ config=config,
+ hidden_size=hidden_size if i == 0 else hidden_size_output,
+ hidden_size_output=hidden_size_output,
+ num_heads=num_heads,
+ drop_path=drop_path[i],
+ query_stride=query_stride[i],
+ window_size=window_size,
+ use_mask_unit_attn=use_mask_unit_attn or (previous_stage_used_masked_attention and i == 0),
+ )
+ for i in range(depth)
+ ]
+ )
+
+ def forward(
+ self, hidden_states: torch.Tensor, head_mask: Optional[torch.FloatTensor], output_attentions: bool = False
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ for i, layer_module in enumerate(self.layers):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ (hidden_states, attn_weights) = layer_module(
+ hidden_states, layer_head_mask, output_attentions=output_attentions
+ )
+
+ return hidden_states, attn_weights
+
+
+def undo_windowing(hidden_states: torch.Tensor, shape: list[int], mask_unit_shape: list[int]) -> torch.Tensor:
+ """
+ Restore spatial organization by undoing windowed organization of mask units.
+
+ Args:
+ hidden_states (`torch.Tensor`): The hidden states tensor of shape `[batch_size, num_mask_unit_height*num_mask_unit_width, hidden_size]`.
+ shape (`list[int]`): The original shape of the hidden states tensor before windowing.
+ mask_unit_shape (`list[int]`): The shape of the mask units used for windowing.
+
+ Returns:
+ torch.Tensor: The restored hidden states tensor of shape [batch_size, num_mask_unit_height*mask_unit_height, num_mask_unit_width*mask_unit_width, hidden_size].
+ """
+ batch_size, hidden_size = hidden_states.shape[0], hidden_states.shape[-1]
+ # From: [batch_size, num_mask_unit_height*num_mask_unit_width, hidden_size]
+ # To: [batch_size, num_mask_unit_height, num_mask_unit_width, mask_unit_height, mask_unit_width, hidden_size]
+ num_mask_units = [s // mu for s, mu in zip(shape, mask_unit_shape)]
+ hidden_states = hidden_states.view(batch_size, *num_mask_units, *mask_unit_shape, hidden_size)
+
+ # From: [batch_size, num_mask_unit_height, num_mask_unit_width, mask_unit_height, mask_unit_width, hidden_size]
+ # To: [batch_size, num_mask_unit_height*mask_unit_height, num_mask_unit_width*mask_unit_width, hidden_size]
+ hidden_states = hidden_states.permute(0, 1, 3, 2, 4, 5)
+ hidden_states = hidden_states.reshape(batch_size, *shape, hidden_size)
+
+ return hidden_states
+
+
+class HieraEncoder(nn.Module):
+ def __init__(self, config: HieraConfig) -> None:
+ super().__init__()
+ total_depth = sum(config.depths)
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, total_depth, device="cpu")]
+ # query strides rule
+ cumulative_depths = torch.tensor(config.depths, device="cpu").cumsum(0).tolist()
+ query_pool_layer = cumulative_depths[: config.num_query_pool]
+ query_strides = [math.prod(config.query_stride) if i in query_pool_layer else 1 for i in range(total_depth)]
+
+ # Transformer blocks
+ self.stages = nn.ModuleList()
+ hidden_size = config.embed_dim
+ stage_ends = [0] + cumulative_depths
+ masked_unit_area = math.prod(config.masked_unit_size)
+ query_stride_area = math.prod(config.query_stride)
+ for idx_stage, depth in enumerate(config.depths):
+ hidden_size_output = int(config.embed_dim * config.embed_dim_multiplier**idx_stage)
+
+ stage = HieraStage(
+ config=config,
+ depth=depth,
+ hidden_size=hidden_size,
+ hidden_size_output=hidden_size_output,
+ num_heads=config.num_heads[idx_stage],
+ drop_path=dpr[stage_ends[idx_stage] : stage_ends[idx_stage + 1]],
+ query_stride=query_strides[stage_ends[idx_stage] : stage_ends[idx_stage + 1]],
+ window_size=int(masked_unit_area * query_stride_area**-idx_stage),
+ use_mask_unit_attn=config.masked_unit_attention[idx_stage],
+ stage_num=idx_stage,
+ )
+
+ hidden_size = hidden_size_output
+ self.stages.append(stage)
+
+ # Setting reroll schedule
+ # The first stage has to reverse everything
+ # The next stage has to reverse all but the first unroll, etc.
+ stage_size = [i // s for i, s in zip(config.image_size, config.patch_stride)]
+ unroll_schedule = [config.query_stride] * len(config.depths[:-1])
+
+ self.schedule = {}
+ for idx_stage in range(len(config.depths)):
+ self.schedule[idx_stage] = unroll_schedule, stage_size
+ if idx_stage < config.num_query_pool:
+ stage_size = [i // s for i, s in zip(stage_size, config.query_stride)]
+ unroll_schedule = unroll_schedule[1:]
+
+ self.gradient_checkpointing = False
+
+ def reroll(
+ self, hidden_states: torch.Tensor, stage_idx: int, bool_masked_pos: Optional[torch.BoolTensor] = None
+ ) -> torch.Tensor:
+ """
+ Roll the given tensor back up to spatial order assuming it's from the given block.
+
+ If no bool_masked_pos is provided returns:
+ - [batch_size, height, width, hidden_size]
+ If a bool_masked_pos is provided returns:
+ - [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
+ """
+ schedule, size = self.schedule[stage_idx]
+ batch_size, seq_len, hidden_size = hidden_states.shape
+
+ num_dim = len(size)
+ mask_unit_shape = [1] * num_dim
+
+ for strides in schedule:
+ # Extract the current patch from seq_len
+ hidden_states = hidden_states.view(
+ batch_size, *strides, seq_len // math.prod(strides), *mask_unit_shape, hidden_size
+ )
+
+ # Move that patch into the current MU
+ # Input: [batch_size, stride, stride, seq_len//(stride*stride), mask_unit_height, mask_unit_width, hidden_size]
+ # Output: [batch_size, seq_len//(stride*stride), stride, mask_unit_height, stride, mask_unit_width, hidden_size]
+ hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5, 6)
+
+ # Reshape to [batch_size, seq_len//(stride*stride), *mask_units, hidden_size]
+ for i in range(num_dim):
+ mask_unit_shape[i] *= strides[i]
+ hidden_states = hidden_states.reshape(batch_size, -1, *mask_unit_shape, hidden_size)
+ seq_len = hidden_states.shape[1]
+
+ # Current shape (e.g., 2d: [batch_size, #num_mask_units_height*#num_mask_units_width, mask_unit_height, mask_unit_width, hidden_size])
+ hidden_states = hidden_states.view(batch_size, seq_len, *mask_unit_shape, hidden_size)
+
+ # If masked, return [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
+ if bool_masked_pos is not None:
+ return hidden_states
+
+ # If not masked, we can return [batch_size, height, width, hidden_size]
+ hidden_states = undo_windowing(hidden_states, size, mask_unit_shape)
+
+ return hidden_states
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_reshaped_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ reshaped_hidden_states = self.reroll(hidden_states, stage_idx=0, bool_masked_pos=bool_masked_pos)
+ all_reshaped_hidden_states = all_reshaped_hidden_states + (reshaped_hidden_states,)
+
+ for i, stage_module in enumerate(self.stages):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ reshaped_hidden_states = self.reroll(hidden_states, stage_idx=i, bool_masked_pos=bool_masked_pos)
+ all_reshaped_hidden_states = all_reshaped_hidden_states + (reshaped_hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, all_hidden_states, all_self_attentions, all_reshaped_hidden_states]
+ if v is not None
+ )
+ return HieraEncoderOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ reshaped_hidden_states=all_reshaped_hidden_states,
+ )
+
+
+def unroll(
+ hidden_states: torch.Tensor, image_shape: tuple[int, int], patch_stride: tuple[int, int], schedule: list[list[int]]
+) -> torch.Tensor:
+ """
+ Reorders the tokens such that patches are contiguous in memory.
+ E.g., given [batch_size, (height, width), hidden_size] and stride of (stride, stride), this will re-order the tokens as
+ [batch_size, (stride, stride, height // stride, width // stride), hidden_size]
+
+ This allows operations like Max2d to be computed as x.view(batch_size, stride*stride, -1, hidden_size).max(dim=1).
+ Not only is this faster, but it also makes it easy to support inputs of arbitrary
+ dimensions in addition to patch-wise sparsity.
+
+ Performing this operation multiple times in sequence puts entire windows as contiguous
+ in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
+ size 8x8 would be contiguous in memory, allowing operations like mask unit attention
+ computed easily and efficiently, while also allowing max to be applied sequentially.
+
+ Note: This means that intermediate values of the model are not in height x width order, so they
+ need to be re-rolled if you want to use the intermediate values as a height x width feature map.
+ The last block of the network is fine though, since by then the strides are all consumed.
+ """
+ batch_size, _, hidden_size = hidden_states.shape
+
+ size = [i // s for i, s in zip(image_shape, patch_stride)]
+
+ current_size = size
+ hidden_states = hidden_states.view(*([batch_size] + current_size + [hidden_size]))
+
+ for strides in schedule:
+ # Move patches with the given strides to the batch dimension
+
+ # Create a view of the tensor with the patch stride as separate dims
+ # For example in 2d: [batch_size, height // stride, stride, width // stride, stride, C]
+ current_size = [i // s for i, s in zip(current_size, strides)]
+ # initialize new_shape with [height // stride, stride, width // stride, stride]
+ new_shape = [item for pair in zip(current_size, strides) for item in pair]
+ # add batch_size and hidden_size to new_shape
+ new_shape = [batch_size] + new_shape + [hidden_size]
+ hidden_states = hidden_states.view(new_shape)
+
+ # Move the patch stride into the batch dimension
+ # For example in 2d: [batch_size, stride, stride, height // stride, width // stride, hidden_size]
+ num_dims = len(new_shape)
+ permute = [0] + list(range(2, num_dims - 1, 2)) + list(range(1, num_dims - 1, 2)) + [num_dims - 1]
+ hidden_states = hidden_states.permute(permute)
+
+ # Now finally flatten the relevant dims into the batch dimension
+ hidden_states = hidden_states.flatten(0, len(strides))
+ batch_size *= math.prod(strides)
+
+ hidden_states = hidden_states.reshape(-1, math.prod(size), hidden_size)
+ return hidden_states
+
+
+@auto_docstring
+class HieraPreTrainedModel(PreTrainedModel):
+ config: HieraConfig
+ base_model_prefix = "hiera"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module) -> None:
+ """Initialize the weights"""
+ std = self.config.initializer_range
+
+ if isinstance(module, HieraEmbeddings):
+ nn.init.trunc_normal_(module.position_embeddings, std=std)
+
+ elif isinstance(module, HieraDecoder):
+ nn.init.trunc_normal_(module.mask_token, std=std)
+ nn.init.trunc_normal_(module.decoder_position_embeddings, std=std)
+
+ elif isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
+ nn.init.trunc_normal_(module.weight, std=std)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, std)
+
+ elif isinstance(module, nn.LayerNorm):
+ nn.init.constant_(module.bias, std)
+ nn.init.constant_(module.weight, self.config.layer_norm_init)
+
+
+class HieraPooler(nn.Module):
+ def __init__(self, config: HieraConfig):
+ super().__init__()
+ num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
+ self.layernorm = nn.LayerNorm(num_features, eps=config.layer_norm_eps)
+ self.pooler = nn.AdaptiveAvgPool1d(1)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = hidden_states.transpose(1, 2)
+ pooled_output = self.pooler(hidden_states)
+ pooled_output = torch.flatten(pooled_output, 1)
+ pooled_output = self.layernorm(pooled_output)
+ return pooled_output
+
+
+@auto_docstring
+class HieraModel(HieraPreTrainedModel):
+ def __init__(self, config: HieraConfig, add_pooling_layer: bool = True, is_mae: bool = False):
+ r"""
+ add_pooling_layer (`bool`, *optional*, defaults to `True`):
+ Whether or not to apply pooling layer.
+ is_mae (`bool`, *optional*, defaults to `False`):
+ Whether or not to run the model on MAE mode.
+ """
+ super().__init__(config)
+ self.num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
+
+ self.embeddings = HieraEmbeddings(config, is_mae=is_mae)
+ self.encoder = HieraEncoder(config)
+
+ self.unroll_schedule = [config.query_stride] * len(config.depths[:-1])
+
+ self.pooler = HieraPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> HieraPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ noise: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ r"""
+ noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*):
+ Mainly used for testing purposes to control randomness and maintain the reproducibility
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, len(self.config.depths))
+
+ embedding_output, bool_masked_pos, ids_restore = self.embeddings(
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, noise=noise
+ )
+
+ image_shape = (pixel_values.shape[-2], pixel_values.shape[-1])
+ hidden_states = unroll(
+ embedding_output,
+ image_shape=image_shape,
+ patch_stride=self.config.patch_stride,
+ schedule=self.unroll_schedule,
+ )
+
+ # Discard masked tokens if bool_masked_pos is provided
+ if bool_masked_pos is not None:
+ mask_unit_area = math.prod(self.config.masked_unit_size)
+ batch_size, _, hidden_size = hidden_states.shape
+ positions = bool_masked_pos.unsqueeze(-1).tile(1, mask_unit_area, hidden_size)
+ hidden_states = hidden_states[positions]
+ hidden_states = hidden_states.view(batch_size, -1, hidden_size)
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = None
+ if self.pooler is not None:
+ pooled_output = self.pooler(sequence_output)
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+ head_outputs = (
+ head_outputs + (bool_masked_pos, ids_restore) if bool_masked_pos is not None else head_outputs
+ )
+ return head_outputs + encoder_outputs[1:]
+
+ return HieraModelOutput(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ bool_masked_pos=bool_masked_pos,
+ ids_restore=ids_restore,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
+ )
+
+
+class HieraDecoder(nn.Module):
+ def __init__(self, config: HieraConfig):
+ super().__init__()
+ num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
+ tokens_spatial_shape = [i // s for i, s in zip(config.image_size, config.patch_stride)]
+ self.tokens_spatial_shape_final = [
+ i // s ** (config.num_query_pool) for i, s in zip(tokens_spatial_shape, config.query_stride)
+ ]
+ self.mask_unit_spatial_shape_final = [
+ i // s ** (config.num_query_pool) for i, s in zip(config.masked_unit_size, config.query_stride)
+ ]
+
+ self.decoder_embeddings = nn.Linear(num_features, config.decoder_hidden_size)
+
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
+
+ self.decoder_position_embeddings = nn.Parameter(
+ torch.zeros(1, math.prod(self.tokens_spatial_shape_final), config.decoder_hidden_size)
+ )
+
+ self.decoder_block = HieraStage(
+ config=config,
+ hidden_size=config.decoder_hidden_size,
+ hidden_size_output=config.decoder_hidden_size,
+ num_heads=config.decoder_num_heads,
+ depth=config.decoder_depth,
+ use_mask_unit_attn=False,
+ drop_path=[0.0] * config.decoder_depth,
+ query_stride=[1] * config.decoder_depth,
+ window_size=0,
+ )
+
+ self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
+
+ # patch stride of prediction
+ self.pred_stride = config.patch_stride[-1] * (config.query_stride[-1] ** config.num_query_pool)
+ pred_dim = (self.pred_stride ** len(config.query_stride)) * config.num_channels
+
+ self.decoder_pred = nn.Linear(config.decoder_hidden_size, pred_dim)
+
+ def forward(
+ self,
+ encoder_hidden_states: torch.Tensor,
+ bool_masked_pos: torch.BoolTensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, torch.BoolTensor]:
+ # Embed tokens
+ hidden_states = self.decoder_embeddings(encoder_hidden_states)
+
+ # Combine visible and bool_masked_pos tokens
+
+ # hidden_states : [batch_size, num_mask_units_visible, *mask_unit_spatial_shape_final, decoder_hidden_size]
+ # bool_masked_pos: [batch_size, num_mask_units]
+ mask_unit_height, mask_unit_width, decoder_hidden_size = hidden_states.shape[2:]
+ batch_size, num_mask_units = bool_masked_pos.shape
+
+ decoder_hidden_states = torch.zeros(
+ batch_size,
+ num_mask_units,
+ mask_unit_height,
+ mask_unit_width,
+ decoder_hidden_size,
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+ mask_tokens = self.mask_token.view(1, 1, 1, 1, -1)
+ bool_masked_pos = bool_masked_pos.reshape(batch_size, num_mask_units, 1, 1, 1)
+ bool_masked_pos = bool_masked_pos.expand(-1, -1, mask_unit_height, mask_unit_width, decoder_hidden_size)
+ decoder_hidden_states[bool_masked_pos] = hidden_states.flatten()
+ decoder_hidden_states = (
+ 1 - bool_masked_pos.float()
+ ) * mask_tokens + bool_masked_pos.float() * decoder_hidden_states
+
+ # Get back spatial order
+ hidden_states = undo_windowing(
+ decoder_hidden_states,
+ self.tokens_spatial_shape_final,
+ self.mask_unit_spatial_shape_final,
+ )
+ bool_masked_pos = undo_windowing(
+ bool_masked_pos[..., 0:1],
+ self.tokens_spatial_shape_final,
+ self.mask_unit_spatial_shape_final,
+ )
+
+ # Flatten
+ hidden_states = hidden_states.reshape(hidden_states.shape[0], -1, hidden_states.shape[-1])
+ bool_masked_pos = bool_masked_pos.view(hidden_states.shape[0], -1)
+
+ # Add pos embed
+ hidden_states = hidden_states + self.decoder_position_embeddings
+
+ # Apply decoder blocks
+ hidden_states, attn_weights = self.decoder_block(
+ hidden_states, head_mask=head_mask, output_attentions=output_attentions
+ )
+ hidden_states = self.decoder_norm(hidden_states)
+
+ # Predictor projection
+ hidden_states = self.decoder_pred(hidden_states)
+
+ return hidden_states, bool_masked_pos
+
+
+class HieraMultiScaleHead(nn.Module):
+ def __init__(self, config: HieraConfig):
+ super().__init__()
+ self.mask_unit_spatial_shape_final = [
+ i // s ** (config.num_query_pool) for i, s in zip(config.masked_unit_size, config.query_stride)
+ ]
+ self.stage_dimensions = [
+ int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(len(config.depths))
+ ]
+ current_masked_unit_size = config.masked_unit_size
+ self.multi_scale_fusion_heads = nn.ModuleList()
+
+ for idx in range(config.num_query_pool):
+ kernel = [i // s for i, s in zip(current_masked_unit_size, self.mask_unit_spatial_shape_final)]
+ current_masked_unit_size = [i // s for i, s in zip(current_masked_unit_size, config.query_stride)]
+ self.multi_scale_fusion_heads.append(
+ nn.Conv2d(
+ self.stage_dimensions[idx],
+ self.stage_dimensions[-1],
+ kernel_size=kernel,
+ stride=kernel,
+ )
+ )
+ self.multi_scale_fusion_heads.append(nn.Identity())
+
+ def apply_fusion_head(self, head: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
+ if isinstance(head, nn.Identity):
+ return hidden_states
+
+ # Doing explicit to avoid problems with torch.fx
+ batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size = hidden_states.shape
+ # From: [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
+ # To: head([batch_size * num_mask_units, hidden_size, mask_unit_height, mask_unit_width])
+ hidden_states = hidden_states.reshape(
+ batch_size * num_mask_units, mask_unit_height, mask_unit_width, hidden_size
+ )
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
+ hidden_states = head(hidden_states)
+
+ # Restore original layout
+ hidden_states = hidden_states.permute(0, 2, 3, 1)
+ mask_unit_height_final, mask_unit_width_final, hidden_size = hidden_states.shape[1:]
+ hidden_states = hidden_states.reshape(
+ batch_size, num_mask_units, mask_unit_height_final, mask_unit_width_final, hidden_size
+ )
+
+ return hidden_states
+
+ def forward(self, feature_maps: list[torch.Tensor]) -> torch.Tensor:
+ # Multi-scale fusion
+ hidden_states = 0.0
+ for head, feature_map in zip(self.multi_scale_fusion_heads, feature_maps):
+ hidden_states = hidden_states + self.apply_fusion_head(head, feature_map)
+
+ return hidden_states
+
+
+@auto_docstring(
+ custom_intro="""
+ The Hiera Model transformer with the decoder on top for self-supervised pre-training.
+
+
+
+ Note that we provide a script to pre-train this model on custom data in our [examples
+ directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
+
+
+ """
+)
+class HieraForPreTraining(HieraPreTrainedModel):
+ def __init__(self, config: HieraConfig) -> None:
+ super().__init__(config)
+ # Encoder
+ self.hiera = HieraModel(config, add_pooling_layer=False, is_mae=True)
+ self.encoder_norm = nn.LayerNorm(self.hiera.num_features, eps=config.layer_norm_eps)
+ # Multi-scale fusion heads
+ self.multiscale_fusion = HieraMultiScaleHead(config)
+ # Decoder
+ self.decoder = HieraDecoder(config)
+ self.pred_stride = self.decoder.pred_stride
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_pixel_label_2d(self, pixel_values: torch.Tensor, bool_masked_pos: torch.BoolTensor) -> torch.Tensor:
+ # bool_masked_pos (boolean tensor): True means *masked*
+ pixel_values = pixel_values.permute(0, 2, 3, 1)
+
+ size = self.pred_stride
+ label = pixel_values.unfold(1, size, size).unfold(2, size, size)
+ label = label.flatten(1, 2).flatten(2)
+ label = label[bool_masked_pos]
+ if self.config.normalize_pixel_loss:
+ mean = label.mean(dim=-1, keepdim=True)
+ var = label.var(dim=-1, keepdim=True)
+ label = (label - mean) / (var + 1.0e-6) ** 0.5
+
+ return label
+
+ def forward_loss(self, pixel_values: torch.Tensor, logits: torch.Tensor, bool_masked_pos: torch.BoolTensor):
+ # We invert the bool_masked_pos such that 1.0 is *masked*
+ bool_masked_pos = ~bool_masked_pos
+ label = self.get_pixel_label_2d(pixel_values, bool_masked_pos)
+
+ logits = logits[bool_masked_pos]
+ loss = (logits - label) ** 2
+ loss = loss.mean()
+
+ return loss
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ noise: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, HieraForPreTrainingOutput]:
+ r"""
+ noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*):
+ Mainly used for testing purposes to control randomness and maintain the reproducibility
+
+ Examples:
+ ```python
+ >>> from transformers import AutoImageProcessor, HieraForPreTraining
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/hiera-tiny-224-mae-hf")
+ >>> model = HieraForPreTraining.from_pretrained("facebook/hiera-tiny-224-mae-hf")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ >>> loss = outputs.loss
+ >>> print(list(logits.shape))
+ [1, 196, 768]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.hiera(
+ pixel_values,
+ noise=noise,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ feature_maps = outputs[-1]
+ bool_masked_pos = outputs[1]
+ ids_to_restore = outputs[2]
+ # Take only the query pooled and last hidden states
+ feature_maps = feature_maps[1 : self.hiera.config.num_query_pool + 1] + (feature_maps[-1],)
+ fused_hidden_states = self.multiscale_fusion(feature_maps)
+ fused_hidden_states = self.encoder_norm(fused_hidden_states)
+
+ # Reconstruct pixel values
+ logits, bool_masked_pos = self.decoder(
+ fused_hidden_states,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ )
+
+ loss = self.forward_loss(pixel_values, logits, bool_masked_pos)
+
+ if not return_dict:
+ output = (logits, bool_masked_pos, ids_to_restore)
+ if output_hidden_states:
+ output = output + (outputs[3],)
+ if output_attentions:
+ output = output + (outputs[4],)
+ if output_hidden_states:
+ output = output + (outputs[-1],)
+ return ((loss,) + output) if loss is not None else output
+
+ return HieraForPreTrainingOutput(
+ loss=loss,
+ logits=logits,
+ bool_masked_pos=bool_masked_pos,
+ ids_restore=ids_to_restore,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ reshaped_hidden_states=outputs.reshaped_hidden_states if output_hidden_states else None,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Hiera Model transformer with an image classification head on top (a linear layer on top of the final hidden state with
+ average pooling) e.g. for ImageNet.
+
+
+
+ Note that it's possible to fine-tune Hiera on higher resolution images than the ones it has been trained on, by
+ setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
+ position embeddings to the higher resolution.
+
+
+ """
+)
+class HieraForImageClassification(HieraPreTrainedModel):
+ def __init__(self, config: HieraConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.hiera = HieraModel(config, add_pooling_layer=True, is_mae=False)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(self.hiera.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, HieraForImageClassificationOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.hiera(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return HieraForImageClassificationOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ reshaped_hidden_states=outputs.reshaped_hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Hiera backbone, to be used with frameworks like DETR and MaskFormer.
+ """
+)
+class HieraBackbone(HieraPreTrainedModel, BackboneMixin):
+ def __init__(self, config: HieraConfig):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.num_features = [config.embed_dim] + [
+ int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(len(config.depths))
+ ]
+ self.embeddings = HieraEmbeddings(config, is_mae=False)
+ self.encoder = HieraEncoder(config)
+
+ # Add layer norms to hidden states of out_features
+ hidden_states_norms = {}
+ for stage, num_channels in zip(self._out_features, self.channels):
+ hidden_states_norms[stage] = nn.LayerNorm(num_channels)
+ self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BackboneOutput:
+ """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/hiera-tiny-224-hf")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "facebook/hiera-tiny-224-hf", out_features=["stage1", "stage2", "stage3", "stage4"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 768, 7, 7]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ embedding_output, _, _ = self.embeddings(pixel_values)
+
+ outputs = self.encoder(
+ embedding_output,
+ head_mask=None,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[-1]
+
+ feature_maps = ()
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ batch_size, height, width, num_channels = hidden_state.shape
+ hidden_state = hidden_state.view(batch_size, height * width, num_channels)
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
+ hidden_state = hidden_state.view(batch_size, height, width, num_channels)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ feature_maps += (hidden_state,)
+
+ if not return_dict:
+ output = (feature_maps,)
+ if output_hidden_states:
+ output += (outputs[1],)
+ if output_attentions:
+ output += (outputs[2],)
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs[1] if output_hidden_states else None,
+ attentions=outputs[2] if output_attentions else None,
+ )
+
+
+__all__ = ["HieraForImageClassification", "HieraForPreTraining", "HieraBackbone", "HieraModel", "HieraPreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hubert/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/hubert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d975dabc689a73c83818ced8bed5ad86072df9b2
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hubert/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_hubert import *
+ from .modeling_hubert import *
+ from .modeling_tf_hubert import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hubert/configuration_hubert.py b/venv/lib/python3.13/site-packages/transformers/models/hubert/configuration_hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8af7b5a0f3cb712bb112bb5d7144ea0d9da29e0
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hubert/configuration_hubert.py
@@ -0,0 +1,265 @@
+# coding=utf-8
+# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Hubert model configuration"""
+
+import functools
+import operator
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class HubertConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`HubertModel`]. It is used to instantiate an
+ Hubert model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Hubert
+ [facebook/hubert-base-ls960](https://huggingface.co/facebook/hubert-base-ls960) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32):
+ Vocabulary size of the Hubert model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`HubertModel`]. Vocabulary size of the model. Defines the different
+ tokens that can be represented by the *inputs_ids* passed to the forward method of [`HubertModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout(`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ activation_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for activations inside the fully connected layer.
+ attention_dropout(`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ final_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`].
+ layerdrop (`float`, *optional*, defaults to 0.1):
+ The LayerDrop probability. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) for more
+ details.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ feat_extract_norm (`str`, *optional*, defaults to `"group"`):
+ The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
+ normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
+ convolutional layers.
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for output of the feature encoder.
+ feat_proj_layer_norm (`bool`, *optional*, defaults to `True`):
+ Whether to apply LayerNorm to the output of the feature encoder.
+ feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the 1D convolutional layers of the feature
+ extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ conv_dim (`tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
+ feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
+ conv_stride (`tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+ A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
+ conv_kernel (`tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
+ *conv_dim*.
+ conv_bias (`bool`, *optional*, defaults to `False`):
+ Whether the 1D convolutional layers have a bias.
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
+ embeddings layer.
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
+ Number of groups of 1D convolutional positional embeddings layer.
+ conv_pos_batch_norm (`bool`, *optional*, defaults to `False`):
+ Whether to use batch norm instead of weight norm in conv_pos
+ do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
+ Whether do apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
+ True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
+ False` corresponds to applying layer norm after the attention layer.
+ apply_spec_augment (`bool`, *optional*, defaults to `True`):
+ Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
+ [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
+ Recognition](https://huggingface.co/papers/1904.08779).
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
+ procedure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
+ reasoning from the probability of each feature vector to be chosen as the start of the vector span to be
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
+ mask_time_length (`int`, *optional*, defaults to 10):
+ Length of vector span along the time axis.
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
+ mask_time_min_masks''
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
+ masking procedure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
+ the axis. If reasoning from the probability of each feature vector to be chosen as the start of the vector
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
+ True`.
+ mask_feature_length (`int`, *optional*, defaults to 10):
+ Length of vector span along the feature axis.
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
+ step, irrespectively of `mask_feature_prob`. Only relevant if
+ ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
+ instance of [`HubertForCTC`].
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
+ of [`HubertForCTC`].
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
+ instance of [`HubertForSequenceClassification`].
+ classifier_proj_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the projection before token mean-pooling for classification.
+
+ Example:
+
+ ```python
+ >>> from transformers import HubertModel, HubertConfig
+
+ >>> # Initializing a Hubert facebook/hubert-base-ls960 style configuration
+ >>> configuration = HubertConfig()
+
+ >>> # Initializing a model from the facebook/hubert-base-ls960 style configuration
+ >>> model = HubertModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "hubert"
+
+ def __init__(
+ self,
+ vocab_size=32,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout=0.1,
+ activation_dropout=0.1,
+ attention_dropout=0.1,
+ feat_proj_layer_norm=True,
+ feat_proj_dropout=0.0,
+ final_dropout=0.1,
+ layerdrop=0.1,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ feat_extract_norm="group",
+ feat_extract_activation="gelu",
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
+ conv_kernel=(10, 3, 3, 3, 3, 2, 2),
+ conv_bias=False,
+ num_conv_pos_embeddings=128,
+ num_conv_pos_embedding_groups=16,
+ conv_pos_batch_norm=False,
+ do_stable_layer_norm=False,
+ apply_spec_augment=True,
+ mask_time_prob=0.05,
+ mask_time_length=10,
+ mask_time_min_masks=2,
+ mask_feature_prob=0.0,
+ mask_feature_length=10,
+ mask_feature_min_masks=0,
+ ctc_loss_reduction="sum",
+ ctc_zero_infinity=False,
+ use_weighted_layer_sum=False,
+ classifier_proj_size=256,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
+ self.hidden_size = hidden_size
+ self.feat_extract_norm = feat_extract_norm
+ self.feat_extract_activation = feat_extract_activation
+ self.conv_dim = list(conv_dim)
+ self.conv_stride = list(conv_stride)
+ self.conv_kernel = list(conv_kernel)
+ self.conv_bias = conv_bias
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
+ self.conv_pos_batch_norm = conv_pos_batch_norm
+ self.num_feat_extract_layers = len(self.conv_dim)
+ self.num_hidden_layers = num_hidden_layers
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.num_attention_heads = num_attention_heads
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.feat_proj_layer_norm = feat_proj_layer_norm
+ self.feat_proj_dropout = feat_proj_dropout
+ self.final_dropout = final_dropout
+ self.layerdrop = layerdrop
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ self.vocab_size = vocab_size
+ self.do_stable_layer_norm = do_stable_layer_norm
+ self.use_weighted_layer_sum = use_weighted_layer_sum
+ self.classifier_proj_size = classifier_proj_size
+
+ if (
+ (len(self.conv_stride) != self.num_feat_extract_layers)
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
+ ):
+ raise ValueError(
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ )
+
+ # fine-tuning config parameters for SpecAugment: https://huggingface.co/papers/1904.08779
+ self.apply_spec_augment = apply_spec_augment
+ self.mask_time_prob = mask_time_prob
+ self.mask_time_length = mask_time_length
+ self.mask_time_min_masks = mask_time_min_masks
+ self.mask_feature_prob = mask_feature_prob
+ self.mask_feature_length = mask_feature_length
+ self.mask_feature_min_masks = mask_feature_min_masks
+
+ # ctc loss
+ self.ctc_loss_reduction = ctc_loss_reduction
+ self.ctc_zero_infinity = ctc_zero_infinity
+
+ @property
+ def inputs_to_logits_ratio(self):
+ return functools.reduce(operator.mul, self.conv_stride, 1)
+
+
+__all__ = ["HubertConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hubert/modeling_hubert.py b/venv/lib/python3.13/site-packages/transformers/models/hubert/modeling_hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..060b715e8d499a13906092133dab2bdddae216df
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hubert/modeling_hubert.py
@@ -0,0 +1,1285 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/hubert/modular_hubert.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_hubert.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...integrations.fsdp import is_fsdp_managed_module
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
+from .configuration_hubert import HubertConfig
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+class HubertPositionalConvEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.num_conv_pos_embeddings,
+ padding=config.num_conv_pos_embeddings // 2,
+ groups=config.num_conv_pos_embedding_groups,
+ )
+
+ self.batch_norm = None
+ if config.conv_pos_batch_norm:
+ self.batch_norm = nn.BatchNorm1d(config.hidden_size)
+ else:
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
+
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
+ if hasattr(self.conv, "parametrizations"):
+ weight_g = self.conv.parametrizations.weight.original0
+ weight_v = self.conv.parametrizations.weight.original1
+ else:
+ weight_g = self.conv.weight_g
+ weight_v = self.conv.weight_v
+ deepspeed.zero.register_external_parameter(self, weight_v)
+ deepspeed.zero.register_external_parameter(self, weight_g)
+ else:
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
+
+ self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.transpose(1, 2)
+ if self.batch_norm is not None:
+ hidden_states = self.batch_norm(hidden_states)
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.padding(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class HubertSamePadLayer(nn.Module):
+ def __init__(self, num_conv_pos_embeddings):
+ super().__init__()
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
+
+ def forward(self, hidden_states):
+ if self.num_pad_remove > 0:
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
+ return hidden_states
+
+
+class HubertNoLayerNormConvLayer(GradientCheckpointingLayer):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class HubertLayerNormConvLayer(GradientCheckpointingLayer):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ hidden_states = hidden_states.transpose(-2, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(-2, -1)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class HubertGroupNormConvLayer(GradientCheckpointingLayer):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+class HubertFeatureEncoder(nn.Module):
+ """Construct the features from raw audio waveform"""
+
+ def __init__(self, config):
+ super().__init__()
+
+ if config.feat_extract_norm == "group":
+ conv_layers = [HubertGroupNormConvLayer(config, layer_id=0)] + [
+ HubertNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
+ ]
+ elif config.feat_extract_norm == "layer":
+ conv_layers = [HubertLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
+ else:
+ raise ValueError(
+ f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
+ )
+ self.conv_layers = nn.ModuleList(conv_layers)
+ self.gradient_checkpointing = False
+ self._requires_grad = True
+
+ def _freeze_parameters(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self._requires_grad = False
+
+ def forward(self, input_values):
+ hidden_states = input_values[:, None]
+
+ # make sure hidden_states require grad for gradient_checkpointing
+ if self._requires_grad and self.training:
+ hidden_states.requires_grad = True
+
+ for conv_layer in self.conv_layers:
+ hidden_states = conv_layer(hidden_states)
+
+ return hidden_states
+
+
+class HubertFeatureProjection(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.feat_proj_layer_norm = config.feat_proj_layer_norm
+ if self.feat_proj_layer_norm:
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
+
+ def forward(self, hidden_states):
+ # non-projected hidden states are needed for quantization
+ if self.feat_proj_layer_norm:
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.projection(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class HubertAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ is_causal: bool = False,
+ config: Optional[HubertConfig] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights, None
+
+
+class HubertFeedForward(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
+
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.intermediate_dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ hidden_states = self.intermediate_dropout(hidden_states)
+
+ hidden_states = self.output_dense(hidden_states)
+ hidden_states = self.output_dropout(hidden_states)
+ return hidden_states
+
+
+class HubertEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config):
+ super().__init__()
+ self.attention = HubertAttention(
+ embed_dim=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=False,
+ config=config,
+ )
+
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.feed_forward = HubertFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
+ attn_residual = hidden_states
+ hidden_states, attn_weights, _ = self.attention(
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
+ )
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = attn_residual + hidden_states
+
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states + self.feed_forward(hidden_states)
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class HubertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.pos_conv_embed = HubertPositionalConvEmbedding(config)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ # make sure padded tokens output 0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
+
+ position_embeddings = self.pos_conv_embed(hidden_states)
+ hidden_states = hidden_states + position_embeddings
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
+
+ for layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ dropout_probability = torch.rand([])
+
+ skip_the_layer = self.training and dropout_probability < self.config.layerdrop
+ if not skip_the_layer or synced_gpus:
+ # under fsdp or deepspeed zero3 all gpus must run in sync
+ layer_outputs = layer(
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
+ )
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+
+class HubertAttnAdapterLayer(nn.Module):
+ def __init__(self, config):
+ """
+ Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
+ up training throughput.
+ """
+ super().__init__()
+ self.input_dim = config.adapter_attn_dim
+ self.hidden_dim = config.hidden_size
+
+ self.norm = nn.LayerNorm(self.hidden_dim)
+ self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
+ self.act_fn = nn.ReLU()
+ self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
+
+ def forward(self, hidden_states: torch.FloatTensor):
+ hidden_states = self.norm(hidden_states)
+
+ hidden_states = self.linear_1(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+
+ return hidden_states
+
+
+class HubertEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
+ def __init__(self, config):
+ super().__init__()
+ self.attention = HubertAttention(
+ embed_dim=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=False,
+ config=config,
+ )
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.feed_forward = HubertFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ if getattr(config, "adapter_attn_dim", None) is not None:
+ self.adapter_layer = HubertAttnAdapterLayer(config)
+ else:
+ self.adapter_layer = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ):
+ attn_residual = hidden_states
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.attention(
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
+ )
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = attn_residual + hidden_states
+ hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
+
+ if self.adapter_layer is not None:
+ hidden_states = hidden_states + self.adapter_layer(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class HubertEncoderStableLayerNorm(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.pos_conv_embed = HubertPositionalConvEmbedding(config)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList(
+ [HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ # make sure padded tokens output 0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
+
+ position_embeddings = self.pos_conv_embed(hidden_states)
+ hidden_states = hidden_states + position_embeddings
+ hidden_states = self.dropout(hidden_states)
+
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
+
+ for layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ dropout_probability = torch.rand([])
+
+ skip_the_layer = self.training and dropout_probability < self.config.layerdrop
+ if not skip_the_layer or synced_gpus:
+ # under fsdp or deepspeed zero3 all gpus must run in sync
+ # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
+ layer_outputs = layer(
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
+ )
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+
+@auto_docstring
+class HubertPreTrainedModel(PreTrainedModel):
+ config: HubertConfig
+ base_model_prefix = "hubert"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
+ with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
+ nn.init.kaiming_normal_(module.weight.data)
+ else:
+ with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
+ nn.init.kaiming_normal_(module.weight.data)
+ else:
+ nn.init.kaiming_normal_(module.weight.data)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, HubertModel):
+ if hasattr(module, "masked_spec_embed"):
+ module.masked_spec_embed.data.uniform_()
+ elif isinstance(module, HubertForSequenceClassification):
+ if hasattr(module, "layer_weights"):
+ module.layer_weights.data.fill_(1.0 / (self.config.num_hidden_layers + 1))
+
+ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+ batch_size = attention_mask.shape[0]
+
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ # these two operations makes sure that all values before the output lengths idxs are attended to
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+
+def _compute_mask_indices(
+ shape: tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.detach().sum(-1).tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
+@auto_docstring
+class HubertModel(HubertPreTrainedModel):
+ def __init__(self, config: HubertConfig):
+ super().__init__(config)
+ self.config = config
+ self.feature_extractor = HubertFeatureEncoder(config)
+ self.feature_projection = HubertFeatureProjection(config)
+
+ # model only needs masking vector if mask prob is > 0.0
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
+
+ if config.do_stable_layer_norm:
+ self.encoder = HubertEncoderStableLayerNorm(config)
+ else:
+ self.encoder = HubertEncoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _mask_hidden_states(
+ self,
+ hidden_states: torch.FloatTensor,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ """
+ Masks extracted features along time axis and/or along feature axis according to
+ [SpecAugment](https://huggingface.co/papers/1904.08779).
+ """
+
+ # `config.apply_spec_augment` can set masking to False
+ if not getattr(self.config, "apply_spec_augment", True):
+ return hidden_states
+
+ # generate indices & apply SpecAugment along time axis
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ if mask_time_indices is not None:
+ # apply SpecAugment along time axis with given mask_time_indices
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+ elif self.config.mask_time_prob > 0 and self.training:
+ mask_time_indices = _compute_mask_indices(
+ (batch_size, sequence_length),
+ mask_prob=self.config.mask_time_prob,
+ mask_length=self.config.mask_time_length,
+ attention_mask=attention_mask,
+ min_masks=self.config.mask_time_min_masks,
+ )
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+
+ if self.config.mask_feature_prob > 0 and self.training:
+ # generate indices & apply SpecAugment along feature axis
+ mask_feature_indices = _compute_mask_indices(
+ (batch_size, hidden_size),
+ mask_prob=self.config.mask_feature_prob,
+ mask_length=self.config.mask_feature_length,
+ min_masks=self.config.mask_feature_min_masks,
+ )
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
+ hidden_states[mask_feature_indices] = 0
+
+ return hidden_states
+
+ @auto_docstring
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+ r"""
+ mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
+ masked extracted features in *config.proj_codevector_dim* space.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, HubertModel
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
+ >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
+
+
+ >>> def map_to_array(example):
+ ... example["speech"] = example["audio"]["array"]
+ ... return example
+
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> ds = ds.map(map_to_array)
+
+ >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
+ >>> hidden_states = model(input_values).last_hidden_state
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
+
+ hidden_states = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if not return_dict:
+ return (hidden_states,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+_HIDDEN_STATES_START_POSITION = 1
+
+
+@auto_docstring(
+ custom_intro="""
+ Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
+ """
+)
+class HubertForCTC(HubertPreTrainedModel):
+ def __init__(self, config, target_lang: Optional[str] = None):
+ r"""
+ target_lang (`str`, *optional*):
+ Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or
+ adapter..bin. Only relevant when using an instance of [`HubertForCTC`] with adapters. Uses 'eng' by
+ default.
+ """
+ super().__init__(config)
+
+ self.hubert = HubertModel(config)
+ self.dropout = nn.Dropout(config.final_dropout)
+
+ self.target_lang = target_lang
+
+ if config.vocab_size is None:
+ raise ValueError(
+ f"You are trying to instantiate {self.__class__} with a configuration that "
+ "does not define the vocabulary size of the language model head. Please "
+ "instantiate the model as follows: `HubertForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+ "or define `vocab_size` of your model's configuration."
+ )
+ output_hidden_size = (
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
+ )
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def tie_weights(self):
+ """
+ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
+ passing `target_lang=...` to `from_pretrained(...)`.
+
+ This method is **not** supposed to be called by the user and is prone to be changed in the future.
+ """
+
+ # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
+ # correctly load adapter layers for Hubert so that we do not have to introduce a new API to
+ # [`PreTrainedModel`]. While slightly hacky, Hubert never has to tie input and output embeddings, so that it is
+ # ok to repurpose this function here.
+ target_lang = self.target_lang
+
+ if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
+ raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
+ elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
+ logger.info("By default `target_lang` is set to 'eng'.")
+ elif target_lang is not None:
+ self.load_adapter(target_lang, force_load=True)
+
+ def freeze_feature_extractor(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ warnings.warn(
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
+ FutureWarning,
+ )
+ self.freeze_feature_encoder()
+
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.hubert.feature_extractor._freeze_parameters()
+
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.hubert.parameters():
+ param.requires_grad = False
+
+ @auto_docstring
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, CausalLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None and labels.max() >= self.config.vocab_size:
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+ outputs = self.hubert(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.dropout(hidden_states)
+
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # retrieve loss input_lengths from attention_mask
+ attention_mask = (
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
+ )
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+
+ # assuming that padded tokens are filled with -100
+ # when not being attended to
+ labels_mask = labels >= 0
+ target_lengths = labels_mask.sum(-1)
+ flattened_targets = labels.masked_select(labels_mask)
+
+ # ctc_loss doesn't support fp16
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = nn.functional.ctc_loss(
+ log_probs,
+ flattened_targets,
+ input_lengths,
+ target_lengths,
+ blank=self.config.pad_token_id,
+ reduction=self.config.ctc_loss_reduction,
+ zero_infinity=self.config.ctc_zero_infinity,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
+ SUPERB Keyword Spotting.
+ """
+)
+class HubertForSequenceClassification(HubertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, "add_adapter") and config.add_adapter:
+ raise ValueError(
+ "Sequence classification does not support the use of Hubert adapters (config.add_adapter=True)"
+ )
+ self.hubert = HubertModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def freeze_feature_extractor(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
+ not be updated during training.
+ """
+ warnings.warn(
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
+ FutureWarning,
+ )
+ self.freeze_feature_encoder()
+
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.hubert.feature_extractor._freeze_parameters()
+
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.hubert.parameters():
+ param.requires_grad = False
+
+ @auto_docstring
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, SequenceClassifierOutput]:
+ r"""
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+ into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
+ (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
+ To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
+ into a tensor of type `torch.FloatTensor`. See [`HubertProcessor.__call__`] for details.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.hubert(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+ if attention_mask is None:
+ pooled_output = hidden_states.mean(dim=1)
+ else:
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
+ expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_padding_mask] = 0.0
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["HubertForCTC", "HubertForSequenceClassification", "HubertModel", "HubertPreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hubert/modeling_tf_hubert.py b/venv/lib/python3.13/site-packages/transformers/models/hubert/modeling_tf_hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..45c05ff3073762a73ba1b8084af9899616ebe446
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hubert/modeling_tf_hubert.py
@@ -0,0 +1,1671 @@
+# coding=utf-8
+# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TensorFlow Hubert model."""
+
+from __future__ import annotations
+
+import warnings
+from typing import Any
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
+from ...modeling_tf_utils import (
+ TFPreTrainedModel,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list, stable_softmax
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_hubert import HubertConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "HubertConfig"
+
+
+LARGE_NEGATIVE = -1e8
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._sample_without_replacement
+def _sample_without_replacement(distribution, num_samples):
+ """
+ Categorical sampling without replacement is currently not implemented. The gumbel-max trick will do for now - see
+ https://github.com/tensorflow/tensorflow/issues/9260 for more info
+ """
+ z = -tf.math.log(tf.random.uniform(shape_list(distribution), 0, 1))
+ _, indices = tf.nn.top_k(distribution + z, num_samples)
+ return indices
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._scatter_values_on_batch_indices
+def _scatter_values_on_batch_indices(values, batch_indices, output_shape):
+ """
+ Scatter function as in PyTorch with indices in format (batch_dim, indices)
+ """
+ indices_shape = shape_list(batch_indices)
+ # broadcast batch dim to indices_shape
+ broad_casted_batch_dims = tf.reshape(
+ tf.broadcast_to(tf.expand_dims(tf.range(indices_shape[0]), axis=-1), indices_shape), [1, -1]
+ )
+ # transform batch_indices to pair_indices
+ pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
+ # scatter values to pair indices
+ return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), output_shape)
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._compute_mask_indices
+def _compute_mask_indices(
+ shape: tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ min_masks: int = 0,
+) -> tf.Tensor:
+ """
+ Computes random mask spans for a given shape
+
+ Args:
+ shape: the shape for which to compute masks.
+ should be of size 2 where first element is batch size and 2nd is timesteps
+ attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+ mask_prob:
+ probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+
+ Adapted from [fairseq's
+ data_utils.py](https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376).
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ tf.debugging.assert_less(
+ mask_length,
+ sequence_length,
+ message=(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
+ f" `sequence_length`: {sequence_length}`"
+ ),
+ )
+
+ # compute number of masked spans in batch
+ num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,))
+ num_masked_spans = tf.maximum(num_masked_spans, min_masks)
+ num_masked_spans = tf.cast(num_masked_spans, tf.int32)
+
+ # make sure num masked indices <= sequence_length
+ num_masked_spans = tf.math.minimum(sequence_length // mask_length, num_masked_spans)
+ num_masked_spans = tf.squeeze(num_masked_spans)
+
+ # SpecAugment mask to fill
+ spec_aug_mask = tf.zeros((batch_size, sequence_length), dtype=tf.int32)
+
+ # uniform distribution to sample from, make sure that offset samples are < sequence_length
+ uniform_dist = tf.ones((batch_size, sequence_length - (mask_length - 1)))
+
+ # get random indices to mask
+ spec_aug_mask_idxs = _sample_without_replacement(uniform_dist, num_masked_spans)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = tf.expand_dims(spec_aug_mask_idxs, -1)
+ spec_aug_mask_idxs = tf.tile(spec_aug_mask_idxs, (1, 1, mask_length))
+ spec_aug_mask_idxs = tf.reshape(spec_aug_mask_idxs, (batch_size, num_masked_spans * mask_length))
+
+ offsets = tf.range(mask_length)[tf.newaxis, tf.newaxis, :]
+ offsets = tf.tile(offsets, (batch_size, num_masked_spans, 1))
+ offsets = tf.reshape(offsets, (batch_size, num_masked_spans * mask_length))
+
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # scatter indices to mask
+ spec_aug_mask = _scatter_values_on_batch_indices(
+ tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, tf.shape(spec_aug_mask)
+ )
+
+ return spec_aug_mask
+
+
+# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
+def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ src_len = shape_list(mask)[1]
+ tgt_len = tgt_len if tgt_len is not None else src_len
+ one_cst = tf.constant(1.0)
+ mask = tf.cast(mask, dtype=one_cst.dtype)
+ expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
+
+ return (one_cst - expanded_mask) * LARGE_NEGATIVE
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2GroupNorm with Wav2Vec2->Hubert
+class TFHubertGroupNorm(keras.layers.Layer):
+ """
+ From tensorflow-addons https://www.tensorflow.org/addons/api_docs/python/tfa/layers/GroupNormalization
+ """
+
+ def __init__(
+ self,
+ groups: int = 32,
+ axis: int = -1,
+ epsilon: float = 1e-3,
+ center: bool = True,
+ scale: bool = True,
+ beta_initializer: keras.initializers.Initializer = "zeros",
+ gamma_initializer: keras.initializers.Initializer = "ones",
+ beta_regularizer: keras.regularizers.Regularizer = None,
+ gamma_regularizer: keras.regularizers.Regularizer = None,
+ beta_constraint: keras.constraints.Constraint = None,
+ gamma_constraint: keras.constraints.Constraint = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.supports_masking = True
+ self.groups = groups
+ self.axis = axis
+ self.epsilon = epsilon
+ self.center = center
+ self.scale = scale
+ self.beta_initializer = keras.initializers.get(beta_initializer)
+ self.gamma_initializer = keras.initializers.get(gamma_initializer)
+ self.beta_regularizer = keras.regularizers.get(beta_regularizer)
+ self.gamma_regularizer = keras.regularizers.get(gamma_regularizer)
+ self.beta_constraint = keras.constraints.get(beta_constraint)
+ self.gamma_constraint = keras.constraints.get(gamma_constraint)
+ self._check_axis()
+
+ def build(self, input_shape):
+ self._check_if_input_shape_is_none(input_shape)
+ self._set_number_of_groups_for_instance_norm(input_shape)
+ self._check_size_of_dimensions(input_shape)
+ self._create_input_spec(input_shape)
+
+ self._add_gamma_weight(input_shape)
+ self._add_beta_weight(input_shape)
+ self.built = True
+ super().build(input_shape)
+
+ def call(self, inputs):
+ input_shape = keras.backend.int_shape(inputs)
+ tensor_input_shape = tf.shape(inputs)
+
+ reshaped_inputs, group_shape = self._reshape_into_groups(inputs, input_shape, tensor_input_shape)
+
+ normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape)
+
+ is_instance_norm = (input_shape[self.axis] // self.groups) == 1
+ if not is_instance_norm:
+ outputs = tf.reshape(normalized_inputs, tensor_input_shape)
+ else:
+ outputs = normalized_inputs
+
+ return outputs
+
+ def get_config(self):
+ config = {
+ "groups": self.groups,
+ "axis": self.axis,
+ "epsilon": self.epsilon,
+ "center": self.center,
+ "scale": self.scale,
+ "beta_initializer": keras.initializers.serialize(self.beta_initializer),
+ "gamma_initializer": keras.initializers.serialize(self.gamma_initializer),
+ "beta_regularizer": keras.regularizers.serialize(self.beta_regularizer),
+ "gamma_regularizer": keras.regularizers.serialize(self.gamma_regularizer),
+ "beta_constraint": keras.constraints.serialize(self.beta_constraint),
+ "gamma_constraint": keras.constraints.serialize(self.gamma_constraint),
+ }
+ base_config = super().get_config()
+ return {**base_config, **config}
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+ def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape):
+ group_shape = [tensor_input_shape[i] for i in range(len(input_shape))]
+ is_instance_norm = (input_shape[self.axis] // self.groups) == 1
+ if not is_instance_norm:
+ group_shape[self.axis] = input_shape[self.axis] // self.groups
+ group_shape.insert(self.axis, self.groups)
+ group_shape = tf.stack(group_shape)
+ reshaped_inputs = tf.reshape(inputs, group_shape)
+ return reshaped_inputs, group_shape
+ else:
+ return inputs, group_shape
+
+ def _apply_normalization(self, reshaped_inputs, input_shape):
+ group_shape = keras.backend.int_shape(reshaped_inputs)
+ group_reduction_axes = list(range(1, len(group_shape)))
+ is_instance_norm = (input_shape[self.axis] // self.groups) == 1
+ if not is_instance_norm:
+ axis = -2 if self.axis == -1 else self.axis - 1
+ else:
+ axis = -1 if self.axis == -1 else self.axis - 1
+ group_reduction_axes.pop(axis)
+
+ mean, variance = tf.nn.moments(reshaped_inputs, group_reduction_axes, keepdims=True)
+
+ gamma, beta = self._get_reshaped_weights(input_shape)
+ normalized_inputs = tf.nn.batch_normalization(
+ reshaped_inputs,
+ mean=mean,
+ variance=variance,
+ scale=gamma,
+ offset=beta,
+ variance_epsilon=self.epsilon,
+ )
+ return normalized_inputs
+
+ def _get_reshaped_weights(self, input_shape):
+ broadcast_shape = self._create_broadcast_shape(input_shape)
+ gamma = None
+ beta = None
+ if self.scale:
+ gamma = tf.reshape(self.gamma, broadcast_shape)
+
+ if self.center:
+ beta = tf.reshape(self.beta, broadcast_shape)
+ return gamma, beta
+
+ def _check_if_input_shape_is_none(self, input_shape):
+ dim = input_shape[self.axis]
+ if dim is None:
+ raise ValueError(
+ "Axis "
+ + str(self.axis)
+ + " of input tensor should have a defined dimension but the layer received an input with shape "
+ + str(input_shape)
+ + "."
+ )
+
+ def _set_number_of_groups_for_instance_norm(self, input_shape):
+ dim = input_shape[self.axis]
+
+ if self.groups == -1:
+ self.groups = dim
+
+ def _check_size_of_dimensions(self, input_shape):
+ dim = input_shape[self.axis]
+ if dim < self.groups:
+ raise ValueError(
+ "Number of groups ("
+ + str(self.groups)
+ + ") cannot be more than the number of channels ("
+ + str(dim)
+ + ")."
+ )
+
+ if dim % self.groups != 0:
+ raise ValueError(
+ "Number of groups ("
+ + str(self.groups)
+ + ") must be a multiple of the number of channels ("
+ + str(dim)
+ + ")."
+ )
+
+ def _check_axis(self):
+ if self.axis == 0:
+ raise ValueError(
+ "You are trying to normalize your batch axis. Do you want to use tf.layer.batch_normalization instead"
+ )
+
+ def _create_input_spec(self, input_shape):
+ dim = input_shape[self.axis]
+ self.input_spec = keras.layers.InputSpec(ndim=len(input_shape), axes={self.axis: dim})
+
+ def _add_gamma_weight(self, input_shape):
+ dim = input_shape[self.axis]
+ shape = (dim,)
+
+ if self.scale:
+ self.gamma = self.add_weight(
+ shape=shape,
+ name="gamma",
+ initializer=self.gamma_initializer,
+ regularizer=self.gamma_regularizer,
+ constraint=self.gamma_constraint,
+ )
+ else:
+ self.gamma = None
+
+ def _add_beta_weight(self, input_shape):
+ dim = input_shape[self.axis]
+ shape = (dim,)
+
+ if self.center:
+ self.beta = self.add_weight(
+ shape=shape,
+ name="beta",
+ initializer=self.beta_initializer,
+ regularizer=self.beta_regularizer,
+ constraint=self.beta_constraint,
+ )
+ else:
+ self.beta = None
+
+ def _create_broadcast_shape(self, input_shape):
+ broadcast_shape = [1] * len(input_shape)
+ is_instance_norm = (input_shape[self.axis] // self.groups) == 1
+ if not is_instance_norm:
+ broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
+ broadcast_shape.insert(self.axis, self.groups)
+ else:
+ broadcast_shape[self.axis] = self.groups
+ return broadcast_shape
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2WeightNormConv1D with Wav2Vec2->Hubert
+class TFHubertWeightNormConv1D(keras.layers.Conv1D):
+ """Adapted from https://www.tensorflow.org/probability/api_docs/python/tfp/layers/weight_norm/WeightNorm"""
+
+ def __init__(self, filters, kernel_size, groups, explicit_padding, **kwargs):
+ super().__init__(
+ filters=filters,
+ kernel_size=kernel_size,
+ groups=groups,
+ padding="valid",
+ use_bias=True,
+ bias_initializer="he_normal",
+ **kwargs,
+ )
+ self.explicit_padding = explicit_padding
+ self.filter_axis = 2
+ self.kernel_norm_axes = tf.constant([0, 1])
+
+ def _init_norm(self):
+ """Set the norm of the weight vector."""
+ kernel_norm = tf.sqrt(tf.reduce_sum(tf.square(self.weight_v), axis=self.kernel_norm_axes))
+ self.weight_g.assign(kernel_norm[:, tf.newaxis, tf.newaxis])
+
+ def _normalize_kernel(self):
+ """Generate normalized weights."""
+ kernel = tf.nn.l2_normalize(self.weight_v, axis=self.kernel_norm_axes) * tf.transpose(self.weight_g)
+ self.kernel = tf.transpose(kernel)
+
+ def build(self, input_shape):
+ if not self.built:
+ super().build(input_shape)
+
+ self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True)
+ self.weight_v = self.kernel
+
+ self.weight_g = self.add_weight(
+ name="weight_g",
+ shape=(int(self.weight_v.shape[self.filter_axis]), 1, 1),
+ initializer="ones",
+ dtype=self.weight_v.dtype,
+ trainable=True,
+ )
+ self._init_norm()
+ self.bias = self.add_weight(name="bias", shape=(self.filters,), initializer="zeros", trainable=True)
+
+ def call(self, inputs):
+ # TODO Matt: Assigning to attributes in call() is deeply sinful in TensorFlow, as it should be idempotent.
+ # This whole layer should be replaced by a layer that doesn't inherit from Conv1D, but instead calls
+ # a functional 1d convolution with normalized weights that it generates (but does not store!)
+ self._normalize_kernel()
+
+ padded_inputs = tf.pad(inputs, ((0, 0), (self.explicit_padding, self.explicit_padding), (0, 0)))
+ output = super().call(padded_inputs)
+
+ return output
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2NoLayerNormConvLayer with Wav2Vec2->Hubert
+class TFHubertNoLayerNormConvLayer(keras.layers.Layer):
+ def __init__(self, config: HubertConfig, layer_id: int = 0, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
+ self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = keras.layers.Conv1D(
+ filters=self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ strides=config.conv_stride[layer_id],
+ use_bias=config.conv_bias,
+ name="conv",
+ )
+ self.activation = get_tf_activation(config.feat_extract_activation)
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "conv", None) is not None:
+ with tf.name_scope(self.conv.name):
+ self.conv.build([None, None, self.in_conv_dim])
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2LayerNormConvLayer with Wav2Vec2->Hubert
+class TFHubertLayerNormConvLayer(keras.layers.Layer):
+ def __init__(self, config: HubertConfig, layer_id: int = 0, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
+ self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = keras.layers.Conv1D(
+ filters=self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ strides=config.conv_stride[layer_id],
+ use_bias=config.conv_bias,
+ name="conv",
+ )
+ self.layer_norm = keras.layers.LayerNormalization(name="layer_norm", epsilon=config.layer_norm_eps)
+ self.activation = get_tf_activation(config.feat_extract_activation)
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "conv", None) is not None:
+ with tf.name_scope(self.conv.name):
+ self.conv.build([None, None, self.in_conv_dim])
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.out_conv_dim])
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2GroupNormConvLayer with Wav2Vec2->Hubert
+class TFHubertGroupNormConvLayer(keras.layers.Layer):
+ def __init__(self, config: HubertConfig, layer_id: int = 0, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
+ self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = keras.layers.Conv1D(
+ filters=self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ strides=config.conv_stride[layer_id],
+ use_bias=config.conv_bias,
+ name="conv",
+ )
+ self.activation = get_tf_activation(config.feat_extract_activation)
+ self.layer_norm = TFHubertGroupNorm(groups=self.out_conv_dim, epsilon=config.layer_norm_eps, name="layer_norm")
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "conv", None) is not None:
+ with tf.name_scope(self.conv.name):
+ self.conv.build([None, None, self.in_conv_dim])
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.out_conv_dim])
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2PositionalConvEmbedding with Wav2Vec2->Hubert
+class TFHubertPositionalConvEmbedding(keras.layers.Layer):
+ def __init__(self, config: HubertConfig, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
+ self.conv = TFHubertWeightNormConv1D(
+ filters=config.hidden_size,
+ kernel_size=config.num_conv_pos_embeddings,
+ groups=config.num_conv_pos_embedding_groups,
+ explicit_padding=config.num_conv_pos_embeddings // 2,
+ name="conv",
+ )
+ self.padding = TFHubertSamePadLayer(config.num_conv_pos_embeddings)
+ self.activation = get_tf_activation(config.feat_extract_activation)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.padding(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "conv", None) is not None:
+ with tf.name_scope(self.conv.name):
+ self.conv.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2SamePadLayer with Wav2Vec2->Hubert
+class TFHubertSamePadLayer(keras.layers.Layer):
+ def __init__(self, num_conv_pos_embeddings, **kwargs):
+ super().__init__(**kwargs)
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
+
+ def call(self, hidden_states):
+ if self.num_pad_remove > 0:
+ hidden_states = hidden_states[:, : -self.num_pad_remove, :]
+ return hidden_states
+
+
+class TFHubertFeatureEncoder(keras.layers.Layer):
+ def __init__(self, config: HubertConfig, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
+
+ if config.feat_extract_norm == "group":
+ conv_layers = [TFHubertGroupNormConvLayer(config, layer_id=0, name=f"conv_layers.{0}")] + [
+ TFHubertNoLayerNormConvLayer(config, layer_id=i + 1, name=f"conv_layers.{i + 1}")
+ for i in range(config.num_feat_extract_layers - 1)
+ ]
+ elif config.feat_extract_norm == "layer":
+ conv_layers = [
+ TFHubertLayerNormConvLayer(config, layer_id=i, name=f"conv_layers.{i}")
+ for i in range(config.num_feat_extract_layers)
+ ]
+ else:
+ raise ValueError(
+ f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
+ )
+ self.conv_layers = conv_layers
+
+ def call(self, input_values):
+ hidden_states = tf.expand_dims(input_values, -1)
+ for conv_layer in self.conv_layers:
+ hidden_states = conv_layer(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ for conv_layer in self.conv_layers:
+ with tf.name_scope(conv_layer.name):
+ conv_layer.build(None)
+
+
+class TFHubertFeatureExtractor(TFHubertFeatureEncoder):
+ def __init__(self, config, **kwargs):
+ super().__init__(config, **kwargs)
+ warnings.warn(
+ f"The class `{self.__class__.__name__}` has been depreciated "
+ "and will be removed in Transformers v5. "
+ f"Use `{self.__class__.__bases__[0].__name__}` instead.",
+ FutureWarning,
+ )
+
+
+class TFHubertFeatureProjection(keras.layers.Layer):
+ def __init__(self, config: HubertConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+ self.projection = keras.layers.Dense(
+ units=config.hidden_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="projection",
+ )
+ self.dropout = keras.layers.Dropout(rate=config.feat_proj_dropout)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.projection(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.conv_dim[-1]])
+ if getattr(self, "projection", None) is not None:
+ with tf.name_scope(self.projection.name):
+ self.projection.build([None, None, self.config.conv_dim[-1]])
+
+
+# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with TFBart->TFHubert
+class TFHubertAttention(keras.layers.Layer):
+ """Multi-headed attention from "Attention Is All You Need"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = keras.layers.Dropout(dropout)
+ self.head_dim = embed_dim // num_heads
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
+ self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
+ self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
+ self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
+
+ def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
+ return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ key_value_states: tf.Tensor | None = None,
+ past_key_value: tuple[tuple[tf.Tensor]] | None = None,
+ attention_mask: tf.Tensor | None = None,
+ layer_head_mask: tf.Tensor | None = None,
+ training: bool | None = False,
+ ) -> tuple[tf.Tensor, tf.Tensor | None]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ bsz, tgt_len, embed_dim = shape_list(hidden_states)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = tf.concat([past_key_value[0], key_states], axis=2)
+ value_states = tf.concat([past_key_value[1], value_states], axis=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
+ key_states = tf.reshape(key_states, proj_shape)
+ value_states = tf.reshape(value_states, proj_shape)
+
+ src_len = shape_list(key_states)[1]
+ attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
+
+ tf.debugging.assert_equal(
+ shape_list(attn_weights),
+ [bsz * self.num_heads, tgt_len, src_len],
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
+ )
+
+ if attention_mask is not None:
+ tf.debugging.assert_equal(
+ shape_list(attention_mask),
+ [bsz, 1, tgt_len, src_len],
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
+ )
+
+ attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
+ attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+ attn_weights = stable_softmax(attn_weights, axis=-1)
+
+ if layer_head_mask is not None:
+ tf.debugging.assert_equal(
+ shape_list(layer_head_mask),
+ [self.num_heads],
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
+ )
+
+ attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
+ attn_weights, (bsz, self.num_heads, tgt_len, src_len)
+ )
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+ attn_probs = self.dropout(attn_weights, training=training)
+ attn_output = tf.matmul(attn_probs, value_states)
+
+ tf.debugging.assert_equal(
+ shape_list(attn_output),
+ [bsz * self.num_heads, tgt_len, self.head_dim],
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
+ )
+
+ attn_output = tf.transpose(
+ tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
+ )
+ attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
+
+ attn_output = self.out_proj(attn_output)
+ attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
+
+ return attn_output, attn_weights, past_key_value
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "k_proj", None) is not None:
+ with tf.name_scope(self.k_proj.name):
+ self.k_proj.build([None, None, self.embed_dim])
+ if getattr(self, "q_proj", None) is not None:
+ with tf.name_scope(self.q_proj.name):
+ self.q_proj.build([None, None, self.embed_dim])
+ if getattr(self, "v_proj", None) is not None:
+ with tf.name_scope(self.v_proj.name):
+ self.v_proj.build([None, None, self.embed_dim])
+ if getattr(self, "out_proj", None) is not None:
+ with tf.name_scope(self.out_proj.name):
+ self.out_proj.build([None, None, self.embed_dim])
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2FeedForward with Wav2Vec2->Hubert
+class TFHubertFeedForward(keras.layers.Layer):
+ def __init__(self, config: HubertConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.intermediate_dropout = keras.layers.Dropout(config.activation_dropout)
+
+ self.intermediate_dense = keras.layers.Dense(
+ units=config.intermediate_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="intermediate_dense",
+ )
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
+
+ self.output_dense = keras.layers.Dense(
+ units=config.hidden_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="output_dense",
+ )
+ self.output_dropout = keras.layers.Dropout(config.hidden_dropout)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.intermediate_dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ hidden_states = self.intermediate_dropout(hidden_states, training=training)
+
+ hidden_states = self.output_dense(hidden_states)
+ hidden_states = self.output_dropout(hidden_states, training=training)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "intermediate_dense", None) is not None:
+ with tf.name_scope(self.intermediate_dense.name):
+ self.intermediate_dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "output_dense", None) is not None:
+ with tf.name_scope(self.output_dense.name):
+ self.output_dense.build([None, None, self.config.intermediate_size])
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2EncoderLayer with Wav2Vec2->Hubert
+class TFHubertEncoderLayer(keras.layers.Layer):
+ def __init__(self, config: HubertConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.attention = TFHubertAttention(
+ embed_dim=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=False,
+ name="attention",
+ )
+ self.dropout = keras.layers.Dropout(config.hidden_dropout)
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+ self.feed_forward = TFHubertFeedForward(config, name="feed_forward")
+ self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm")
+ self.config = config
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = False,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ attn_residual = hidden_states
+ hidden_states, attn_weights, _ = self.attention(
+ hidden_states, attention_mask=attention_mask, training=training
+ )
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = attn_residual + hidden_states
+
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states + self.feed_forward(hidden_states)
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.hidden_size])
+ if getattr(self, "feed_forward", None) is not None:
+ with tf.name_scope(self.feed_forward.name):
+ self.feed_forward.build(None)
+ if getattr(self, "final_layer_norm", None) is not None:
+ with tf.name_scope(self.final_layer_norm.name):
+ self.final_layer_norm.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->Hubert
+class TFHubertEncoderLayerStableLayerNorm(keras.layers.Layer):
+ def __init__(self, config: HubertConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.attention = TFHubertAttention(
+ embed_dim=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=False,
+ name="attention",
+ )
+ self.dropout = keras.layers.Dropout(config.hidden_dropout)
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+ self.feed_forward = TFHubertFeedForward(config, name="feed_forward")
+ self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm")
+ self.config = config
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = False,
+ training: bool = False,
+ ) -> tuple[tf.Tensor]:
+ attn_residual = hidden_states
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.attention(
+ hidden_states, attention_mask=attention_mask, training=training
+ )
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = attn_residual + hidden_states
+ hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.hidden_size])
+ if getattr(self, "feed_forward", None) is not None:
+ with tf.name_scope(self.feed_forward.name):
+ self.feed_forward.build(None)
+ if getattr(self, "final_layer_norm", None) is not None:
+ with tf.name_scope(self.final_layer_norm.name):
+ self.final_layer_norm.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2Encoder with Wav2Vec2->Hubert
+class TFHubertEncoder(keras.layers.Layer):
+ def __init__(self, config: HubertConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.pos_conv_embed = TFHubertPositionalConvEmbedding(config, name="pos_conv_embed")
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+ self.dropout = keras.layers.Dropout(config.hidden_dropout)
+ self.layer = [TFHubertEncoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)]
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = False,
+ output_hidden_states: bool | None = False,
+ return_dict: bool | None = True,
+ training: bool | None = False,
+ ) -> TFBaseModelOutput | tuple[tf.Tensor]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ hidden_states = hidden_states * tf.expand_dims(attention_mask, -1)
+ attention_mask = _expand_mask(attention_mask)
+ else:
+ attention_mask = None
+
+ position_embeddings = self.pos_conv_embed(hidden_states)
+ hidden_states = hidden_states + position_embeddings
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ dropout_probability = np.random.uniform(0, 1)
+ if training and (dropout_probability < self.config.layerdrop): # skip the layer
+ continue
+
+ layer_outputs = layer_module(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "pos_conv_embed", None) is not None:
+ with tf.name_scope(self.pos_conv_embed.name):
+ self.pos_conv_embed.build(None)
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.hidden_size])
+ if getattr(self, "layer", None) is not None:
+ for layer in self.layer:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2EncoderStableLayerNorm with Wav2Vec2->Hubert
+class TFHubertEncoderStableLayerNorm(keras.layers.Layer):
+ def __init__(self, config: HubertConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.pos_conv_embed = TFHubertPositionalConvEmbedding(config, name="pos_conv_embed")
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+ self.dropout = keras.layers.Dropout(config.hidden_dropout)
+ self.layer = [
+ TFHubertEncoderLayerStableLayerNorm(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)
+ ]
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor | None = None,
+ output_attentions: bool | None = False,
+ output_hidden_states: bool | None = False,
+ return_dict: bool | None = True,
+ training: bool | None = False,
+ ) -> TFBaseModelOutput | tuple[tf.Tensor]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ hidden_states = hidden_states * tf.expand_dims(attention_mask, -1)
+ attention_mask = _expand_mask(attention_mask)
+ else:
+ attention_mask = None
+
+ position_embeddings = self.pos_conv_embed(hidden_states)
+ hidden_states = hidden_states + position_embeddings
+ hidden_states = self.dropout(hidden_states, training=training)
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+ dropout_probability = np.random.uniform(0, 1)
+ if training and (dropout_probability < self.config.layerdrop): # skip the layer
+ continue
+
+ layer_outputs = layer_module(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "pos_conv_embed", None) is not None:
+ with tf.name_scope(self.pos_conv_embed.name):
+ self.pos_conv_embed.build(None)
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.hidden_size])
+ if getattr(self, "layer", None) is not None:
+ for layer in self.layer:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+@keras_serializable
+class TFHubertMainLayer(keras.layers.Layer):
+ config_class = HubertConfig
+
+ def __init__(self, config: HubertConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.feature_extractor = TFHubertFeatureEncoder(config, name="feature_extractor")
+ self.feature_projection = TFHubertFeatureProjection(config, name="feature_projection")
+
+ if config.do_stable_layer_norm:
+ self.encoder = TFHubertEncoderStableLayerNorm(config, name="encoder")
+ else:
+ self.encoder = TFHubertEncoder(config, name="encoder")
+
+ def build(self, input_shape=None):
+ self.masked_spec_embed = self.add_weight(
+ shape=(self.config.hidden_size,), initializer="uniform", trainable=True, name="masked_spec_embed"
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "feature_extractor", None) is not None:
+ with tf.name_scope(self.feature_extractor.name):
+ self.feature_extractor.build(None)
+ if getattr(self, "feature_projection", None) is not None:
+ with tf.name_scope(self.feature_projection.name):
+ self.feature_projection.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+
+ def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return (input_length - kernel_size) // stride + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+ return input_lengths
+
+ def _mask_hidden_states(self, hidden_states: tf.Tensor, mask_time_indices: tf.Tensor | None = None):
+ """
+ Masks extracted features along time axis and/or along feature axis according to
+ [SpecAugment](https://huggingface.co/papers/1904.08779).
+ """
+ batch_size, sequence_length, hidden_size = shape_list(hidden_states)
+
+ # `config.apply_spec_augment` can set masking to False
+ if not getattr(self.config, "apply_spec_augment", True):
+ return hidden_states
+
+ if mask_time_indices is not None:
+ # apply SpecAugment along time axis with given mask_time_indices
+ hidden_states = tf.where(
+ tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
+ self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
+ hidden_states,
+ )
+
+ elif self.config.mask_time_prob > 0:
+ # generate indices & apply SpecAugment along time axis
+ mask_time_indices = _compute_mask_indices(
+ (batch_size, sequence_length),
+ mask_prob=self.config.mask_time_prob,
+ mask_length=self.config.mask_time_length,
+ min_masks=2,
+ )
+ hidden_states = tf.where(
+ tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
+ self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
+ hidden_states,
+ )
+
+ # apply SpecAugment along feature axis
+ if self.config.mask_feature_prob > 0:
+ mask_feature_indices = _compute_mask_indices(
+ (batch_size, hidden_size),
+ mask_prob=self.config.mask_feature_prob,
+ mask_length=self.config.mask_feature_length,
+ )
+ hidden_states = tf.where(mask_feature_indices[:, tf.newaxis, :], hidden_states, 0)
+
+ return hidden_states
+
+ @unpack_inputs
+ def call(
+ self,
+ input_values: tf.Tensor,
+ attention_mask: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ position_ids: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: tf.Tensor | None = None,
+ output_hidden_states: tf.Tensor | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ **kwargs: Any,
+ ):
+ hidden_states = self.feature_extractor(tf.cast(input_values, tf.float32), training=training)
+
+ if attention_mask is not None:
+ # compute real output lengths according to convolution formula
+ output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1))
+
+ attention_mask = tf.sequence_mask(
+ output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype
+ )
+
+ hidden_states = self.feature_projection(hidden_states, training=training)
+
+ mask_time_indices = kwargs.get("mask_time_indices")
+ if training:
+ hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_states = encoder_outputs[0]
+
+ if not return_dict:
+ return (hidden_states,) + encoder_outputs[1:]
+
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class TFHubertPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = HubertConfig
+ base_model_prefix = "hubert"
+ main_input_name = "input_values"
+
+ @property
+ def input_signature(self):
+ return {
+ "input_values": tf.TensorSpec((None, 16000), tf.float32, name="input_values"),
+ "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
+ "token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
+ }
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ logger.warning(
+ f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
+ "to train/fine-tune this model, you need a GPU or a TPU"
+ )
+
+
+HUBERT_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_values` only and nothing else: `model(input_values)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_values, attention_mask])` or `model([input_values, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_values": input_values, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Args:
+ config ([`HubertConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+HUBERT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_values` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_values` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False``):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare TFHubert Model transformer outputting raw hidden-states without any specific head on top.",
+ HUBERT_START_DOCSTRING,
+)
+class TFHubertModel(TFHubertPreTrainedModel):
+ def __init__(self, config: HubertConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.config = config
+ self.hubert = TFHubertMainLayer(config, name="hubert")
+
+ @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
+ @unpack_inputs
+ def call(
+ self,
+ input_values: tf.Tensor,
+ attention_mask: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ position_ids: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFBaseModelOutput | tuple[tf.Tensor]:
+ """
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, TFHubertModel
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
+ >>> model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
+
+
+ >>> def map_to_array(example):
+ ... example["speech"] = example["audio"]["array"]
+ ... return example
+
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> ds = ds.map(map_to_array)
+
+ >>> input_values = processor(ds["speech"][0], return_tensors="tf").input_values # Batch size 1
+ >>> hidden_states = model(input_values).last_hidden_state
+ ```"""
+
+ output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
+ output_attentions = output_attentions if output_attentions else self.config.output_attentions
+ return_dict = return_dict if return_dict else self.config.return_dict
+
+ outputs = self.hubert(
+ input_values=input_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "hubert", None) is not None:
+ with tf.name_scope(self.hubert.name):
+ self.hubert.build(None)
+
+
+@add_start_docstrings(
+ """TFHubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ HUBERT_START_DOCSTRING,
+)
+class TFHubertForCTC(TFHubertPreTrainedModel):
+ def __init__(self, config: HubertConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.hubert = TFHubertMainLayer(config, name="hubert")
+ self.dropout = keras.layers.Dropout(config.final_dropout)
+ self.lm_head = keras.layers.Dense(config.vocab_size, name="lm_head")
+ self.output_hidden_size = (
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
+ )
+
+ def freeze_feature_extractor(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
+ not be updated during training.
+ """
+ warnings.warn(
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
+ FutureWarning,
+ )
+ self.freeze_feature_encoder()
+
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.hubert.feature_extractor.trainable = False
+
+ @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
+ @unpack_inputs
+ def call(
+ self,
+ input_values: tf.Tensor,
+ attention_mask: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ position_ids: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ labels: tf.Tensor | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool | None = False,
+ ) -> TFCausalLMOutput | tuple[tf.Tensor]:
+ r"""
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_values` docstring) Tokens with indices set to `-100` are ignored (masked),
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import tensorflow as tf
+ >>> from transformers import AutoProcessor, TFHubertForCTC
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
+ >>> model = TFHubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
+
+
+ >>> def map_to_array(example):
+ ... example["speech"] = example["audio"]["array"]
+ ... return example
+
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> ds = ds.map(map_to_array)
+
+ >>> input_values = processor(ds["speech"][0], return_tensors="tf").input_values # Batch size 1
+ >>> logits = model(input_values).logits
+ >>> predicted_ids = tf.argmax(logits, axis=-1)
+
+ >>> transcription = processor.decode(predicted_ids[0])
+
+ >>> # compute loss
+ >>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
+
+ >>> # Pass the transcription as text to encode labels
+ >>> labels = processor(text=transcription, return_tensors="tf").input_values
+
+ >>> loss = model(input_values, labels=labels).loss
+ ```"""
+ if labels is not None and tf.reduce_max(labels) >= self.config.vocab_size:
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+ outputs = self.hubert(
+ input_values=input_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_states = outputs[0]
+ hidden_states = self.dropout(hidden_states, training=training)
+
+ logits = self.lm_head(hidden_states)
+
+ if labels is not None:
+ attention_mask = (
+ attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32)
+ )
+ input_lengths = self.hubert._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1))
+
+ # assuming that padded tokens are filled with -100
+ # when not being attended to
+ labels_mask = tf.cast(labels >= 0, tf.int32)
+ target_lengths = tf.reduce_sum(labels_mask, axis=-1)
+
+ loss = tf.nn.ctc_loss(
+ logits=logits,
+ labels=labels,
+ logit_length=input_lengths,
+ label_length=target_lengths,
+ blank_index=self.config.pad_token_id,
+ logits_time_major=False,
+ )
+
+ if self.config.ctc_loss_reduction == "sum":
+ loss = tf.reduce_sum(loss)
+ loss = tf.reshape(loss, (1,))
+ if self.config.ctc_loss_reduction == "mean":
+ loss = tf.reduce_mean(loss)
+ loss = tf.reshape(loss, (1,))
+ else:
+ loss = None
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFCausalLMOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "hubert", None) is not None:
+ with tf.name_scope(self.hubert.name):
+ self.hubert.build(None)
+ if getattr(self, "lm_head", None) is not None:
+ with tf.name_scope(self.lm_head.name):
+ self.lm_head.build([None, None, self.output_hidden_size])
+
+
+__all__ = ["TFHubertForCTC", "TFHubertModel", "TFHubertPreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hubert/modular_hubert.py b/venv/lib/python3.13/site-packages/transformers/models/hubert/modular_hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..facebcf445e6bdfd75c7880aefad84e136f4da88
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hubert/modular_hubert.py
@@ -0,0 +1,302 @@
+# coding=utf-8
+# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Hubert model."""
+
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...activations import ACT2FN
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring
+from ..wav2vec2.modeling_wav2vec2 import (
+ Wav2Vec2Encoder,
+ Wav2Vec2EncoderStableLayerNorm,
+ Wav2Vec2FeatureEncoder,
+ Wav2Vec2ForCTC,
+ Wav2Vec2ForSequenceClassification,
+ Wav2Vec2Model,
+ Wav2Vec2SamePadLayer,
+)
+from .configuration_hubert import HubertConfig
+
+
+_HIDDEN_STATES_START_POSITION = 1
+
+
+class HubertPositionalConvEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.num_conv_pos_embeddings,
+ padding=config.num_conv_pos_embeddings // 2,
+ groups=config.num_conv_pos_embedding_groups,
+ )
+
+ self.batch_norm = None
+ if config.conv_pos_batch_norm:
+ self.batch_norm = nn.BatchNorm1d(config.hidden_size)
+ else:
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
+
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
+ if hasattr(self.conv, "parametrizations"):
+ weight_g = self.conv.parametrizations.weight.original0
+ weight_v = self.conv.parametrizations.weight.original1
+ else:
+ weight_g = self.conv.weight_g
+ weight_v = self.conv.weight_v
+ deepspeed.zero.register_external_parameter(self, weight_v)
+ deepspeed.zero.register_external_parameter(self, weight_g)
+ else:
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
+
+ self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.transpose(1, 2)
+ if self.batch_norm is not None:
+ hidden_states = self.batch_norm(hidden_states)
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.padding(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class HubertSamePadLayer(Wav2Vec2SamePadLayer):
+ pass
+
+
+class HubertFeatureEncoder(Wav2Vec2FeatureEncoder):
+ pass
+
+
+class HubertFeatureProjection(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.feat_proj_layer_norm = config.feat_proj_layer_norm
+ if self.feat_proj_layer_norm:
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
+
+ def forward(self, hidden_states):
+ # non-projected hidden states are needed for quantization
+ if self.feat_proj_layer_norm:
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.projection(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class HubertEncoder(Wav2Vec2Encoder):
+ pass
+
+
+class HubertEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm):
+ pass
+
+
+@auto_docstring
+class HubertPreTrainedModel(PreTrainedModel):
+ config: HubertConfig
+ base_model_prefix = "hubert"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
+ with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
+ nn.init.kaiming_normal_(module.weight.data)
+ else:
+ with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
+ nn.init.kaiming_normal_(module.weight.data)
+ else:
+ nn.init.kaiming_normal_(module.weight.data)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, HubertModel):
+ if hasattr(module, "masked_spec_embed"):
+ module.masked_spec_embed.data.uniform_()
+ elif isinstance(module, HubertForSequenceClassification):
+ if hasattr(module, "layer_weights"):
+ module.layer_weights.data.fill_(1.0 / (self.config.num_hidden_layers + 1))
+
+ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+ batch_size = attention_mask.shape[0]
+
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ # these two operations makes sure that all values before the output lengths idxs are attended to
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+
+class HubertModel(Wav2Vec2Model, HubertPreTrainedModel):
+ def __init__(self, config: HubertConfig):
+ super().__init__(config)
+ self.config = config
+ self.feature_extractor = HubertFeatureEncoder(config)
+ self.feature_projection = HubertFeatureProjection(config)
+
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
+
+ if config.do_stable_layer_norm:
+ self.encoder = HubertEncoderStableLayerNorm(config)
+ else:
+ self.encoder = HubertEncoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ del self.adapter
+
+ def freeze_feature_extractor(self):
+ raise AttributeError("Not needed for Hubert")
+
+ def freeze_feature_encoder(self):
+ raise AttributeError("Not needed for Hubert")
+
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+ r"""
+ mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
+ masked extracted features in *config.proj_codevector_dim* space.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, HubertModel
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
+ >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
+
+
+ >>> def map_to_array(example):
+ ... example["speech"] = example["audio"]["array"]
+ ... return example
+
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> ds = ds.map(map_to_array)
+
+ >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
+ >>> hidden_states = model(input_values).last_hidden_state
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
+
+ hidden_states = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if not return_dict:
+ return (hidden_states,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class HubertForCTC(Wav2Vec2ForCTC):
+ pass
+
+
+class HubertForSequenceClassification(Wav2Vec2ForSequenceClassification):
+ pass
+
+
+__all__ = ["HubertForCTC", "HubertForSequenceClassification", "HubertModel", "HubertPreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_dense/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_dense/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..27de691c845369b97805fb53a8e509b7e948b386
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_dense/__init__.py
@@ -0,0 +1,15 @@
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_hunyuan_v1_dense import *
+ from .modeling_hunyuan_v1_dense import *
+ from .tokenization_hy import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py
new file mode 100644
index 0000000000000000000000000000000000000000..064b0a9702ccc1a1bd38dfe889bd0eb88291fac0
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py
@@ -0,0 +1,189 @@
+# coding=utf-8
+# Copyright (C) 2025 THL A29 Limited, a Tencent company and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""HunYuanDenseV1 model configuration"""
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class HunYuanDenseV1Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`HunYuanDenseV1Config`]. It is used to instantiate an
+ HunYuan model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the HunYuan-7B.
+ Hunyuan-7B-Instruct [tencent/Hunyuan-7B-Instruct](https://huggingface.co/tencent/Hunyuan-7B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 290943):
+ Vocabulary size of the HunYuan model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`HunYuanDenseV1Config`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations or shared MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ eod_token_id (int, *optional*, defaults to 3):
+ Token ID representing the end-of-document marker. Used to indicate the termination of a text sequence.
+ Example: In multi-document processing, this token helps the model distinguish between separate documents.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ head_dim (`int`, *optional*, defaults to 128):
+ The attention head dimension.
+ """
+
+ model_type = "hunyuan_v1_dense"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=290943,
+ hidden_size=4096,
+ intermediate_size: int = 11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ eod_token_id=3,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ head_dim=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ # self._rope_scaling_validation() # TODO: Need validation?
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor` or `type` and `alpha`, "
+ f"got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
+ rope_scaling_alpha = self.rope_scaling.get("alpha", None)
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
+ raise ValueError(
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
+ )
+ if rope_scaling_factor is None and rope_scaling_alpha is None:
+ raise ValueError("`rope_scaling`'s factor or alpha field must be have one, got both of none")
+ if rope_scaling_factor is not None:
+ if not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1.0, got {rope_scaling_factor}")
+ if rope_scaling_alpha is not None:
+ if not isinstance(rope_scaling_alpha, float) or rope_scaling_alpha <= 1.0:
+ raise ValueError(f"`rope_scaling`'s alpha field must be a float > 1.0, got {rope_scaling_alpha}")
+
+
+__all__ = ["HunYuanDenseV1Config"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py
new file mode 100644
index 0000000000000000000000000000000000000000..acd6f926ea436904611db8dc43eca5852c9b966d
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py
@@ -0,0 +1,514 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_hunyuan_v1_dense.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright (C) 2025 THL A29 Limited, a Tencent company and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from transformers.cache_utils import Cache
+
+from ...activations import ACT2FN
+from ...cache_utils import DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_hunyuan_v1_dense import HunYuanDenseV1Config
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class HunYuanDenseV1RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ HunYuanDenseV1RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class HunYuanDenseV1MLP(nn.Module):
+ def __init__(self, config: HunYuanDenseV1Config, layer_idx=None, is_shared_mlp=False):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+ self.layer_idx = layer_idx
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class HunYuanDenseV1Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: HunYuanDenseV1Config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.query_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.key_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+ query_states = self.query_layernorm(query_states)
+ key_states = self.key_layernorm(key_states)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class HunYuanDenseV1DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: HunYuanDenseV1Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = HunYuanDenseV1Attention(config=config, layer_idx=layer_idx)
+
+ self.mlp = HunYuanDenseV1MLP(config)
+ self.input_layernorm = HunYuanDenseV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = HunYuanDenseV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.layer_idx = layer_idx
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class HunYuanDenseV1PreTrainedModel(PreTrainedModel):
+ config: HunYuanDenseV1Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["HunYuanDenseV1DecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": HunYuanDenseV1DecoderLayer,
+ "attentions": HunYuanDenseV1Attention,
+ }
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class HunYuanDenseV1RotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: HunYuanDenseV1Config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ if self.rope_type == "dynamic" and config.rope_scaling["alpha"]:
+ # DynamicNTKAlphaRotary
+ self.dim = config.head_dim
+ base = config.rope_theta * config.rope_scaling.get("alpha") ** (self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
+ self.attention_scaling = 1.0
+ else:
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@auto_docstring
+class HunYuanDenseV1Model(HunYuanDenseV1PreTrainedModel):
+ def __init__(self, config: HunYuanDenseV1Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [HunYuanDenseV1DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = HunYuanDenseV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = HunYuanDenseV1RotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class HunYuanDenseV1ForCausalLM(HunYuanDenseV1PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = HunYuanDenseV1Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, HunYuanDenseV1ForCausalLM
+
+ >>> model = HunYuanDenseV1ForCausalLM.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class HunYuanDenseV1ForSequenceClassification(GenericForSequenceClassification, HunYuanDenseV1PreTrainedModel):
+ pass
+
+
+__all__ = [
+ "HunYuanDenseV1ForCausalLM",
+ "HunYuanDenseV1Model",
+ "HunYuanDenseV1PreTrainedModel",
+ "HunYuanDenseV1ForSequenceClassification",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py
new file mode 100644
index 0000000000000000000000000000000000000000..d527abc08f93a581810d5dcb06b7dc5538721fb5
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py
@@ -0,0 +1,193 @@
+# coding=utf-8
+# Copyright (C) 2025 THL A29 Limited, a Tencent company and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch HunYuanDenseV1 model."""
+
+from typing import Callable, Optional
+
+import torch
+from torch import nn
+
+from transformers.cache_utils import Cache
+from transformers.utils import (
+ logging,
+)
+
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs
+from ..llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaForCausalLM,
+ LlamaForSequenceClassification,
+ LlamaMLP,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ apply_rotary_pos_emb,
+ eager_attention_forward,
+)
+from .configuration_hunyuan_v1_dense import HunYuanDenseV1Config
+
+
+logger = logging.get_logger(__name__)
+
+
+class HunYuanDenseV1RMSNorm(LlamaRMSNorm):
+ pass
+
+
+class HunYuanDenseV1MLP(LlamaMLP):
+ def __init__(self, config: HunYuanDenseV1Config, layer_idx=None, is_shared_mlp=False):
+ super().__init__(config)
+ self.layer_idx = layer_idx
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+
+
+class HunYuanDenseV1Attention(LlamaAttention):
+ def __init__(self, config: HunYuanDenseV1Config, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.query_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.key_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+ query_states = self.query_layernorm(query_states)
+ key_states = self.key_layernorm(key_states)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class HunYuanDenseV1DecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: HunYuanDenseV1Config, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.layer_idx = layer_idx
+
+
+class HunYuanDenseV1PreTrainedModel(LlamaPreTrainedModel):
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class HunYuanDenseV1RotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: HunYuanDenseV1Config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ if self.rope_type == "dynamic" and config.rope_scaling["alpha"]:
+ # DynamicNTKAlphaRotary
+ self.dim = config.head_dim
+ base = config.rope_theta * config.rope_scaling.get("alpha") ** (self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
+ self.attention_scaling = 1.0
+ else:
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class HunYuanDenseV1Model(LlamaModel):
+ pass
+
+
+class HunYuanDenseV1ForCausalLM(LlamaForCausalLM):
+ pass
+
+
+class HunYuanDenseV1ForSequenceClassification(LlamaForSequenceClassification):
+ pass
+
+
+__all__ = [
+ "HunYuanDenseV1ForCausalLM",
+ "HunYuanDenseV1Model",
+ "HunYuanDenseV1PreTrainedModel",
+ "HunYuanDenseV1ForSequenceClassification",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_moe/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_moe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd107ee7a3c16d2527806035f67702e9220cae51
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_moe/__init__.py
@@ -0,0 +1,14 @@
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_hunyuan_v1_moe import *
+ from .modeling_hunyuan import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..386ddac1d3ebb37de330a3940ad5ac556be5bcf6
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py
@@ -0,0 +1,204 @@
+# coding=utf-8
+# Copyright (C) 2025 THL A29 Limited, a Tencent company and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""HunYuanMoEV1 model configuration"""
+
+from typing import Union
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class HunYuanMoEV1Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`HunYuanMoEV1Model`]. It is used to instantiate an
+ HunYuan model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the HunYuan-7B.
+ Hunyuan-A13B-Instruct [tencent/Hunyuan-A13B-Instruct](https://huggingface.co/tencent/Hunyuan-A13B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 290943):
+ Vocabulary size of the HunYuan model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`HunYuanMoEV1Model`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations or shared MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ eod_token_id (int, *optional*, defaults to 3):
+ Token ID representing the end-of-document marker. Used to indicate the termination of a text sequence.
+ Example: In multi-document processing, this token helps the model distinguish between separate documents.
+ sep_token_id (`int`, *optional*, defaults to 4):
+ Token ID representing the separator token (`[SEP]`), used to demarcate boundaries between different text segments.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ num_experts (`int` or `List`, *optional*, defaults to 1):
+ The number of experts for moe. If it is a list, it will be used as the number of experts for each layer.
+ moe_topk (int or List, *optional*, defaults to 1):
+ Number of experts selected per token (Top-K routing). List form enables layer-wise customization.
+ head_dim (`int`, *optional*, defaults to 128):
+ The attention head dimension.
+ """
+
+ model_type = "hunyuan_v1_moe"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=290943,
+ hidden_size=4096,
+ intermediate_size: int = 11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ eod_token_id=3,
+ sep_token_id=4,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ num_experts: Union[int, list] = 1,
+ moe_topk: Union[int, list] = 1,
+ head_dim=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_experts = num_experts
+ self.moe_topk = moe_topk
+
+ self.head_dim = head_dim
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ # self._rope_scaling_validation() # TODO: Need validation?
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ sep_token_id=sep_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor` or `type` and `alpha`, "
+ f"got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
+ rope_scaling_alpha = self.rope_scaling.get("alpha", None)
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
+ raise ValueError(
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
+ )
+ if rope_scaling_factor is None and rope_scaling_alpha is None:
+ raise ValueError("`rope_scaling`'s factor or alpha field must be have one, got both of none")
+ if rope_scaling_factor is not None:
+ if not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1.0, got {rope_scaling_factor}")
+ if rope_scaling_alpha is not None:
+ if not isinstance(rope_scaling_alpha, float) or rope_scaling_alpha <= 1.0:
+ raise ValueError(f"`rope_scaling`'s alpha field must be a float > 1.0, got {rope_scaling_alpha}")
+
+
+__all__ = ["HunYuanMoEV1Config"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..63d58a77e1bdf3d13780c374ca0dd5b0bb6c0e2d
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py
@@ -0,0 +1,584 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_hunyuan_v1_moe.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright (C) 2025 THL A29 Limited, a Tencent company and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from transformers.cache_utils import Cache
+
+from ...activations import ACT2FN
+from ...cache_utils import DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_hunyuan_v1_moe import HunYuanMoEV1Config
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class HunYuanMoEV1RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ HunYuanMoEV1RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class HunYuanMoEV1MLP(nn.Module):
+ def __init__(self, config: HunYuanMoEV1Config, layer_idx=None, is_shared_mlp=False):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+ self.layer_idx = layer_idx
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class HunYuanMoEV1Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.query_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.key_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+ query_states = self.query_layernorm(query_states)
+ key_states = self.key_layernorm(key_states)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class HunYuanMoEV1Gate(nn.Module):
+ def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx]
+ self.wg = nn.Linear(config.hidden_size, num_experts, bias=False, dtype=torch.float32)
+
+ def forward(self, hidden_states):
+ bsz, seq_len, hidden_size = hidden_states.shape
+ hidden_states = hidden_states.reshape(-1, hidden_size)
+ if self.wg.weight.dtype == torch.float32:
+ hidden_states = hidden_states.float()
+ logits = self.wg(hidden_states)
+ return logits
+
+
+class HunYuanMoEV1Moe(nn.Module):
+ def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx]
+ self.top_k = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx]
+ self.gate = HunYuanMoEV1Gate(config, layer_idx=layer_idx)
+ # self.wg = nn.Linear(config.hidden_size, config.num_experts, bias=False, dtype=torch.float32)
+ self.experts = nn.ModuleList(
+ [HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(self.num_experts)]
+ )
+
+ self.shared_mlp = HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ hidden_states_mlp = self.shared_mlp(hidden_states)
+ router_logits = self.gate(hidden_states)
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ # router_logits: (batch * sequence_length, n_experts)
+
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ # we cast back to the input dtype
+ routing_weights = routing_weights.to(hidden_states.dtype)
+
+ final_hidden_states = torch.zeros(
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ # One hot encode the selected experts to create an expert mask
+ # this will be used to easily index which expert is going to be sollicitated
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+ # Loop over all available experts in the model and perform the computation on each expert
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+ for expert_idx in expert_hit:
+ expert_layer = self.experts[expert_idx]
+ idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
+
+ # Index the correct hidden states and compute the expert hidden state for
+ # the current expert. We need to make sure to multiply the output hidden
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
+
+ # However `index_add_` only support torch tensors for indexing so we'll use
+ # the `top_x` tensor here.
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
+ return final_hidden_states + hidden_states_mlp
+
+
+class HunYuanMoEV1DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = HunYuanMoEV1Attention(config=config, layer_idx=layer_idx)
+ self.mlp = HunYuanMoEV1Moe(config, layer_idx=layer_idx)
+ self.input_layernorm = HunYuanMoEV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = HunYuanMoEV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.layer_idx = layer_idx
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class HunYuanMoEV1PreTrainedModel(PreTrainedModel):
+ config: HunYuanMoEV1Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["HunYuanMoEV1DecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _can_compile_fullgraph = False
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": HunYuanMoEV1DecoderLayer,
+ "attentions": HunYuanMoEV1Attention,
+ }
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class HunYuanMoEV1RotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: HunYuanMoEV1Config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ if self.rope_type == "dynamic" and config.rope_scaling["alpha"]:
+ # DynamicNTKAlphaRotary
+ self.dim = config.head_dim
+ base = config.rope_theta * config.rope_scaling.get("alpha") ** (self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
+ self.attention_scaling = 1.0
+ else:
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@auto_docstring
+class HunYuanMoEV1Model(HunYuanMoEV1PreTrainedModel):
+ def __init__(self, config: HunYuanMoEV1Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [HunYuanMoEV1DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = HunYuanMoEV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = HunYuanMoEV1RotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position: torch.Tensor = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class HunYuanMoEV1ForCausalLM(HunYuanMoEV1PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = HunYuanMoEV1Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, HunYuanMoEV1ForCausalLM
+
+ >>> model = HunYuanMoEV1ForCausalLM.from_pretrained("meta-hunyuan_v1_moe/HunYuanMoEV1-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-hunyuan_v1_moe/HunYuanMoEV1-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class HunYuanMoEV1ForSequenceClassification(GenericForSequenceClassification, HunYuanMoEV1PreTrainedModel):
+ pass
+
+
+__all__ = [
+ "HunYuanMoEV1ForCausalLM",
+ "HunYuanMoEV1Model",
+ "HunYuanMoEV1PreTrainedModel",
+ "HunYuanMoEV1ForSequenceClassification",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..a72d6268fe7021b472a141beee79ea3eadca5c84
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py
@@ -0,0 +1,273 @@
+# coding=utf-8
+# Copyright (C) 2025 THL A29 Limited, a Tencent company and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch HunYuanMoEV1 model."""
+
+from typing import Callable, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from transformers.cache_utils import Cache
+from transformers.utils import (
+ logging,
+)
+
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs
+from ..llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaForCausalLM,
+ LlamaForSequenceClassification,
+ LlamaMLP,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ apply_rotary_pos_emb,
+ eager_attention_forward,
+)
+from .configuration_hunyuan_v1_moe import HunYuanMoEV1Config
+
+
+logger = logging.get_logger(__name__)
+
+
+class HunYuanMoEV1RMSNorm(LlamaRMSNorm):
+ pass
+
+
+class HunYuanMoEV1MLP(LlamaMLP):
+ def __init__(self, config: HunYuanMoEV1Config, layer_idx=None, is_shared_mlp=False):
+ super().__init__(config)
+ self.layer_idx = layer_idx
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+
+
+class HunYuanMoEV1Attention(LlamaAttention):
+ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.query_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.key_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+ query_states = self.query_layernorm(query_states)
+ key_states = self.key_layernorm(key_states)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class HunYuanMoEV1Gate(nn.Module):
+ def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx]
+ self.wg = nn.Linear(config.hidden_size, num_experts, bias=False, dtype=torch.float32)
+
+ def forward(self, hidden_states):
+ bsz, seq_len, hidden_size = hidden_states.shape
+ hidden_states = hidden_states.reshape(-1, hidden_size)
+ if self.wg.weight.dtype == torch.float32:
+ hidden_states = hidden_states.float()
+ logits = self.wg(hidden_states)
+ return logits
+
+
+class HunYuanMoEV1Moe(nn.Module):
+ def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx]
+ self.top_k = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx]
+ self.gate = HunYuanMoEV1Gate(config, layer_idx=layer_idx)
+ # self.wg = nn.Linear(config.hidden_size, config.num_experts, bias=False, dtype=torch.float32)
+ self.experts = nn.ModuleList(
+ [HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(self.num_experts)]
+ )
+
+ self.shared_mlp = HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ hidden_states_mlp = self.shared_mlp(hidden_states)
+ router_logits = self.gate(hidden_states)
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ # router_logits: (batch * sequence_length, n_experts)
+
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ # we cast back to the input dtype
+ routing_weights = routing_weights.to(hidden_states.dtype)
+
+ final_hidden_states = torch.zeros(
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ # One hot encode the selected experts to create an expert mask
+ # this will be used to easily index which expert is going to be sollicitated
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+ # Loop over all available experts in the model and perform the computation on each expert
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+ for expert_idx in expert_hit:
+ expert_layer = self.experts[expert_idx]
+ idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
+
+ # Index the correct hidden states and compute the expert hidden state for
+ # the current expert. We need to make sure to multiply the output hidden
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
+
+ # However `index_add_` only support torch tensors for indexing so we'll use
+ # the `top_x` tensor here.
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
+ return final_hidden_states + hidden_states_mlp
+
+
+class HunYuanMoEV1DecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.hidden_size = config.hidden_size
+ self.self_attn = HunYuanMoEV1Attention(config=config, layer_idx=layer_idx)
+ self.mlp = HunYuanMoEV1Moe(config, layer_idx=layer_idx)
+ self.input_layernorm = HunYuanMoEV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = HunYuanMoEV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.layer_idx = layer_idx
+
+
+class HunYuanMoEV1PreTrainedModel(LlamaPreTrainedModel):
+ _can_compile_fullgraph = False
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class HunYuanMoEV1RotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: HunYuanMoEV1Config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ if self.rope_type == "dynamic" and config.rope_scaling["alpha"]:
+ # DynamicNTKAlphaRotary
+ self.dim = config.head_dim
+ base = config.rope_theta * config.rope_scaling.get("alpha") ** (self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
+ self.attention_scaling = 1.0
+ else:
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class HunYuanMoEV1Model(LlamaModel):
+ pass
+
+
+class HunYuanMoEV1ForCausalLM(LlamaForCausalLM):
+ pass
+
+
+class HunYuanMoEV1ForSequenceClassification(LlamaForSequenceClassification):
+ pass
+
+
+__all__ = [
+ "HunYuanMoEV1ForCausalLM",
+ "HunYuanMoEV1Model",
+ "HunYuanMoEV1PreTrainedModel",
+ "HunYuanMoEV1ForSequenceClassification",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/ibert/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/ibert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf34ec43ac1014d8c153b3aa259e394fc7b73570
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/ibert/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_ibert import *
+ from .modeling_ibert import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/ibert/configuration_ibert.py b/venv/lib/python3.13/site-packages/transformers/models/ibert/configuration_ibert.py
new file mode 100644
index 0000000000000000000000000000000000000000..963e6e6c9ed00bcb40dce7c4354110ee9b487187
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/ibert/configuration_ibert.py
@@ -0,0 +1,142 @@
+# coding=utf-8
+# Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao,
+# Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team.
+# Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""I-BERT configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class IBertConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`IBertModel`]. It is used to instantiate a I-BERT
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the IBERT
+ [kssteven/ibert-roberta-base](https://huggingface.co/kssteven/ibert-roberta-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the I-BERT model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`IBertModel`]
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`IBertModel`]
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
+ quant_mode (`bool`, *optional*, defaults to `False`):
+ Whether to quantize the model or not.
+ force_dequant (`str`, *optional*, defaults to `"none"`):
+ Force dequantize specific nonlinear layer. Dequantized layers are then executed with full precision.
+ `"none"`, `"gelu"`, `"softmax"`, `"layernorm"` and `"nonlinear"` are supported. As default, it is set as
+ `"none"`, which does not dequantize any layers. Please specify `"gelu"`, `"softmax"`, or `"layernorm"` to
+ dequantize GELU, Softmax, or LayerNorm, respectively. `"nonlinear"` will dequantize all nonlinear layers,
+ i.e., GELU, Softmax, and LayerNorm.
+ """
+
+ model_type = "ibert"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ position_embedding_type="absolute",
+ quant_mode=False,
+ force_dequant="none",
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.position_embedding_type = position_embedding_type
+ self.quant_mode = quant_mode
+ self.force_dequant = force_dequant
+
+
+class IBertOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ]
+ )
+
+
+__all__ = ["IBertConfig", "IBertOnnxConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/ibert/modeling_ibert.py b/venv/lib/python3.13/site-packages/transformers/models/ibert/modeling_ibert.py
new file mode 100644
index 0000000000000000000000000000000000000000..57b3df2f570babcd1b7d90e8bfee989b89fc99a8
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/ibert/modeling_ibert.py
@@ -0,0 +1,1253 @@
+# coding=utf-8
+# Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao,
+# Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team.
+# Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""PyTorch I-BERT model."""
+
+import math
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import gelu
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import auto_docstring, logging
+from .configuration_ibert import IBertConfig
+from .quant_modules import IntGELU, IntLayerNorm, IntSoftmax, QuantAct, QuantEmbedding, QuantLinear
+
+
+logger = logging.get_logger(__name__)
+
+
+class IBertEmbeddings(nn.Module):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.quant_mode = config.quant_mode
+ self.embedding_bit = 8
+ self.embedding_act_bit = 16
+ self.act_bit = 8
+ self.ln_input_bit = 22
+ self.ln_output_bit = 32
+
+ self.word_embeddings = QuantEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ padding_idx=config.pad_token_id,
+ weight_bit=self.embedding_bit,
+ quant_mode=self.quant_mode,
+ )
+ self.token_type_embeddings = QuantEmbedding(
+ config.type_vocab_size, config.hidden_size, weight_bit=self.embedding_bit, quant_mode=self.quant_mode
+ )
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ # End copy
+ self.padding_idx = config.pad_token_id
+ self.position_embeddings = QuantEmbedding(
+ config.max_position_embeddings,
+ config.hidden_size,
+ padding_idx=self.padding_idx,
+ weight_bit=self.embedding_bit,
+ quant_mode=self.quant_mode,
+ )
+
+ # Integer-only addition between embeddings
+ self.embeddings_act1 = QuantAct(self.embedding_act_bit, quant_mode=self.quant_mode)
+ self.embeddings_act2 = QuantAct(self.embedding_act_bit, quant_mode=self.quant_mode)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = IntLayerNorm(
+ config.hidden_size,
+ eps=config.layer_norm_eps,
+ output_bit=self.ln_output_bit,
+ quant_mode=self.quant_mode,
+ force_dequant=config.force_dequant,
+ )
+ self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = create_position_ids_from_input_ids(
+ input_ids, self.padding_idx, past_key_values_length
+ ).to(input_ids.device)
+ else:
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds, inputs_embeds_scaling_factor = self.word_embeddings(input_ids)
+ else:
+ inputs_embeds_scaling_factor = None
+ token_type_embeddings, token_type_embeddings_scaling_factor = self.token_type_embeddings(token_type_ids)
+
+ embeddings, embeddings_scaling_factor = self.embeddings_act1(
+ inputs_embeds,
+ inputs_embeds_scaling_factor,
+ identity=token_type_embeddings,
+ identity_scaling_factor=token_type_embeddings_scaling_factor,
+ )
+
+ if self.position_embedding_type == "absolute":
+ position_embeddings, position_embeddings_scaling_factor = self.position_embeddings(position_ids)
+ embeddings, embeddings_scaling_factor = self.embeddings_act1(
+ embeddings,
+ embeddings_scaling_factor,
+ identity=position_embeddings,
+ identity_scaling_factor=position_embeddings_scaling_factor,
+ )
+
+ embeddings, embeddings_scaling_factor = self.LayerNorm(embeddings, embeddings_scaling_factor)
+ embeddings = self.dropout(embeddings)
+ embeddings, embeddings_scaling_factor = self.output_activation(embeddings, embeddings_scaling_factor)
+ return embeddings, embeddings_scaling_factor
+
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+ """
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+
+ Args:
+ inputs_embeds: torch.Tensor
+
+ Returns: torch.Tensor
+ """
+ input_shape = inputs_embeds.size()[:-1]
+ sequence_length = input_shape[1]
+
+ position_ids = torch.arange(
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
+ )
+ return position_ids.unsqueeze(0).expand(input_shape)
+
+
+class IBertSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+ self.quant_mode = config.quant_mode
+ self.weight_bit = 8
+ self.bias_bit = 32
+ self.act_bit = 8
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ # Q, K, V Linear layers
+ self.query = QuantLinear(
+ config.hidden_size,
+ self.all_head_size,
+ bias=True,
+ weight_bit=self.weight_bit,
+ bias_bit=self.bias_bit,
+ quant_mode=self.quant_mode,
+ per_channel=True,
+ )
+ self.key = QuantLinear(
+ config.hidden_size,
+ self.all_head_size,
+ bias=True,
+ weight_bit=self.weight_bit,
+ bias_bit=self.bias_bit,
+ quant_mode=self.quant_mode,
+ per_channel=True,
+ )
+ self.value = QuantLinear(
+ config.hidden_size,
+ self.all_head_size,
+ bias=True,
+ weight_bit=self.weight_bit,
+ bias_bit=self.bias_bit,
+ quant_mode=self.quant_mode,
+ per_channel=True,
+ )
+
+ # Requantization (32bit -> 8bit) for Q, K, V activations
+ self.query_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
+ self.key_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
+ self.value_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
+ self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type != "absolute":
+ raise ValueError("I-BERT only supports 'absolute' for `config.position_embedding_type`")
+
+ self.softmax = IntSoftmax(self.act_bit, quant_mode=self.quant_mode, force_dequant=config.force_dequant)
+
+ def forward(
+ self,
+ hidden_states,
+ hidden_states_scaling_factor,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ ):
+ # Projection
+ mixed_query_layer, mixed_query_layer_scaling_factor = self.query(hidden_states, hidden_states_scaling_factor)
+ mixed_key_layer, mixed_key_layer_scaling_factor = self.key(hidden_states, hidden_states_scaling_factor)
+ mixed_value_layer, mixed_value_layer_scaling_factor = self.value(hidden_states, hidden_states_scaling_factor)
+
+ # Requantization
+ query_layer, query_layer_scaling_factor = self.query_activation(
+ mixed_query_layer, mixed_query_layer_scaling_factor
+ )
+ key_layer, key_layer_scaling_factor = self.key_activation(mixed_key_layer, mixed_key_layer_scaling_factor)
+ value_layer, value_layer_scaling_factor = self.value_activation(
+ mixed_value_layer, mixed_value_layer_scaling_factor
+ )
+
+ # Transpose
+ batch_size, seq_length, _ = hidden_states.shape
+ query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
+ 1, 2
+ )
+ key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
+ value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
+ 1, 2
+ )
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ scale = math.sqrt(self.attention_head_size)
+ attention_scores = attention_scores / scale
+ if self.quant_mode:
+ attention_scores_scaling_factor = query_layer_scaling_factor * key_layer_scaling_factor / scale
+ else:
+ attention_scores_scaling_factor = None
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in IBertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs, attention_probs_scaling_factor = self.softmax(
+ attention_scores, attention_scores_scaling_factor
+ )
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ if attention_probs_scaling_factor is not None:
+ context_layer_scaling_factor = attention_probs_scaling_factor * value_layer_scaling_factor
+ else:
+ context_layer_scaling_factor = None
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ # requantization: 32-bit -> 8-bit
+ context_layer, context_layer_scaling_factor = self.output_activation(
+ context_layer, context_layer_scaling_factor
+ )
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+ output_scaling_factor = (
+ (context_layer_scaling_factor, attention_probs_scaling_factor)
+ if output_attentions
+ else (context_layer_scaling_factor,)
+ )
+
+ return outputs, output_scaling_factor
+
+
+class IBertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.quant_mode = config.quant_mode
+ self.act_bit = 8
+ self.weight_bit = 8
+ self.bias_bit = 32
+ self.ln_input_bit = 22
+ self.ln_output_bit = 32
+
+ self.dense = QuantLinear(
+ config.hidden_size,
+ config.hidden_size,
+ bias=True,
+ weight_bit=self.weight_bit,
+ bias_bit=self.bias_bit,
+ quant_mode=self.quant_mode,
+ per_channel=True,
+ )
+ self.ln_input_act = QuantAct(self.ln_input_bit, quant_mode=self.quant_mode)
+ self.LayerNorm = IntLayerNorm(
+ config.hidden_size,
+ eps=config.layer_norm_eps,
+ output_bit=self.ln_output_bit,
+ quant_mode=self.quant_mode,
+ force_dequant=config.force_dequant,
+ )
+ self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, hidden_states_scaling_factor, input_tensor, input_tensor_scaling_factor):
+ hidden_states, hidden_states_scaling_factor = self.dense(hidden_states, hidden_states_scaling_factor)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states, hidden_states_scaling_factor = self.ln_input_act(
+ hidden_states,
+ hidden_states_scaling_factor,
+ identity=input_tensor,
+ identity_scaling_factor=input_tensor_scaling_factor,
+ )
+ hidden_states, hidden_states_scaling_factor = self.LayerNorm(hidden_states, hidden_states_scaling_factor)
+
+ hidden_states, hidden_states_scaling_factor = self.output_activation(
+ hidden_states, hidden_states_scaling_factor
+ )
+ return hidden_states, hidden_states_scaling_factor
+
+
+class IBertAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.quant_mode = config.quant_mode
+ self.self = IBertSelfAttention(config)
+ self.output = IBertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ hidden_states_scaling_factor,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ ):
+ self_outputs, self_outputs_scaling_factor = self.self(
+ hidden_states,
+ hidden_states_scaling_factor,
+ attention_mask,
+ head_mask,
+ output_attentions,
+ )
+ attention_output, attention_output_scaling_factor = self.output(
+ self_outputs[0], self_outputs_scaling_factor[0], hidden_states, hidden_states_scaling_factor
+ )
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ outputs_scaling_factor = (attention_output_scaling_factor,) + self_outputs_scaling_factor[1:]
+ return outputs, outputs_scaling_factor
+
+
+class IBertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.quant_mode = config.quant_mode
+ self.act_bit = 8
+ self.weight_bit = 8
+ self.bias_bit = 32
+ self.dense = QuantLinear(
+ config.hidden_size,
+ config.intermediate_size,
+ bias=True,
+ weight_bit=self.weight_bit,
+ bias_bit=self.bias_bit,
+ quant_mode=self.quant_mode,
+ per_channel=True,
+ )
+ if config.hidden_act != "gelu":
+ raise ValueError("I-BERT only supports 'gelu' for `config.hidden_act`")
+ self.intermediate_act_fn = IntGELU(quant_mode=self.quant_mode, force_dequant=config.force_dequant)
+ self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
+
+ def forward(self, hidden_states, hidden_states_scaling_factor):
+ hidden_states, hidden_states_scaling_factor = self.dense(hidden_states, hidden_states_scaling_factor)
+ hidden_states, hidden_states_scaling_factor = self.intermediate_act_fn(
+ hidden_states, hidden_states_scaling_factor
+ )
+
+ # Requantization: 32bit -> 8-bit
+ hidden_states, hidden_states_scaling_factor = self.output_activation(
+ hidden_states, hidden_states_scaling_factor
+ )
+ return hidden_states, hidden_states_scaling_factor
+
+
+class IBertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.quant_mode = config.quant_mode
+ self.act_bit = 8
+ self.weight_bit = 8
+ self.bias_bit = 32
+ self.ln_input_bit = 22
+ self.ln_output_bit = 32
+
+ self.dense = QuantLinear(
+ config.intermediate_size,
+ config.hidden_size,
+ bias=True,
+ weight_bit=self.weight_bit,
+ bias_bit=self.bias_bit,
+ quant_mode=self.quant_mode,
+ per_channel=True,
+ )
+ self.ln_input_act = QuantAct(self.ln_input_bit, quant_mode=self.quant_mode)
+ self.LayerNorm = IntLayerNorm(
+ config.hidden_size,
+ eps=config.layer_norm_eps,
+ output_bit=self.ln_output_bit,
+ quant_mode=self.quant_mode,
+ force_dequant=config.force_dequant,
+ )
+ self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, hidden_states_scaling_factor, input_tensor, input_tensor_scaling_factor):
+ hidden_states, hidden_states_scaling_factor = self.dense(hidden_states, hidden_states_scaling_factor)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states, hidden_states_scaling_factor = self.ln_input_act(
+ hidden_states,
+ hidden_states_scaling_factor,
+ identity=input_tensor,
+ identity_scaling_factor=input_tensor_scaling_factor,
+ )
+ hidden_states, hidden_states_scaling_factor = self.LayerNorm(hidden_states, hidden_states_scaling_factor)
+
+ hidden_states, hidden_states_scaling_factor = self.output_activation(
+ hidden_states, hidden_states_scaling_factor
+ )
+ return hidden_states, hidden_states_scaling_factor
+
+
+class IBertLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.quant_mode = config.quant_mode
+ self.act_bit = 8
+
+ self.seq_len_dim = 1
+ self.attention = IBertAttention(config)
+ self.intermediate = IBertIntermediate(config)
+ self.output = IBertOutput(config)
+
+ self.pre_intermediate_act = QuantAct(self.act_bit, quant_mode=self.quant_mode)
+ self.pre_output_act = QuantAct(self.act_bit, quant_mode=self.quant_mode)
+
+ def forward(
+ self,
+ hidden_states,
+ hidden_states_scaling_factor,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ ):
+ self_attention_outputs, self_attention_outputs_scaling_factor = self.attention(
+ hidden_states,
+ hidden_states_scaling_factor,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ attention_output_scaling_factor = self_attention_outputs_scaling_factor[0]
+
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ layer_output, layer_output_scaling_factor = self.feed_forward_chunk(
+ attention_output, attention_output_scaling_factor
+ )
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output, attention_output_scaling_factor):
+ attention_output, attention_output_scaling_factor = self.pre_intermediate_act(
+ attention_output, attention_output_scaling_factor
+ )
+ intermediate_output, intermediate_output_scaling_factor = self.intermediate(
+ attention_output, attention_output_scaling_factor
+ )
+
+ intermediate_output, intermediate_output_scaling_factor = self.pre_output_act(
+ intermediate_output, intermediate_output_scaling_factor
+ )
+ layer_output, layer_output_scaling_factor = self.output(
+ intermediate_output, intermediate_output_scaling_factor, attention_output, attention_output_scaling_factor
+ )
+ return layer_output, layer_output_scaling_factor
+
+
+class IBertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.quant_mode = config.quant_mode
+ self.layer = nn.ModuleList([IBertLayer(config) for _ in range(config.num_hidden_layers)])
+
+ def forward(
+ self,
+ hidden_states,
+ hidden_states_scaling_factor,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = None # `config.add_cross_attention` is not supported
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states,
+ hidden_states_scaling_factor,
+ attention_mask,
+ layer_head_mask,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class IBertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.quant_mode = config.quant_mode
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+@auto_docstring
+class IBertPreTrainedModel(PreTrainedModel):
+ config: IBertConfig
+ base_model_prefix = "ibert"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (QuantLinear, nn.Linear)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (QuantEmbedding, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, (IntLayerNorm, nn.LayerNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, IBertLMHead):
+ module.bias.data.zero_()
+
+ def resize_token_embeddings(self, new_num_tokens=None):
+ raise NotImplementedError("`resize_token_embeddings` is not supported for I-BERT.")
+
+
+@auto_docstring
+class IBertModel(IBertPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
+ all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ r"""
+ add_pooling_layer (bool, *optional*, defaults to `True`):
+ Whether to add a pooling layer
+ """
+ super().__init__(config)
+ self.config = config
+ self.quant_mode = config.quant_mode
+
+ self.embeddings = IBertEmbeddings(config)
+ self.encoder = IBertEncoder(config)
+
+ self.pooler = IBertPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, tuple[torch.FloatTensor]]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output, embedding_output_scaling_factor = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ embedding_output_scaling_factor,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+@auto_docstring
+class IBertForMaskedLM(IBertPreTrainedModel):
+ _tied_weights_keys = ["lm_head.decoder.bias", "lm_head.decoder.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.ibert = IBertModel(config, add_pooling_layer=False)
+ self.lm_head = IBertLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head.decoder = new_embeddings
+ self.lm_head.bias = new_embeddings.bias
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[MaskedLMOutput, tuple[torch.FloatTensor]]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ibert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ prediction_scores = self.lm_head(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class IBertLMHead(nn.Module):
+ """I-BERT Head for masked language modeling."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+ self.decoder.bias = self.bias
+
+ def forward(self, features, **kwargs):
+ x = self.dense(features)
+ x = gelu(x)
+ x = self.layer_norm(x)
+
+ # project back to size of vocabulary with bias
+ x = self.decoder(x)
+
+ return x
+
+ def _tie_weights(self) -> None:
+ # For accelerate compatibility and to not break backward compatibility
+ if self.decoder.bias.device.type == "meta":
+ self.decoder.bias = self.bias
+ else:
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
+ self.bias = self.decoder.bias
+
+
+@auto_docstring(
+ custom_intro="""
+ I-BERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """
+)
+class IBertForSequenceClassification(IBertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.ibert = IBertModel(config, add_pooling_layer=False)
+ self.classifier = IBertClassificationHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[SequenceClassifierOutput, tuple[torch.FloatTensor]]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ibert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class IBertForMultipleChoice(IBertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.ibert = IBertModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[MultipleChoiceModelOutput, tuple[torch.FloatTensor]]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ flat_inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.ibert(
+ flat_input_ids,
+ position_ids=flat_position_ids,
+ token_type_ids=flat_token_type_ids,
+ attention_mask=flat_attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class IBertForTokenClassification(IBertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.ibert = IBertModel(config, add_pooling_layer=False)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[TokenClassifierOutput, tuple[torch.FloatTensor]]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ibert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class IBertClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(self, features, **kwargs):
+ hidden_states = features[:, 0, :] # take token (equiv. to [CLS])
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.dense(hidden_states)
+ hidden_states = torch.tanh(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.out_proj(hidden_states)
+ return hidden_states
+
+
+@auto_docstring
+class IBertForQuestionAnswering(IBertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.ibert = IBertModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[QuestionAnsweringModelOutput, tuple[torch.FloatTensor]]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ibert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+ are ignored. This is modified from fairseq's *utils.make_positions*.
+
+ Args:
+ input_ids (`torch.LongTensor`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Returns: torch.Tensor
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = input_ids.ne(padding_idx).int()
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
+ return incremental_indices.long() + padding_idx
+
+
+__all__ = [
+ "IBertForMaskedLM",
+ "IBertForMultipleChoice",
+ "IBertForQuestionAnswering",
+ "IBertForSequenceClassification",
+ "IBertForTokenClassification",
+ "IBertModel",
+ "IBertPreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/ibert/quant_modules.py b/venv/lib/python3.13/site-packages/transformers/models/ibert/quant_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..949702a5af97da779cb6dab842b0029d274417dc
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/ibert/quant_modules.py
@@ -0,0 +1,820 @@
+# coding=utf-8
+# Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao,
+# Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team.
+# Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import decimal
+
+import numpy as np
+import torch
+from torch import nn
+from torch.autograd import Function
+
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class QuantEmbedding(nn.Module):
+ """
+ Quantized version of `torch.nn.Embedding`. Adds quantization-specific arguments on top of `torch.nn.Embedding`.
+
+ Args:
+ weight_bit (`int`, *optional*, defaults to `8`):
+ Bitwidth for the quantized weight.
+ momentum (`float`, *optional*, defaults to `0.95`):
+ Momentum for updating the activation quantization range.
+ quant_mode (`bool`, *optional*, defaults to `False`):
+ Whether or not the layer is quantized.
+ """
+
+ def __init__(
+ self,
+ num_embeddings,
+ embedding_dim,
+ padding_idx=None,
+ max_norm=None,
+ norm_type=2.0,
+ scale_grad_by_freq=False,
+ sparse=False,
+ _weight=None,
+ weight_bit=8,
+ momentum=0.95,
+ quant_mode=False,
+ ):
+ super().__init__()
+ self.num_ = num_embeddings
+ self.dim = embedding_dim
+ self.padding_idx = padding_idx
+ self.max_norm = max_norm
+ self.norm_type = norm_type
+ self.scale_grad_by_freq = scale_grad_by_freq
+ self.sparse = sparse
+
+ self.weight = nn.Parameter(torch.zeros([num_embeddings, embedding_dim]))
+ self.register_buffer("weight_scaling_factor", torch.zeros(1))
+ self.register_buffer("weight_integer", torch.zeros_like(self.weight))
+
+ self.weight_bit = weight_bit
+ self.momentum = momentum
+ self.quant_mode = quant_mode
+ self.percentile_mode = False
+ self.weight_function = SymmetricQuantFunction.apply
+
+ def forward(self, x, positions=None, incremental_state=None):
+ if not self.quant_mode:
+ return (
+ nn.functional.embedding(
+ x,
+ self.weight,
+ self.padding_idx,
+ self.max_norm,
+ self.norm_type,
+ self.scale_grad_by_freq,
+ self.sparse,
+ ),
+ None,
+ )
+
+ w = self.weight
+ w_transform = w.data.detach()
+ w_min = w_transform.min().expand(1)
+ w_max = w_transform.max().expand(1)
+
+ self.weight_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, False)
+ self.weight_integer = self.weight_function(
+ self.weight, self.weight_bit, self.percentile_mode, self.weight_scaling_factor
+ )
+
+ emb_int = nn.functional.embedding(
+ x,
+ self.weight_integer,
+ self.padding_idx,
+ self.max_norm,
+ self.norm_type,
+ self.scale_grad_by_freq,
+ self.sparse,
+ )
+ return emb_int * self.weight_scaling_factor, self.weight_scaling_factor
+
+
+class QuantAct(nn.Module):
+ """
+ Quantizes the given activation.
+
+ Args:
+ activation_bit (`int`):
+ Bitwidth for the quantized activation.
+ act_range_momentum (`float`, *optional*, defaults to `0.95`):
+ Momentum for updating the activation quantization range.
+ per_channel (`bool`, *optional*, defaults to `False`):
+ Whether to or not use channel-wise quantization.
+ channel_len (`int`, *optional*):
+ Specify the channel length when set the *per_channel* True.
+ quant_mode (`bool`, *optional*, defaults to `False`):
+ Whether or not the layer is quantized.
+ """
+
+ def __init__(self, activation_bit, act_range_momentum=0.95, per_channel=False, channel_len=None, quant_mode=False):
+ super().__init__()
+
+ self.activation_bit = activation_bit
+ self.act_range_momentum = act_range_momentum
+ self.quant_mode = quant_mode
+ self.per_channel = per_channel
+ self.percentile = False
+ self.act_function = SymmetricQuantFunction.apply
+
+ if not self.per_channel:
+ self.register_buffer("x_min", torch.zeros(1))
+ self.register_buffer("x_max", torch.zeros(1))
+ self.register_buffer("act_scaling_factor", torch.zeros(1))
+ self.x_min -= 1e-5
+ self.x_max += 1e-5
+ else:
+ raise NotImplementedError("per-channel mode is not currently supported for activation.")
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}(activation_bit={self.activation_bit}, "
+ f"quant_mode: {self.quant_mode}, Act_min: {self.x_min.item():.2f}, "
+ f"Act_max: {self.x_max.item():.2f})"
+ )
+
+ def forward(
+ self,
+ x,
+ pre_act_scaling_factor=None,
+ identity=None,
+ identity_scaling_factor=None,
+ specified_min=None,
+ specified_max=None,
+ ):
+ x_act = x if identity is None else identity + x
+ # collect running stats if training
+ if self.training:
+ assert not self.percentile, "percentile mode is not currently supported for activation."
+ assert not self.per_channel, "per-channel mode is not currently supported for activation."
+ x_min = x_act.data.min()
+ x_max = x_act.data.max()
+
+ assert x_max.isnan().sum() == 0 and x_min.isnan().sum() == 0, (
+ "NaN detected when computing min/max of the activation"
+ )
+
+ # Initialization
+ if self.x_min.min() > -1.1e-5 and self.x_max.max() < 1.1e-5:
+ self.x_min = self.x_min + x_min
+ self.x_max = self.x_max + x_max
+
+ # exponential moving average (EMA)
+ # use momentum to prevent the quantized values change greatly every iteration
+ elif self.act_range_momentum == -1:
+ self.x_min = torch.min(self.x_min, x_min)
+ self.x_max = torch.max(self.x_max, x_max)
+ else:
+ self.x_min = self.x_min * self.act_range_momentum + x_min * (1 - self.act_range_momentum)
+ self.x_max = self.x_max * self.act_range_momentum + x_max * (1 - self.act_range_momentum)
+
+ if not self.quant_mode:
+ return x_act, None
+
+ x_min = self.x_min if specified_min is None else specified_min
+ x_max = self.x_max if specified_max is None else specified_max
+
+ self.act_scaling_factor = symmetric_linear_quantization_params(
+ self.activation_bit, x_min, x_max, per_channel=self.per_channel
+ )
+
+ if pre_act_scaling_factor is None:
+ # this is for the input quantization
+ quant_act_int = self.act_function(x, self.activation_bit, self.percentile, self.act_scaling_factor)
+ else:
+ quant_act_int = FixedPointMul.apply(
+ x,
+ pre_act_scaling_factor,
+ self.activation_bit,
+ self.act_scaling_factor,
+ identity,
+ identity_scaling_factor,
+ )
+
+ correct_output_scale = self.act_scaling_factor.view(-1)
+
+ return quant_act_int * correct_output_scale, self.act_scaling_factor
+
+
+class QuantLinear(nn.Module):
+ """
+ Quantized version of `torch.nn.Linear`. Adds quantization-specific arguments on top of `torch.nn.Linear`.
+
+ Args:
+ weight_bit (`int`, *optional*, defaults to `8`):
+ Bitwidth for the quantized weight.
+ bias_bit (`int`, *optional*, defaults to `32`):
+ Bitwidth for the quantized bias.
+ per_channel (`bool`, *optional*, defaults to `False`):
+ Whether or not to use channel-wise quantization.
+ quant_mode (`bool`, *optional*, defaults to `False`):
+ Whether or not the layer is quantized.
+ """
+
+ def __init__(
+ self, in_features, out_features, bias=True, weight_bit=8, bias_bit=32, per_channel=False, quant_mode=False
+ ):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+
+ self.weight = nn.Parameter(torch.zeros([out_features, in_features]))
+ self.register_buffer("weight_integer", torch.zeros_like(self.weight))
+ self.register_buffer("fc_scaling_factor", torch.zeros(self.out_features))
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_features))
+ self.register_buffer("bias_integer", torch.zeros_like(self.bias))
+
+ self.weight_bit = weight_bit
+ self.quant_mode = quant_mode
+ self.per_channel = per_channel
+ self.bias_bit = bias_bit
+ self.quant_mode = quant_mode
+ self.percentile_mode = False
+ self.weight_function = SymmetricQuantFunction.apply
+
+ def __repr__(self):
+ s = super().__repr__()
+ s = f"({s} weight_bit={self.weight_bit}, quant_mode={self.quant_mode})"
+ return s
+
+ def forward(self, x, prev_act_scaling_factor=None):
+ if not self.quant_mode:
+ return nn.functional.linear(x, weight=self.weight, bias=self.bias), None
+
+ # assert that prev_act_scaling_factor is a scalar tensor
+ assert prev_act_scaling_factor is not None and prev_act_scaling_factor.shape == (1,), (
+ "Input activation to the QuantLinear layer should be globally (non-channel-wise) quantized. "
+ "Please add a QuantAct layer with `per_channel = True` before this QuantAct layer"
+ )
+
+ w = self.weight
+ w_transform = w.data.detach()
+ if self.per_channel:
+ w_min, _ = torch.min(w_transform, dim=1, out=None)
+ w_max, _ = torch.max(w_transform, dim=1, out=None)
+ else:
+ w_min = w_transform.min().expand(1)
+ w_max = w_transform.max().expand(1)
+
+ self.fc_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, self.per_channel)
+ self.weight_integer = self.weight_function(
+ self.weight, self.weight_bit, self.percentile_mode, self.fc_scaling_factor
+ )
+
+ bias_scaling_factor = self.fc_scaling_factor * prev_act_scaling_factor
+
+ if self.bias is not None:
+ self.bias_integer = self.weight_function(self.bias, self.bias_bit, False, bias_scaling_factor)
+
+ prev_act_scaling_factor = prev_act_scaling_factor.view(1, -1)
+ x_int = x / prev_act_scaling_factor
+
+ return (
+ nn.functional.linear(x_int, weight=self.weight_integer, bias=self.bias_integer) * bias_scaling_factor,
+ bias_scaling_factor,
+ )
+
+
+class IntGELU(nn.Module):
+ """
+ Quantized version of `torch.nn.GELU`. Adds quantization-specific arguments on top of `torch.nn.GELU`.
+
+ Args:
+ quant_mode (`bool`, *optional*, defaults to `False`):
+ Whether or not the layer is quantized.
+ force_dequant (`str`, *optional*, defaults to `"none"`):
+ Force dequantize the layer if either "gelu" or "nonlinear" is given.
+ """
+
+ def __init__(self, quant_mode=True, force_dequant="none"):
+ super().__init__()
+ self.quant_mode = quant_mode
+
+ if force_dequant in ["nonlinear", "gelu"]:
+ logger.info("Force dequantize gelu")
+ self.quant_mode = False
+
+ if not self.quant_mode:
+ self.activation_fn = nn.GELU()
+
+ self.k = 1.4142
+ self.const = 14 # dummy integer constant
+ self.coeff = [-0.2888, -1.769, 1] # a(x+b)**2 + c
+ self.coeff[2] /= self.coeff[0]
+
+ def int_erf(self, x_int, scaling_factor):
+ b_int = torch.floor(self.coeff[1] / scaling_factor)
+ c_int = torch.floor(self.coeff[2] / scaling_factor**2)
+ sign = torch.sign(x_int)
+
+ abs_int = torch.min(torch.abs(x_int), -b_int)
+ y_int = sign * ((abs_int + b_int) ** 2 + c_int)
+ scaling_factor = scaling_factor**2 * self.coeff[0]
+
+ # avoid overflow
+ y_int = floor_ste.apply(y_int / 2**self.const)
+ scaling_factor = scaling_factor * 2**self.const
+
+ return y_int, scaling_factor
+
+ def forward(self, x, scaling_factor=None):
+ if not self.quant_mode:
+ return self.activation_fn(x), None
+
+ x_int = x / scaling_factor
+ sigmoid_int, sigmoid_scaling_factor = self.int_erf(x_int, scaling_factor / self.k)
+
+ shift_int = 1.0 // sigmoid_scaling_factor
+
+ x_int = x_int * (sigmoid_int + shift_int)
+ scaling_factor = scaling_factor * sigmoid_scaling_factor / 2
+
+ return x_int * scaling_factor, scaling_factor
+
+
+class IntSoftmax(nn.Module):
+ """
+ Quantized version of `torch.nn.Softmax`. Adds quantization-specific arguments on top of `torch.nn.Softmax`.
+
+ Args:
+ output_bit (`int`):
+ Bitwidth for the layer output activation.
+ quant_mode (`bool`, *optional*, defaults to `False`):
+ Whether or not the layer is quantized.
+ force_dequant (`str`, *optional*, defaults to `"none"`):
+ Force dequantize the layer if either "softmax" or "nonlinear" is given.
+ """
+
+ def __init__(self, output_bit, quant_mode=False, force_dequant="none"):
+ super().__init__()
+ self.output_bit = output_bit
+ self.max_bit = 32
+ self.quant_mode = quant_mode
+
+ if force_dequant in ["nonlinear", "softmax"]:
+ logger.info("Force dequantize softmax")
+ self.quant_mode = False
+
+ self.act = QuantAct(16, quant_mode=self.quant_mode)
+ self.x0 = -0.6931 # -ln2
+ self.const = 30 # dummy integer constant
+ self.coef = [0.35815147, 0.96963238, 1.0] # ax**2 + bx + c
+ self.coef[1] /= self.coef[0]
+ self.coef[2] /= self.coef[0]
+
+ def int_polynomial(self, x_int, scaling_factor):
+ with torch.no_grad():
+ b_int = torch.floor(self.coef[1] / scaling_factor)
+ c_int = torch.floor(self.coef[2] / scaling_factor**2)
+ z = (x_int + b_int) * x_int + c_int
+ scaling_factor = self.coef[0] * scaling_factor**2
+ return z, scaling_factor
+
+ def int_exp(self, x_int, scaling_factor):
+ with torch.no_grad():
+ x0_int = torch.floor(self.x0 / scaling_factor)
+ x_int = torch.max(x_int, self.const * x0_int)
+
+ q = floor_ste.apply(x_int / x0_int)
+ r = x_int - x0_int * q
+ exp_int, exp_scaling_factor = self.int_polynomial(r, scaling_factor)
+ exp_int = torch.clamp(floor_ste.apply(exp_int * 2 ** (self.const - q)), min=0)
+ scaling_factor = exp_scaling_factor / 2**self.const
+ return exp_int, scaling_factor
+
+ def forward(self, x, scaling_factor):
+ if not self.quant_mode:
+ return nn.functional.softmax(x, dim=-1), None
+
+ x_int = x / scaling_factor
+
+ x_int_max, _ = x_int.max(dim=-1, keepdim=True)
+ x_int = x_int - x_int_max
+ exp_int, exp_scaling_factor = self.int_exp(x_int, scaling_factor)
+
+ # Avoid overflow
+ exp, exp_scaling_factor = self.act(exp_int, exp_scaling_factor)
+ exp_int = exp / exp_scaling_factor
+
+ exp_int_sum = exp_int.sum(dim=-1, keepdim=True)
+ factor = floor_ste.apply(2**self.max_bit / exp_int_sum)
+ exp_int = floor_ste.apply(exp_int * factor / 2 ** (self.max_bit - self.output_bit))
+ scaling_factor = 1 / 2**self.output_bit
+ return exp_int * scaling_factor, scaling_factor
+
+
+class IntLayerNorm(nn.Module):
+ """
+ Quantized version of `torch.nn.LayerNorm`. Adds quantization-specific arguments on top of `torch.nn.LayerNorm`.
+
+ Args:
+ output_bit (`int`, *optional*, defaults to `8`):
+ Bitwidth for the layer output activation.
+ quant_mode (`bool`, *optional*, defaults to `False`):
+ Whether or not the layer is quantized.
+ force_dequant (`str`, *optional*, defaults to `"none"`):
+ Force dequantize the layer if either "layernorm" or "nonlinear" is given.
+ """
+
+ def __init__(self, normalized_shape, eps, output_bit=8, quant_mode=False, force_dequant="none"):
+ super().__init__()
+ self.normalized_shape = normalized_shape
+ self.eps = eps
+
+ self.weight = nn.Parameter(torch.zeros(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+
+ self.quant_mode = quant_mode
+ if force_dequant in ["nonlinear", "layernorm"]:
+ logger.info("Force dequantize layernorm")
+ self.quant_mode = False
+
+ self.register_buffer("shift", torch.zeros(1))
+ self.output_bit = output_bit
+ self.max_bit = 32
+ self.dim_sqrt = None
+ self.activation = QuantAct(self.output_bit, quant_mode=self.quant_mode)
+
+ def set_shift(self, y_int):
+ with torch.no_grad():
+ y_sq_int = y_int**2
+ var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
+ shift = (torch.log2(torch.sqrt(var_int / 2**self.max_bit)).ceil()).max()
+ shift_old = self.shift
+ self.shift = torch.max(self.shift, shift)
+ logger.info(f"Dynamic shift adjustment: {int(shift_old)} -> {int(self.shift)}")
+
+ def overflow_fallback(self, y_int):
+ """
+ This fallback function is called when overflow is detected during training time, and adjusts the `self.shift`
+ to avoid overflow in the subsequent runs.
+ """
+ self.set_shift(y_int) # adjusts `self.shift`
+ y_int_shifted = floor_ste.apply(y_int / 2**self.shift)
+ y_sq_int = y_int_shifted**2
+ var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
+ return var_int
+
+ def forward(self, x, scaling_factor=None):
+ if not self.quant_mode:
+ mean = x.mean(axis=2, keepdim=True)
+ y = x - mean
+ var = torch.mean(y**2, axis=2, keepdim=True)
+ x = y / torch.sqrt(self.eps + var)
+ x = x * self.weight + self.bias
+ return x, None
+
+ # compute sqrt of the feature dimension if it is the first run
+ if self.dim_sqrt is None:
+ n = torch.tensor(x.shape[2], dtype=torch.float)
+ self.dim_sqrt = torch.sqrt(n).to(x.device)
+
+ # Normalization: computes mean and variance(std)
+ x_int = x / scaling_factor
+ mean_int = round_ste.apply(x_int.mean(axis=2, keepdim=True))
+ y_int = x_int - mean_int
+ y_int_shifted = floor_ste.apply(y_int / 2**self.shift)
+ y_sq_int = y_int_shifted**2
+ var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
+
+ # overflow handling in training time
+ if self.training:
+ # if overflow is detected
+ if var_int.max() >= 2**self.max_bit:
+ var_int = self.overflow_fallback(y_int)
+ assert var_int.max() < 2**self.max_bit + 0.1, (
+ "Error detected in overflow handling: "
+ "`var_int` exceeds `self.max_bit` (the maximum possible bit width)"
+ )
+
+ # To be replaced with integer-sqrt kernel that produces the same output
+ std_int = floor_ste.apply(torch.sqrt(var_int)) * 2**self.shift
+ factor = floor_ste.apply(2**31 / std_int)
+ y_int = floor_ste.apply(y_int * factor / 2)
+ scaling_factor = self.dim_sqrt / 2**30
+
+ # scaling and shifting
+ bias = self.bias.data.detach() / (self.weight.data.detach())
+ bias_int = floor_ste.apply(bias / scaling_factor)
+
+ y_int = y_int + bias_int
+ scaling_factor = scaling_factor * self.weight
+ x = y_int * scaling_factor
+
+ return x, scaling_factor
+
+
+def get_percentile_min_max(input, lower_percentile, upper_percentile, output_tensor=False):
+ """
+ Calculate the percentile max and min values in a given tensor
+
+ Args:
+ input (`torch.Tensor`):
+ The target tensor to calculate percentile max and min.
+ lower_percentile (`float`):
+ If 0.1, means we return the value of the smallest 0.1% value in the tensor as percentile min.
+ upper_percentile (`float`):
+ If 99.9, means we return the value of the largest 0.1% value in the tensor as percentile max.
+ output_tensor (`bool`, *optional*, defaults to `False`):
+ If True, this function returns tensors, otherwise it returns values.
+
+ Returns:
+ `Tuple(torch.Tensor, torch.Tensor)`: Percentile min and max value of *input*
+ """
+ input_length = input.shape[0]
+
+ lower_index = round(input_length * (1 - lower_percentile * 0.01))
+ upper_index = round(input_length * upper_percentile * 0.01)
+
+ upper_bound = torch.kthvalue(input, k=upper_index).values
+
+ if lower_percentile == 0:
+ lower_bound = upper_bound * 0
+ # lower_index += 1
+ else:
+ lower_bound = -torch.kthvalue(-input, k=lower_index).values
+
+ if not output_tensor:
+ lower_bound = lower_bound.item()
+ upper_bound = upper_bound.item()
+ return lower_bound, upper_bound
+
+
+def linear_quantize(input, scale, zero_point, inplace=False):
+ """
+ Quantize single-precision input tensor to integers with the given scaling factor and zeropoint.
+
+ Args:
+ input (`torch.Tensor`):
+ Single-precision input tensor to be quantized.
+ scale (`torch.Tensor`):
+ Scaling factor for quantization.
+ zero_pint (`torch.Tensor`):
+ Shift for quantization.
+ inplace (`bool`, *optional*, defaults to `False`):
+ Whether to compute inplace or not.
+
+ Returns:
+ `torch.Tensor`: Linearly quantized value of *input* according to *scale* and *zero_point*.
+ """
+ # reshape scale and zeropoint for convolutional weights and activation
+ if len(input.shape) == 4:
+ scale = scale.view(-1, 1, 1, 1)
+ zero_point = zero_point.view(-1, 1, 1, 1)
+ # reshape scale and zeropoint for linear weights
+ elif len(input.shape) == 2:
+ scale = scale.view(-1, 1)
+ zero_point = zero_point.view(-1, 1)
+ else:
+ scale = scale.view(-1)
+ zero_point = zero_point.view(-1)
+ # quantized = float / scale + zero_point
+ if inplace:
+ input.mul_(1.0 / scale).add_(zero_point).round_()
+ return input
+ return torch.round(1.0 / scale * input + zero_point)
+
+
+def symmetric_linear_quantization_params(num_bits, saturation_min, saturation_max, per_channel=False):
+ """
+ Compute the scaling factor with the given quantization range for symmetric quantization.
+
+ Args:
+ saturation_min (`torch.Tensor`):
+ Lower bound for quantization range.
+ saturation_max (`torch.Tensor`):
+ Upper bound for quantization range.
+ per_channel (`bool`, *optional*, defaults to `False`):
+ Whether to or not use channel-wise quantization.
+
+ Returns:
+ `torch.Tensor`: Scaling factor that linearly quantizes the given range between *saturation_min* and
+ *saturation_max*.
+ """
+ # in this part, we do not need any gradient computation,
+ # in order to enforce this, we put torch.no_grad()
+ with torch.no_grad():
+ n = 2 ** (num_bits - 1) - 1
+
+ if per_channel:
+ scale, _ = torch.max(torch.stack([saturation_min.abs(), saturation_max.abs()], dim=1), dim=1)
+ scale = torch.clamp(scale, min=1e-8) / n
+
+ else:
+ scale = max(saturation_min.abs(), saturation_max.abs())
+ scale = torch.clamp(scale, min=1e-8) / n
+
+ return scale
+
+
+class SymmetricQuantFunction(Function):
+ """
+ Class to quantize the given floating-point values using symmetric quantization with given range and bitwidth.
+ """
+
+ @staticmethod
+ def forward(ctx, x, k, percentile_mode, scale):
+ """
+ Args:
+ x (`torch.Tensor`):
+ Floating point tensor to be quantized.
+ k (`int`):
+ Quantization bitwidth.
+ percentile_mode (`bool`):
+ Whether or not to use percentile calibration.
+ scale (`torch.Tensor`):
+ Pre-calculated scaling factor for *x*. Note that the current implementation of SymmetricQuantFunction
+ requires pre-calculated scaling factor.
+
+ Returns:
+ `torch.Tensor`: Symmetric-quantized value of *input*.
+ """
+ zero_point = torch.tensor(0.0, device=scale.device)
+
+ n = 2 ** (k - 1) - 1
+ new_quant_x = linear_quantize(x, scale, zero_point, inplace=False)
+ new_quant_x = torch.clamp(new_quant_x, -n, n - 1)
+
+ ctx.scale = scale
+ return new_quant_x
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ scale = ctx.scale
+ if len(grad_output.shape) == 4:
+ scale = scale.view(-1, 1, 1, 1)
+ # reshape scale and zeropoint for linear weights
+ elif len(grad_output.shape) == 2:
+ scale = scale.view(-1, 1)
+ else:
+ scale = scale.view(-1)
+
+ return grad_output.clone() / scale, None, None, None, None
+
+
+class floor_ste(Function):
+ """
+ Straight-through Estimator(STE) for torch.floor()
+ """
+
+ @staticmethod
+ def forward(ctx, x):
+ return torch.floor(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output.clone()
+
+
+class round_ste(Function):
+ """
+ Straight-through Estimator(STE) for torch.round()
+ """
+
+ @staticmethod
+ def forward(ctx, x):
+ return torch.round(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output.clone()
+
+
+def batch_frexp(inputs, max_bit=31):
+ """
+ Decompose the scaling factor into mantissa and twos exponent.
+
+ Args:
+ scaling_factor (`torch.Tensor`):
+ Target scaling factor to decompose.
+
+ Returns:
+ ``Tuple(torch.Tensor, torch.Tensor)`: mantisa and exponent
+ """
+
+ shape_of_input = inputs.size()
+
+ # trans the input to be a 1-d tensor
+ inputs = inputs.view(-1)
+
+ output_m, output_e = np.frexp(inputs.cpu().numpy())
+ tmp_m = []
+ for m in output_m:
+ int_m_shifted = int(
+ decimal.Decimal(m * (2**max_bit)).quantize(decimal.Decimal("1"), rounding=decimal.ROUND_HALF_UP)
+ )
+ tmp_m.append(int_m_shifted)
+ output_m = np.array(tmp_m)
+
+ output_e = float(max_bit) - output_e
+
+ return (
+ torch.from_numpy(output_m).to(inputs.device).view(shape_of_input),
+ torch.from_numpy(output_e).to(inputs.device).view(shape_of_input),
+ )
+
+
+class FixedPointMul(Function):
+ """
+ Function to perform fixed-point arithmetic that can match integer arithmetic on hardware.
+
+ Args:
+ pre_act (`torch.Tensor`):
+ Input tensor.
+ pre_act_scaling_factor (`torch.Tensor`):
+ Scaling factor of the input tensor *pre_act*.
+ bit_num (`int`):
+ Quantization bitwidth.
+ z_scaling_factor (`torch.Tensor`):
+ Scaling factor of the output tensor.
+ identity (`torch.Tensor`, *optional*):
+ Identity tensor, if exists.
+ identity_scaling_factor (`torch.Tensor`, *optional*):
+ Scaling factor of the identity tensor *identity*, if exists.
+
+ Returns:
+ `torch.Tensor`: Output tensor(*pre_act* if *identity* is not given, otherwise the addition of *pre_act* and
+ *identity*), whose scale is rescaled to *z_scaling_factor*.
+ """
+
+ @staticmethod
+ def forward(
+ ctx,
+ pre_act,
+ pre_act_scaling_factor,
+ bit_num,
+ z_scaling_factor,
+ identity=None,
+ identity_scaling_factor=None,
+ ):
+ if len(pre_act_scaling_factor.shape) == 3:
+ reshape = lambda x: x # noqa: E731
+ else:
+ reshape = lambda x: x.view(1, 1, -1) # noqa: E731
+ ctx.identity = identity
+
+ n = 2 ** (bit_num - 1) - 1
+
+ with torch.no_grad():
+ pre_act_scaling_factor = reshape(pre_act_scaling_factor)
+ if identity is not None:
+ identity_scaling_factor = reshape(identity_scaling_factor)
+
+ ctx.z_scaling_factor = z_scaling_factor
+
+ z_int = torch.round(pre_act / pre_act_scaling_factor)
+ _A = pre_act_scaling_factor.type(torch.double)
+ _B = (z_scaling_factor.type(torch.float)).type(torch.double)
+ new_scale = _A / _B
+ new_scale = reshape(new_scale)
+
+ m, e = batch_frexp(new_scale)
+
+ output = z_int.type(torch.double) * m.type(torch.double)
+ output = torch.round(output / (2.0**e))
+
+ if identity is not None:
+ # needs addition of identity activation
+ wx_int = torch.round(identity / identity_scaling_factor)
+
+ _A = identity_scaling_factor.type(torch.double)
+ _B = (z_scaling_factor.type(torch.float)).type(torch.double)
+ new_scale = _A / _B
+ new_scale = reshape(new_scale)
+
+ m1, e1 = batch_frexp(new_scale)
+ output1 = wx_int.type(torch.double) * m1.type(torch.double)
+ output1 = torch.round(output1 / (2.0**e1))
+
+ output = output1 + output
+
+ return torch.clamp(output.type(torch.float), -n - 1, n)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ identity_grad = None
+ if ctx.identity is not None:
+ identity_grad = grad_output.clone() / ctx.z_scaling_factor
+ return grad_output.clone() / ctx.z_scaling_factor, None, None, None, None, identity_grad, None
diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics3/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/idefics3/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1dd3bfda7fbce5af140b438ec00f3ff51718ed5
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/idefics3/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_idefics3 import *
+ from .image_processing_idefics3 import *
+ from .image_processing_idefics3_fast import *
+ from .modeling_idefics3 import *
+ from .processing_idefics3 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics3/configuration_idefics3.py b/venv/lib/python3.13/site-packages/transformers/models/idefics3/configuration_idefics3.py
new file mode 100644
index 0000000000000000000000000000000000000000..97a2e57f1d8dc4cc2b8c27ebade7ef1a5109f97f
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/idefics3/configuration_idefics3.py
@@ -0,0 +1,190 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Idefics3 model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class Idefics3VisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Idefics3VisionModel`]. It is used to instantiate a
+ Idefics3 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the SigLIP checkpoint
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) used in the Idefics3 model
+ [HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 1152):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input images.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 32):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+ Example:
+
+ ```python
+ >>> from transformers.models.idefics3.modeling_idefics3 import Idefics3VisionTransformer
+ >>> from transformers.models.idefics3.configuration_idefics3 import Idefics3VisionConfig
+
+ >>> # Initializing a Idefics3VisionConfig with google/siglip-base-patch16-224 style configuration
+ >>> configuration = Idefics3VisionConfig()
+
+ >>> # Initializing a Idefics3VisionTransformer (with random weights) from the google/siglip-base-patch16-224 style configuration
+ >>> model = Idefics3VisionTransformer(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "idefics3_vision"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=1152,
+ intermediate_size=3072,
+ num_hidden_layers=12,
+ num_attention_heads=16,
+ num_channels=3,
+ image_size=224,
+ patch_size=32,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+
+
+class Idefics3Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Idefics3Model`]. It is used to instantiate a
+ Idefics3 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the model of the Idefics3
+ [HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should cache the key/value pairs of the attention mechanism. Only
+ relevant if `config.is_decoder=True`.
+ image_token_id (`int`, *optional*, defaults to 128257):
+ The id of the "image" token.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to tie the word embeddings with the token embeddings.
+ vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`):
+ Custom vision config or dict for the vision tower
+ text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`):
+ Custom text config or dict for the text model
+ scale_factor (`int`, *optional*, defaults to 2):
+ The scale factor for the image encoder.
+ pad_token_id (`int`, *optional*, defaults to 128002):
+ The id of the padding token.
+
+ Example:
+ ```python
+ >>> from transformers import Idefics3Model, Idefics3Config
+ >>> # Initializing configuration
+ >>> configuration = Idefics3Config()
+ >>> # Initializing a model from the configuration
+ >>> model = Idefics3Model(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "idefics3"
+ sub_configs = {"text_config": AutoConfig, "vision_config": Idefics3VisionConfig}
+
+ def __init__(
+ self,
+ use_cache=True,
+ image_token_id=128257,
+ tie_word_embeddings=False,
+ vision_config=None,
+ text_config=None,
+ scale_factor=2,
+ pad_token_id=128_002,
+ **kwargs,
+ ):
+ self.image_token_id = image_token_id
+ self.use_cache = use_cache
+ self.tie_word_embeddings = tie_word_embeddings
+
+ if vision_config is None:
+ self.vision_config = Idefics3VisionConfig()
+ logger.info("vision_config is None, using default vision config")
+ elif isinstance(vision_config, dict):
+ self.vision_config = Idefics3VisionConfig(**vision_config)
+ elif isinstance(vision_config, Idefics3VisionConfig):
+ self.vision_config = vision_config
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "llama")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ logger.info("text_config is None, using default text config")
+ text_config = CONFIG_MAPPING["llama"](
+ rms_norm_eps=1e-5,
+ pad_token_id=pad_token_id,
+ tie_word_embeddings=False,
+ )
+
+ self.text_config = text_config
+ self.scale_factor = scale_factor
+
+ super().__init__(**kwargs, pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings)
+
+
+__all__ = ["Idefics3Config", "Idefics3VisionConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics3/image_processing_idefics3.py b/venv/lib/python3.13/site-packages/transformers/models/idefics3/image_processing_idefics3.py
new file mode 100644
index 0000000000000000000000000000000000000000..e460a041965ab1315f992b3378df40dc3af7d735
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/idefics3/image_processing_idefics3.py
@@ -0,0 +1,904 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections.abc import Iterable
+from typing import Any, Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature
+from ...image_transforms import PaddingMode, pad, to_channel_dimension_format, to_pil_image
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_nested_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, is_vision_available, logging
+
+
+logger = logging.get_logger(__name__)
+MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum
+
+
+if is_vision_available():
+ import PIL
+ from PIL import Image
+
+
+def _resize_output_size_rescale_to_max_len(
+ height: int, width: int, min_len: Optional[int] = 1, max_len: Optional[int] = None
+) -> tuple[int, int]:
+ """
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ min_len (`int`, *optional*, defaults to 1):
+ Minimum size of the output image.
+ max_len (`int`, *optional*, defaults to the maximum size of the image):
+ Maximum size of the output image.
+ Returns:
+ The output size of the image after resizing.
+ """
+ max_len = max(height, width) if max_len is None else max_len
+ aspect_ratio = width / height
+
+ if width >= height:
+ width = max_len
+ height = int(width / aspect_ratio)
+ if height % 2 != 0:
+ height += 1
+ elif height > width:
+ height = max_len
+ width = int(height * aspect_ratio)
+ if width % 2 != 0:
+ width += 1
+
+ # Avoid resizing to a size smaller than min_len
+ height = max(height, min_len)
+ width = max(width, min_len)
+ return height, width
+
+
+def _resize_output_size_scale_below_upper_bound(
+ height: int, width: int, max_len: Optional[dict[str, int]] = None
+) -> tuple[int, int]:
+ """
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ max_len (`dict[str, int]`, *optional*, defaults to the maximum size of the image):
+ Defines the maximum dimensions of the image.
+ Returns:
+ The output size of the image after resizing.
+ """
+ max_len = max(height, width) if max_len is None else max_len
+
+ aspect_ratio = width / height
+ if width >= height and width > max_len:
+ width = max_len
+ height = int(width / aspect_ratio)
+ elif height > width and height > max_len:
+ height = max_len
+ width = int(height * aspect_ratio)
+
+ # Avoid resizing to a size smaller than 1
+ height = max(height, 1)
+ width = max(width, 1)
+ return height, width
+
+
+def get_resize_output_image_size(
+ image,
+ resolution_max_side: int,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> tuple[int, int]:
+ """
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ resolution_max_side (`int`):
+ The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
+ input aspect ratio.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+ Returns:
+ The output size of the image after resizing.
+ """
+ height, width = get_image_size(image, channel_dim=input_data_format)
+
+ # Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
+ height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
+ # Find the output size when scaling the image to be below the MAX_IMAGE_SIZE
+ height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
+ return height, width
+
+
+# Copied from transformers.models.detr.image_processing_detr.max_across_indices
+def max_across_indices(values: Iterable[Any]) -> list[Any]:
+ """
+ Return the maximum value across all indices of an iterable of values.
+ """
+ return [max(values_i) for values_i in zip(*values)]
+
+
+def get_max_height_width(
+ images_list: list[list[np.ndarray]], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> list[int]:
+ """
+ Get the maximum height and width across all images in a batch.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4))
+
+ max_height = max_width = float("-inf")
+ for images in images_list:
+ for image in images:
+ height, width = get_image_size(image, channel_dim=input_data_format)
+ max_height = max(height, max_height)
+ max_width = max(width, max_width)
+ return (max_height, max_width)
+
+
+# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
+def make_pixel_mask(
+ image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> np.ndarray:
+ """
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+ Args:
+ image (`np.ndarray`):
+ Image to make the pixel mask for.
+ output_size (`tuple[int, int]`):
+ Output size of the mask.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ mask = np.zeros(output_size, dtype=np.int64)
+ mask[:input_height, :input_width] = 1
+ return mask
+
+
+def convert_to_rgb(
+ image: np.ndarray,
+ palette: Optional[PIL.ImagePalette.ImagePalette] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> ImageInput:
+ """
+ Converts an image to RGB format.
+ Args:
+ image (`np.ndarray`):
+ The image to convert.
+ palette (list[int], *optional*):
+ The palette to use if given.
+ data_format (ChannelDimension or str, *optional*):
+ The channel dimension format for the output image. If not provided, it will be the same as the input image.
+ input_data_format (ChannelDimension or str, *optional*):
+ The channel dimension format of the input image.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
+
+ # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
+ # The resized image from PIL will always have channels last, so find the input format first.
+ data_format = input_data_format if data_format is None else data_format
+
+ mode = "P" if palette is not None else None
+ image = to_pil_image(image, image_mode=mode, input_data_format=input_data_format)
+ if image.mode == "P" and palette is not None:
+ image.putpalette(palette)
+
+ image_rgba = image.convert("RGBA")
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
+ alpha_composite = Image.alpha_composite(background, image_rgba)
+ alpha_composite = alpha_composite.convert("RGB")
+
+ output_array = np.array(alpha_composite)
+ # The image is always in channels last format after converting from a PIL image
+ output_array = to_channel_dimension_format(output_array, data_format, input_channel_dim=ChannelDimension.LAST)
+ return output_array
+
+
+# FIXME Amy: make a more general crop function that isn't just centre crop
+def _crop(
+ image: np.ndarray,
+ w1: int,
+ h1: int,
+ w2: int,
+ h2: int,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> np.ndarray:
+ if data_format is None:
+ data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
+
+ if data_format == ChannelDimension.FIRST:
+ image = image[:, h1:h2, w1:w2]
+ elif data_format == ChannelDimension.LAST:
+ image = image[h1:h2, w1:w2, :]
+ else:
+ raise ValueError("Invalid channel dimension format.")
+
+ return image
+
+
+class Idefics3ImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Idefics3 image processor.
+ Args:
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB. This is useful if the input image is of a different format e.g. RGBA.
+ Only has an effect if the input image is in the PIL format.
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image. The longest edge of the image is resized to be <= `size["longest_edge"]`, with the
+ shortest edge resized to keep the input aspect ratio.
+ size (`Dict`, *optional*, defaults to `{"longest_edge": 4 * 364}`):
+ Controls the size of the output image. This is a dictionary containing the key "longest_edge".
+ The image will be resized such that the longest edge is <= `size["longest_edge"]` and the shortest edge is resized
+ to keep the input aspect ratio.
+ resample (`Resampling`, *optional*, defaults to `Resampling.LANCZOS`):
+ Resampling filter to use when resizing the image.
+ do_image_splitting (`bool`, *optional*, defaults to `True`):
+ Whether to split the image into sub-images concatenated with the original image. They are split into patches
+ such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
+ max_image_size (`Dict`, *optional*, defaults to `{"longest_edge": 364}`):
+ Maximum resolution of the patches of images accepted by the model. This is a dictionary containing the key "longest_edge".
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image. If set to `True`, the image is rescaled to have pixel values between 0 and 1.
+ rescale_factor (`float`, *optional*, defaults to `1/255`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. If set to `True`, the image is normalized to have a mean of `image_mean` and
+ a standard deviation of `image_std`.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether or not to pad the images to the largest height and width in the batch and number of images per
+ sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width).
+ """
+
+ model_input_names = ["pixel_values", "pixel_attention_mask"]
+
+ def __init__(
+ self,
+ do_convert_rgb: bool = True,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.LANCZOS,
+ do_image_splitting: bool = True,
+ max_image_size: Optional[dict[str, int]] = None,
+ do_rescale: bool = True,
+ rescale_factor: float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.do_convert_rgb = do_convert_rgb
+ self.do_resize = do_resize
+ self.size = size if size is not None else {"longest_edge": 4 * 364}
+ self.resample = resample
+ self.do_image_splitting = do_image_splitting
+ self.max_image_size = max_image_size if max_image_size is not None else {"longest_edge": 364}
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+ self.do_pad = do_pad
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.LANCZOS,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image. The longest edge of the image is resized to size["longest_edge"], with the shortest edge
+ resized to keep the input aspect ratio. Can also be used with size["height"] and size["width"].
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
+ Resampling filter to use when resizing the image.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the output image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
+
+ # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
+ # The resized image from PIL will always have channels last, so find the input format first.
+ data_format = input_data_format if data_format is None else data_format
+
+ if "longest_edge" in size:
+ size = get_resize_output_image_size(
+ image, resolution_max_side=size["longest_edge"], input_data_format=input_data_format
+ )
+ elif "height" in size and "width" in size:
+ size = (size["height"], size["width"])
+ else:
+ raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
+
+ image_mode = None
+ if image.ndim == 2 or image.shape[-1] == 1:
+ image_mode = "P"
+ image = to_pil_image(image, image_mode=image_mode, input_data_format=input_data_format)
+
+ resized_image = image.resize((size[1], size[0]), resample=resample)
+ resized_image = np.array(resized_image)
+
+ # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
+ # so we need to add it back if necessary.
+ resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
+ # The image is always in channels last format after converting from a PIL image
+ resized_image = to_channel_dimension_format(
+ resized_image, data_format, input_channel_dim=ChannelDimension.LAST
+ )
+ return resized_image
+
+ def split_image(
+ self,
+ image,
+ max_image_size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.LANCZOS,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Split an image into squares of side max_image_size and the original image resized to max_image_size.
+ That means that a single image becomes a sequence of images.
+ This is a "trick" to spend more compute on each image with no changes in the vision encoder.
+ 1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio.
+ 2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)`
+ sub-images of the same size each (image_size, image_size). Typically, 364x364.
+ 3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width.
+ Args:
+ image (`np.ndarray`):
+ Images to split.
+ max_image_size (`dict[str, int]`):
+ Maximum size of the output image. If the image is larger than this size, it will be split into
+ patches of this size, and the original image will be concatenated with the patches, resized to max_size.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
+ Resampling filter to use when resizing the image.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the output image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ height, width = get_image_size(image, channel_dim=input_data_format)
+ max_height = max_width = max_image_size["longest_edge"]
+
+ frames = []
+ if height > max_height or width > max_width:
+ # Calculate the number of splits
+ num_splits_h = math.ceil(height / max_height)
+ num_splits_w = math.ceil(width / max_width)
+ # Calculate the optimal width and height for the sub-images
+ optimal_height = math.ceil(height / num_splits_h)
+ optimal_width = math.ceil(width / num_splits_w)
+
+ # Iterate through each row and column
+ for r in range(num_splits_h):
+ for c in range(num_splits_w):
+ # Calculate the starting point of the crop
+ start_x = c * optimal_width
+ start_y = r * optimal_height
+
+ # Calculate the ending point of the crop
+ end_x = min(start_x + optimal_width, width)
+ end_y = min(start_y + optimal_height, height)
+
+ # Crop the image
+ cropped_image = _crop(
+ image,
+ start_x,
+ start_y,
+ end_x,
+ end_y,
+ data_format=data_format,
+ )
+ frames.append(cropped_image)
+
+ # For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency
+ global_image_height, global_image_width = max_height, max_width
+ if height != global_image_height or width != global_image_width:
+ image = self.resize(
+ image,
+ {"height": global_image_height, "width": global_image_width},
+ resample=resample,
+ input_data_format=data_format,
+ )
+ else:
+ num_splits_h, num_splits_w = 0, 0
+
+ frames.append(image)
+
+ return frames, num_splits_h, num_splits_w
+
+ def resize_for_vision_encoder(
+ self,
+ image: np.ndarray,
+ vision_encoder_max_size: int,
+ resample: PILImageResampling = PILImageResampling.LANCZOS,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Resize images to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
+ Args:
+ image (`np.ndarray`):
+ Images to resize.
+ vision_encoder_max_size (`int`):
+ Maximum size of the output image. If the image is larger than this size, it will be split into
+ patches of this size, and the original image will be concatenated with the patches, resized to max_size.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
+ Resampling filter to use when resizing the image.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the output image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred
+ """
+ height, width = get_image_size(image, channel_dim=input_data_format)
+
+ aspect_ratio = width / height
+ if width >= height:
+ width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
+ height = int(width / aspect_ratio)
+ height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
+ elif height > width:
+ height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
+ width = int(height * aspect_ratio)
+ width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
+ new_size = {"height": height, "width": width}
+ return self.resize(
+ image, size=new_size, resample=resample, input_data_format=input_data_format, data_format=data_format
+ )
+
+ def _pad_image(
+ self,
+ image: np.ndarray,
+ output_size: tuple[int, int],
+ constant_values: Union[float, Iterable[float]] = 0,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pad an image with zeros to the given size.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ output_height, output_width = output_size
+
+ pad_bottom = output_height - input_height
+ pad_right = output_width - input_width
+ padding = ((0, pad_bottom), (0, pad_right))
+ padded_image = pad(
+ image,
+ padding,
+ mode=PaddingMode.CONSTANT,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ return padded_image
+
+ def pad(
+ self,
+ images: list[list[np.ndarray]],
+ constant_values: Union[float, Iterable[float]] = 0,
+ return_pixel_mask: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> BatchFeature:
+ """
+ For a list of images, for each images, pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width.
+ For each sample in the batch, pads the sample with empty images to the max_number of images per sample in the batch. Optionally returns a pixel mask.
+ Args:
+ images (`list[list[np.ndarray]]`):
+ List of list of images to pad. Pads to the largest height and width in the batch.
+ constant_values (`float` or `Iterable[float]`, *optional*):
+ The value to use for the padding if `mode` is `"constant"`.
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
+ Whether to return a pixel mask.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ pad_size = get_max_height_width(images, input_data_format=input_data_format)
+
+ batch_size = len(images)
+ max_num_images = max(len(images_) for images_ in images)
+ input_data_format = (
+ infer_channel_dimension_format(images[0][0], num_channels=(1, 3, 4))
+ if input_data_format is None
+ else input_data_format
+ )
+ data_format = input_data_format if data_format is None else data_format
+ # filter out empty image lists, then take first image of the first sample
+ first_image_in_list = [sample_images for sample_images in images if sample_images][0][0]
+
+ if input_data_format == ChannelDimension.FIRST:
+ n_channels = first_image_in_list.shape[0]
+ elif input_data_format == ChannelDimension.LAST:
+ n_channels = first_image_in_list.shape[-1]
+ else:
+ raise ValueError("Invalid channel dimension format.")
+
+ def empty_image(size, input_data_format):
+ if input_data_format == ChannelDimension.FIRST:
+ return np.zeros((n_channels, *size), dtype=np.uint8)
+ elif input_data_format == ChannelDimension.LAST:
+ return np.zeros((*size, n_channels), dtype=np.uint8)
+
+ padded_images_list = [
+ [empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size)
+ ]
+ padded_masks = [[np.zeros(pad_size, dtype=np.int64) for _ in range(max_num_images)] for _ in range(batch_size)]
+
+ for batch_idx in range(batch_size):
+ for sample_idx, image in enumerate(images[batch_idx]):
+ padded_images_list[batch_idx][sample_idx] = self._pad_image(
+ image,
+ pad_size,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ padded_masks[batch_idx][sample_idx] = make_pixel_mask(
+ image, output_size=pad_size, input_data_format=input_data_format
+ )
+
+ padded_masks = padded_masks if return_pixel_mask else None
+ return padded_images_list, padded_masks
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_convert_rgb: Optional[bool] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_image_splitting: Optional[bool] = None,
+ do_rescale: Optional[bool] = None,
+ max_image_size: Optional[dict[str, int]] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_row_col_info: bool = False,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Preprocess a batch of images.
+ Args:
+ images (`ImageInput`):
+ A list of images to preprocess.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. With the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_image_splitting (`bool`, *optional*, defaults to `self.do_image_splitting`):
+ Whether to split the image into sub-images concatenated with the original image. They are split into patches
+ such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
+ max_image_size (`Dict`, *optional*, defaults to `self.max_image_size`):
+ Maximum resolution of the images. If the image is larger than this size, the image is split into patches.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether or not to pad the images to the largest height and width in the batch.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ return_row_col_info (`bool`, *optional*, default to `False`):
+ Whether to return the number of rows and columns of the split images. This is used for the
+ `Idefics3Processor` to generate prompt strings based on the number of rows and columns.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_image_splitting = do_image_splitting if do_image_splitting is not None else self.do_image_splitting
+ max_image_size = max_image_size if max_image_size is not None else self.max_image_size
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+ do_pad = do_pad if do_pad is not None else self.do_pad
+
+ images = self.fetch_images(images)
+ images_list = make_nested_list_of_images(images)
+
+ if not valid_images(images_list[0]):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ # save the palettes for conversion to RGB
+ palettes_list = [
+ [im.getpalette() if isinstance(im, Image.Image) and im.mode == "P" else None for im in images]
+ for images in images_list
+ ]
+
+ # All transformations expect numpy arrays.
+ images_list = [[to_numpy_array(image) for image in images] for images in images_list]
+ # Search for the first image in the image list.
+ # NOTE: we can't slice the first image with images_list[0][0] if the first batch contains no images. See #36682
+ first_image_in_list = [images for images in images_list if images][0][0]
+
+ # Extra channel dimension for grayscale images
+ if input_data_format in [ChannelDimension.LAST, None]:
+ images_list = [
+ [np.expand_dims(img, axis=-1) if img.ndim == 2 else img for img in images] for images in images_list
+ ]
+ elif input_data_format == ChannelDimension.FIRST:
+ images_list = [
+ [np.expand_dims(img, axis=0) if img.ndim == 2 else img for img in images] for images in images_list
+ ]
+
+ if do_rescale and is_scaled_image(first_image_in_list):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ # We assume that all images have the same channel dimension format.
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(first_image_in_list, num_channels=(1, 3, 4))
+
+ if do_resize:
+ images_list = [
+ [
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+ for image in images
+ ]
+ for images in images_list
+ ]
+
+ if do_image_splitting:
+ # We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
+ # for size=(10, max_image_size) -> rescaled_size=(max_image_size, max_image_size)
+ # for size=(11, max_image_size+1) -> rescaled_size=(max_image_size, max_image_size*2)
+ images_list = [
+ [
+ self.resize_for_vision_encoder(
+ image, max_image_size["longest_edge"], resample=resample, input_data_format=input_data_format
+ )
+ for image in images
+ ]
+ for images in images_list
+ ]
+ images_list_split_arrays = []
+ palettes_list_split_arrays = []
+ images_list_rows = []
+ images_list_cols = []
+ for images, palettes in zip(images_list, palettes_list):
+ split_image_arrays = []
+ split_palettes_arrays = []
+ image_rows = []
+ image_cols = []
+ for image, palette in zip(images, palettes):
+ split_image_array, rows, cols = self.split_image(
+ image,
+ max_image_size=max_image_size,
+ resample=resample,
+ input_data_format=input_data_format,
+ )
+ split_image_arrays.extend(split_image_array)
+ split_palettes_arrays.extend([palette] * len(split_image_array))
+ image_rows.append(rows)
+ image_cols.append(cols)
+ images_list_split_arrays.append(split_image_arrays)
+ palettes_list_split_arrays.append(split_palettes_arrays)
+ images_list_rows.append(image_rows)
+ images_list_cols.append(image_cols)
+ images_list = images_list_split_arrays
+ palettes_list = palettes_list_split_arrays
+ else:
+ # We square the images to max_image_size
+ images_list = [
+ [
+ self.resize(
+ image=image,
+ size={"height": max_image_size["longest_edge"], "width": max_image_size["longest_edge"]},
+ resample=resample,
+ input_data_format=input_data_format,
+ )
+ for image in images
+ ]
+ for images in images_list
+ ]
+ images_list_rows = [[0] * len(images) for images in images_list]
+ images_list_cols = [[0] * len(images) for images in images_list]
+
+ if do_convert_rgb:
+ images_list = [
+ [convert_to_rgb(img, palette) for img, palette in zip(images, palettes)]
+ for images, palettes in zip(images_list, palettes_list)
+ ]
+
+ if do_rescale:
+ images_list = [
+ [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
+ for images in images_list
+ ]
+
+ if do_normalize:
+ images_list = [
+ [
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+ for images in images_list
+ ]
+
+ pixel_attention_mask = None
+ if do_pad:
+ images_list, pixel_attention_mask = self.pad(
+ images_list, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=input_data_format
+ )
+
+ if data_format is not None:
+ images_list = [
+ [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ for image in images
+ ]
+ for images in images_list
+ ]
+
+ # Faster tensor conversion
+ data = {"pixel_values": np.array(images_list) if do_pad and return_tensors is not None else images_list}
+ if pixel_attention_mask is not None:
+ data["pixel_attention_mask"] = (
+ np.array(pixel_attention_mask) if do_pad and return_tensors is not None else pixel_attention_mask
+ )
+
+ encoding = BatchFeature(data=data, tensor_type=return_tensors)
+
+ # This is needed for generating correct text inputs in the processor - we don't pad to the max number of images
+ if return_row_col_info:
+ encoding["rows"] = images_list_rows
+ encoding["cols"] = images_list_cols
+
+ return encoding
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number of image patches for a given image size.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*)
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of patches per image.
+ """
+ do_image_splitting = images_kwargs.get("do_image_splitting", self.do_image_splitting)
+ max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
+ size = images_kwargs.get("size", self.size)
+
+ num_patches = num_rows = num_cols = 1
+ if do_image_splitting:
+ height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=size["longest_edge"])
+ height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=4096)
+ aspect_ratio = width / height
+
+ if width >= height:
+ resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+ resized_height = int(width / aspect_ratio)
+ resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+ elif height > width:
+ resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+ resized_width = int(height * aspect_ratio)
+ resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+
+ max_height = max_width = max_image_size["longest_edge"]
+ if resized_height > max_height or resized_width > max_width:
+ # Calculate the number of splits
+ num_rows = math.ceil(resized_height / max_height)
+ num_cols = math.ceil(resized_width / max_width)
+ num_patches = num_rows * num_cols + 1
+
+ return num_patches, num_rows, num_cols
+
+
+__all__ = ["Idefics3ImageProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics3/image_processing_idefics3_fast.py b/venv/lib/python3.13/site-packages/transformers/models/idefics3/image_processing_idefics3_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b0c0e6180f9841e3bd25bf0eec5ea19a262d49c
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/idefics3/image_processing_idefics3_fast.py
@@ -0,0 +1,542 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+from typing import Optional, Union
+
+import torch
+
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ BatchFeature,
+ DefaultFastImageProcessorKwargs,
+ SizeDict,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ImageInput,
+ PILImageResampling,
+ make_nested_list_of_images,
+)
+from ...processing_utils import Unpack
+from ...utils import TensorType, auto_docstring, is_torchvision_available, logging
+
+
+if is_torchvision_available():
+ from torchvision.transforms import functional as F
+
+
+logger = logging.get_logger(__name__)
+
+MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum
+
+
+def _resize_output_size_rescale_to_max_len(
+ height: int, width: int, min_len: Optional[int] = 1, max_len: Optional[int] = None
+) -> tuple[int, int]:
+ """
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ min_len (`int`, *optional*, defaults to 1):
+ Minimum size of the output image.
+ max_len (`int`, *optional*, defaults to the maximum size of the image):
+ Maximum size of the output image.
+ Returns:
+ The output size of the image after resizing.
+ """
+ max_len = max(height, width) if max_len is None else max_len
+ aspect_ratio = width / height
+
+ if width >= height:
+ width = max_len
+ height = int(width / aspect_ratio)
+ if height % 2 != 0:
+ height += 1
+ elif height > width:
+ height = max_len
+ width = int(height * aspect_ratio)
+ if width % 2 != 0:
+ width += 1
+
+ # Avoid resizing to a size smaller than min_len
+ height = max(height, min_len)
+ width = max(width, min_len)
+ return height, width
+
+
+def _resize_output_size_scale_below_upper_bound(
+ height: int, width: int, max_len: Optional[dict[str, int]] = None
+) -> tuple[int, int]:
+ """
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
+ Defines the maximum dimensions of the image.
+ Returns:
+ The output size of the image after resizing.
+ """
+ max_len = max(height, width) if max_len is None else max_len
+
+ aspect_ratio = width / height
+ if width >= height and width > max_len:
+ width = max_len
+ height = int(width / aspect_ratio)
+ elif height > width and height > max_len:
+ height = max_len
+ width = int(height * aspect_ratio)
+
+ # Avoid resizing to a size smaller than 1
+ height = max(height, 1)
+ width = max(width, 1)
+ return height, width
+
+
+def get_resize_output_image_size(
+ image,
+ resolution_max_side: int,
+) -> tuple[int, int]:
+ """
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
+ Args:
+ image (`torch.Tensor`):
+ Image to resize.
+ resolution_max_side (`int`):
+ The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
+ input aspect ratio.
+ Returns:
+ The output size of the image after resizing.
+ """
+ height, width = image.size()[-2:]
+
+ # Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
+ height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
+ # Find the output size when scaling the image to be below the MAX_IMAGE_SIZE
+ height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
+ return height, width
+
+
+def get_max_height_width(images_list: list[list["torch.Tensor"]]) -> tuple[int, int]:
+ """
+ Get the maximum height and width across all images in a batch.
+ """
+ image_sizes = []
+ for images in images_list:
+ for image in images:
+ image_sizes.append(image.size()[-2:])
+
+ max_height = max(size[0] for size in image_sizes)
+ max_width = max(size[1] for size in image_sizes)
+ return (max_height, max_width)
+
+
+def make_pixel_mask(image: "torch.Tensor", output_size: tuple[int, int]) -> "torch.Tensor":
+ """
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+
+ Args:
+ image (`torch.Tensor`):
+ Image to make the pixel mask for.
+ output_size (`Tuple[int, int]`):
+ Output size of the mask.
+ """
+ input_height, input_width = image.size()[-2:]
+ mask = torch.zeros(output_size, dtype=torch.int64, device=image.device)
+ mask[:input_height, :input_width] = 1
+ return mask
+
+
+class Idefics3FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ do_image_splitting (`bool`, *optional*, defaults to `True`):
+ Whether to split the image into sub-images concatenated with the original image. They are split into patches
+ such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
+ max_image_size (`Dict`, *optional*, defaults to `{"longest_edge": 364}`):
+ Maximum resolution of the patches of images accepted by the model. This is a dictionary containing the key "longest_edge".
+ return_row_col_info (`bool`, *optional*, defaults to `False`):
+ Whether to return the row and column information of the images.
+ """
+
+ do_image_splitting: Optional[bool]
+ max_image_size: Optional[dict[str, int]]
+ return_row_col_info: Optional[bool]
+
+
+@auto_docstring
+class Idefics3ImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.LANCZOS
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ size = {"longest_edge": 4 * 364}
+ max_image_size = {"longest_edge": 364}
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+ do_image_splitting = True
+ do_pad = True
+ return_row_col_info = False
+ valid_kwargs = Idefics3FastImageProcessorKwargs
+
+ def _prepare_images_structure(self, images: ImageInput, expected_ndims: int = 3) -> ImageInput:
+ """
+ Prepare a nested images structure for processing.
+ """
+ return make_nested_list_of_images(images, expected_ndims=expected_ndims)
+
+ def resize(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ antialias: bool = True,
+ **kwargs,
+ ) -> "torch.Tensor":
+ """
+ Resize an image. The longest edge of the image is resized to size.longest_edge, with the shortest edge
+ resized to keep the input aspect ratio. Can also be used with size.height and size.width.
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
+ antialias (`bool`, *optional*, defaults to `True`):
+ Whether to use antialiasing when resizing the image.
+ """
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
+ if interpolation == F.InterpolationMode.LANCZOS:
+ logger.warning_once(
+ "You have used fast image processor with LANCZOS resample which not yet supported for torch.Tensor. "
+ "BICUBIC resample will be used as an alternative. Please fall back to slow image processor if you "
+ "want full consistency with the original model."
+ )
+ interpolation = F.InterpolationMode.BICUBIC
+
+ if size.longest_edge:
+ size = get_resize_output_image_size(image, resolution_max_side=size.longest_edge)
+ elif size.height and size.width:
+ size = (size.height, size.width)
+ else:
+ raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
+
+ return F.resize(image, size, interpolation=interpolation, antialias=antialias)
+
+ def split_images(
+ self,
+ images: torch.Tensor,
+ max_image_size: dict[str, int],
+ interpolation: Optional["F.InterpolationMode"] = None,
+ ):
+ """
+ Split an image into squares of side max_image_size and the original image resized to max_image_size.
+ That means that a single image becomes a sequence of images.
+ This is a "trick" to spend more compute on each image with no changes in the vision encoder.
+ 1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio.
+ 2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)`
+ sub-images of the same size each (image_size, image_size). Typically, 364x364.
+ 3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width.
+ Args:
+ images (`torch.Tensor`):
+ Images to split.
+ max_image_size (`Dict[str, int]`):
+ Maximum size of the output image. If the image is larger than this size, it will be split into
+ patches of this size, and the original image will be concatenated with the patches, resized to max_size.
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
+ """
+ batch_size, num_channels, height, width = images.size()
+ height_dim, width_dim = 2, 3
+
+ max_height = max_width = max_image_size["longest_edge"]
+
+ frames = []
+ if height > max_height or width > max_width:
+ # Calculate the number of splits
+ num_splits_h = math.ceil(height / max_height)
+ num_splits_w = math.ceil(width / max_width)
+
+ # Split the images by height, then by width
+ frames = (
+ images.unfold(height_dim, size=max_height, step=max_height)
+ .unfold(width_dim, size=max_width, step=max_width)
+ .contiguous()
+ .view(batch_size, num_channels, -1, max_height, max_width)
+ .permute(0, 2, 1, 3, 4)
+ ) # batch_size x n_frames x num_channels x height x width
+
+ # For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency
+ global_image_height, global_image_width = max_height, max_width
+ images = self.resize(
+ images, SizeDict(height=global_image_height, width=global_image_width), interpolation=interpolation
+ )
+
+ frames = torch.cat((frames, images.unsqueeze(1)), dim=1)
+ else:
+ num_splits_h, num_splits_w = 0, 0
+ frames = images.unsqueeze(1)
+
+ num_splits_h = [num_splits_h] * batch_size
+ num_splits_w = [num_splits_w] * batch_size
+
+ return frames, num_splits_h, num_splits_w
+
+ def resize_for_vision_encoder(
+ self,
+ image: torch.Tensor,
+ vision_encoder_max_size: int,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ ):
+ """
+ Resize images to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
+ Args:
+ image (`torch.Tensor`):
+ Images to resize.
+ vision_encoder_max_size (`int`):
+ Maximum size of the output image. If the image is larger than this size, it will be split into
+ patches of this size, and the original image will be concatenated with the patches, resized to max_size.
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
+ """
+ height, width = image.size()[-2:]
+
+ aspect_ratio = width / height
+ if width >= height:
+ width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
+ height = int(width / aspect_ratio)
+ height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
+ elif height > width:
+ height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
+ width = int(height * aspect_ratio)
+ width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
+ new_size = SizeDict(height=height, width=width)
+ return self.resize(image, size=new_size, interpolation=interpolation)
+
+ def pad(
+ self,
+ image: torch.Tensor,
+ padded_size: tuple[int, int],
+ fill: int = 0,
+ return_pixel_mask: bool = True,
+ ):
+ original_size = image.shape[-2:]
+ padding_bottom = padded_size[0] - original_size[0]
+ padding_right = padded_size[1] - original_size[1]
+
+ if padding_bottom < 0 or padding_right < 0:
+ raise ValueError(
+ f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
+ f"original size. Got padded size: {padded_size}, original size: {original_size}."
+ )
+
+ # Only pad if necessary
+ if original_size != padded_size:
+ padding = (0, 0, padding_right, padding_bottom)
+ image = F.pad(image, padding, fill=fill, padding_mode="constant")
+
+ # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+ pixel_mask = None
+ if return_pixel_mask:
+ pixel_mask = torch.zeros_like(image[..., 0, :, :], dtype=torch.int64)
+ pixel_mask[: original_size[0], : original_size[1]] = 1
+
+ return image, pixel_mask
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[Idefics3FastImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def _preprocess(
+ self,
+ images: list[list["torch.Tensor"]],
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ do_pad: Optional[bool],
+ do_image_splitting: Optional[bool],
+ max_image_size: Optional[dict[str, int]],
+ return_row_col_info: Optional[bool],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Process a batch of images for the model.
+ """
+
+ grouped_images, grouped_images_index = group_images_by_shape(
+ images, is_nested=True, disable_grouping=disable_grouping
+ )
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(stacked_images, size, interpolation=interpolation)
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index, is_nested=True)
+
+ grouped_images, grouped_images_index = group_images_by_shape(
+ resized_images, is_nested=True, disable_grouping=disable_grouping
+ )
+ split_images_grouped = {}
+ if do_image_splitting:
+ rows_grouped = {}
+ cols_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ stacked_images = self.resize_for_vision_encoder(
+ stacked_images, max_image_size["longest_edge"], interpolation=interpolation
+ )
+ stacked_images, rows, cols = self.split_images(
+ stacked_images, max_image_size=max_image_size, interpolation=interpolation
+ )
+ split_images_grouped[shape] = stacked_images
+ rows_grouped[shape] = rows
+ cols_grouped[shape] = cols
+ processed_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
+ rows = reorder_images(rows_grouped, grouped_images_index, is_nested=True)
+ cols = reorder_images(cols_grouped, grouped_images_index, is_nested=True)
+ # flattenened the doubly nested list to a nested list
+ for i, group_images in enumerate(processed_images):
+ processed_images[i] = [image for sublist in group_images for image in sublist]
+ else:
+ for shape, stacked_images in grouped_images.items():
+ # We square the images to max_image_size
+ stacked_images = self.resize(
+ image=stacked_images,
+ size=SizeDict(height=max_image_size["longest_edge"], width=max_image_size["longest_edge"]),
+ interpolation=interpolation,
+ )
+ split_images_grouped[shape] = stacked_images
+ processed_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
+ rows = [[0] * len(images) for images in processed_images]
+ cols = [[0] * len(images) for images in processed_images]
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(
+ processed_images, is_nested=True, disable_grouping=disable_grouping
+ )
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=True)
+ if do_pad:
+ # Get max images per batch
+ max_num_images = max(len(images_) for images_ in processed_images)
+ max_height, max_width = get_max_height_width(processed_images)
+
+ processed_images_padded = torch.zeros(
+ len(processed_images),
+ max_num_images,
+ *(processed_images[0][0].shape[0], max_height, max_width),
+ device=processed_images[0][0].device,
+ )
+ pixel_attention_masks = torch.zeros(
+ len(processed_images),
+ max_num_images,
+ *(max_height, max_width),
+ device=processed_images[0][0].device,
+ )
+ for i, images in enumerate(processed_images):
+ for j, image in enumerate(images):
+ processed_images_padded[i, j], pixel_attention_masks[i, j] = self.pad(
+ image, (max_height, max_width)
+ )
+ processed_images = processed_images_padded
+
+ if do_pad:
+ data = {"pixel_values": processed_images, "pixel_attention_mask": pixel_attention_masks}
+ elif return_tensors == "pt":
+ data = {"pixel_values": torch.stack([torch.stack(images) for images in processed_images])}
+ else:
+ data = {"pixel_values": processed_images}
+ # This is needed for generating correct text inputs in the processor - we don't pad to the max number of images
+ encoding = BatchFeature(data=data, tensor_type=return_tensors)
+
+ if return_row_col_info:
+ encoding["rows"] = rows
+ encoding["cols"] = cols
+
+ return encoding
+
+ def to_dict(self):
+ encoder_dict = super().to_dict()
+ encoder_dict.pop("_valid_processor_keys", None)
+ encoder_dict.pop("return_row_col_info", None)
+ return encoder_dict
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number of image patches for a given image size.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*)
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of patches per image.
+ """
+ do_image_splitting = images_kwargs.get("do_image_splitting", self.do_image_splitting)
+ max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
+ size = images_kwargs.get("size", self.size)
+
+ num_patches = num_rows = num_cols = 1
+ if do_image_splitting:
+ height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=size["longest_edge"])
+ height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
+ aspect_ratio = width / height
+
+ if width >= height:
+ resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+ resized_height = int(width / aspect_ratio)
+ resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+ elif height > width:
+ resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+ resized_width = int(height * aspect_ratio)
+ resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
+
+ max_height = max_width = max_image_size["longest_edge"]
+ if resized_height > max_height or resized_width > max_width:
+ # Calculate the number of splits
+ num_rows = math.ceil(resized_height / max_height)
+ num_cols = math.ceil(resized_width / max_width)
+ num_patches = num_rows * num_cols + 1
+
+ return num_patches, num_rows, num_cols
+
+
+__all__ = ["Idefics3ImageProcessorFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics3/modeling_idefics3.py b/venv/lib/python3.13/site-packages/transformers/models/idefics3/modeling_idefics3.py
new file mode 100644
index 0000000000000000000000000000000000000000..89bbd931fadc35845f1d15428d428c188de3c3ae
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/idefics3/modeling_idefics3.py
@@ -0,0 +1,975 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Idefics3 model."""
+
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, ModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.generic import check_model_inputs
+from ..auto import AutoModel
+from .configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Idefics3 model's outputs that may also contain a past key/values (to speed up sequential decoding).
+ """
+)
+class Idefics3BaseModelOutputWithPast(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+ sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Idefics causal language model (or autoregressive) outputs.
+ """
+)
+class Idefics3CausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+ sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionEmbeddings with Idefics2->Idefics3
+class Idefics3VisionEmbeddings(nn.Module):
+ """
+ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
+ resolution.
+
+ The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://huggingface.co/papers/2307.06304)
+ which allows treating images in their native aspect ratio and without the need to resize them to the same
+ fixed size. In particular, we start from the original pre-trained SigLIP model
+ (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
+ """
+
+ def __init__(self, config: Idefics3VisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+
+ self.num_patches_per_side = self.image_size // self.patch_size
+ self.num_patches = self.num_patches_per_side**2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
+ boundaries = torch.arange(
+ 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
+ )
+ position_ids = torch.full(
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
+ )
+
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
+ nb_patches_h = p_attn_mask[:, 0].sum()
+ nb_patches_w = p_attn_mask[0].sum()
+
+ h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype)
+ w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype)
+
+ fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
+ fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
+
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
+
+ pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
+ position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
+
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics3Vision
+class Idefics3VisionAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ # Ignore copy
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, seq_length, embed_dim = hidden_states.shape
+
+ queries = self.q_proj(hidden_states)
+ keys = self.k_proj(hidden_states)
+ values = self.v_proj(hidden_states)
+
+ queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ queries,
+ keys,
+ values,
+ attention_mask,
+ is_causal=self.is_causal,
+ scaling=self.scale,
+ dropout=0.0 if not self.training else self.dropout,
+ )
+
+ attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics3Vision
+class Idefics3VisionMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Idefics3SimpleMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ input_size = config.vision_config.hidden_size * (config.scale_factor**2)
+ output_size = config.text_config.hidden_size
+ self.proj = nn.Linear(input_size, output_size, bias=False)
+
+ def forward(self, x):
+ return self.proj(x)
+
+
+# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2EncoderLayer with Idefics2->Idefics3
+class Idefics3EncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Idefics3VisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = Idefics3VisionAttention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = Idefics3VisionMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ @auto_docstring
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics3
+class Idefics3Encoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`Idefics3EncoderLayer`].
+
+ Args:
+ config: Idefics3Config
+ """
+
+ def __init__(self, config: Idefics3Config):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([Idefics3EncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ # Ignore copy
+ @auto_docstring
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ )
+
+ hidden_states = layer_outputs
+
+ return BaseModelOutput(last_hidden_state=hidden_states)
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics3
+class Idefics3RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Idefics3RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Idefics3Connector(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.scale_factor = config.scale_factor
+ self.modality_projection = Idefics3SimpleMLP(config)
+
+ def pixel_shuffle(self, x, scale_factor=2):
+ bsz, seq, embed_dim = x.size()
+ height = width = int(seq**0.5)
+ x = x.view(bsz, height, width, embed_dim)
+ x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
+ x = x.permute(0, 2, 1, 3)
+ x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
+ x = x.permute(0, 2, 1, 3)
+ x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
+ return x
+
+ def forward(self, image_hidden_states):
+ image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
+ image_hidden_states = self.modality_projection(image_hidden_states)
+ return image_hidden_states
+
+
+@auto_docstring
+class Idefics3PreTrainedModel(PreTrainedModel):
+ config: Idefics3Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Idefics3VisionAttention", "Idefics3DecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
+
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+ elif isinstance(module, Idefics3RMSNorm):
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring(
+ custom_intro="""
+ The Idefics3 Vision Transformer Model outputting raw image embedding.
+ """
+)
+class Idefics3VisionTransformer(Idefics3PreTrainedModel):
+ config: Idefics3VisionConfig
+ _supports_sdpa = True
+ _supports_flash_attn = True
+ _supports_flex_attn = True
+ _can_record_outputs = {
+ "hidden_states": Idefics3EncoderLayer,
+ "attentions": Idefics3VisionAttention,
+ }
+
+ def __init__(self, config: Idefics3VisionConfig):
+ super().__init__(config)
+ embed_dim = config.hidden_size
+
+ self.embeddings = Idefics3VisionEmbeddings(config)
+ self.encoder = Idefics3Encoder(config)
+ self.patch_size = config.patch_size
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.get_input_embeddings
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.set_input_embeddings
+ def set_input_embeddings(self, value):
+ self.embeddings = value
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ def forward(
+ self,
+ pixel_values,
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutput]:
+ batch_size = pixel_values.size(0)
+ if patch_attention_mask is None:
+ patch_size = self.patch_size
+ patch_attention_mask = torch.ones(
+ (
+ batch_size,
+ pixel_values.size(2) // patch_size,
+ pixel_values.size(3) // patch_size,
+ )
+ )
+ patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
+
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
+
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
+ if self.config._attn_implementation != "flash_attention_2":
+ patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
+ elif not torch.any(~patch_attention_mask):
+ patch_attention_mask = None
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=patch_attention_mask,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ return BaseModelOutput(
+ last_hidden_state=last_hidden_state,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ Idefics3 model consisting of a SIGLIP vision encoder and Llama3 language decoder
+ """
+)
+class Idefics3Model(Idefics3PreTrainedModel):
+ def __init__(self, config: Idefics3Config):
+ super().__init__(config)
+ self.padding_idx = self.config.text_config.pad_token_id
+ self.vocab_size = self.config.text_config.vocab_size
+
+ self.vision_model = Idefics3VisionTransformer._from_config(config.vision_config)
+ self.connector = Idefics3Connector(config)
+ self.text_model = AutoModel.from_config(config.text_config)
+
+ self.image_seq_len = int(
+ ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
+ )
+ self.image_token_id = self.config.image_token_id
+
+ self.post_init()
+
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.enable_input_require_grads
+ def enable_input_require_grads(self):
+ """
+ Enables the gradients for the input embeddings.
+
+ This is useful for lora when using gradient checkpointing.
+ c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
+
+ Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
+ """
+
+ def get_lowest_module(module):
+ if len(list(module.children())) == 0:
+ # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
+ return module
+ else:
+ # Recursively call the function on each child module
+ return get_lowest_module(list(module.children())[0])
+
+ def make_inputs_require_grads(module, input, output):
+ output.requires_grad_(True)
+
+ self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
+ self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
+ make_inputs_require_grads
+ )
+
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.disable_input_require_grads
+ def disable_input_require_grads(self):
+ self._text_require_grads_hook.remove()
+ self._vision_require_grads_hook.remove()
+
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.get_input_embeddings
+ def get_input_embeddings(self):
+ return self.text_model.get_input_embeddings()
+
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.set_input_embeddings
+ def set_input_embeddings(self, value):
+ self.text_model.set_input_embeddings(value)
+
+ def inputs_merger(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: Optional[torch.Tensor],
+ image_hidden_states: Optional[torch.Tensor],
+ ):
+ """
+ This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
+ The merging happens as follows:
+ - The text token sequence is: `tok_1 tok_2 tok_3 ... tok_4`.
+ - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space.
+ We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
+ - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
+ - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
+ return inputs_embeds
+
+ def get_image_features(
+ self, pixel_values: torch.FloatTensor, pixel_attention_mask: Optional[torch.LongTensor] = None
+ ):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ pixel_attention_mask (`torch.LongTensor`, *optional*):
+ The attention mask indicating padded regions in the image.
+ """
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
+ pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
+
+ # Remove padding images - padding images are full 0.
+ nb_values_per_image = pixel_values.shape[1:].numel()
+ real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
+ pixel_values = pixel_values[real_images_inds].contiguous()
+
+ # Handle the vision attention mask
+ if pixel_attention_mask is None:
+ pixel_attention_mask = torch.ones(
+ size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
+ dtype=torch.bool,
+ device=pixel_values.device,
+ )
+ else:
+ # Remove padding images from the mask
+ pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
+ pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
+
+ patch_size = self.config.vision_config.patch_size
+ patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
+ patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
+
+ # Get sequence from the vision encoder
+ image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
+ image_hidden_states.last_hidden_state
+
+ # Modality projection & resampling
+ image_hidden_states = self.connector(image_hidden_states.last_hidden_state)
+ return image_hidden_states
+
+ @can_return_tuple
+ @auto_docstring(
+ custom_intro="""
+ Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
+ the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
+ max_num_images is the maximum number of images among the batch_size samples in the batch.
+ Padding images are not needed beyond padding the pixel_values at the entrance of the model.
+ For efficiency, we only pass through the vision_model's forward the real images by
+ discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
+ image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
+ """
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
+ image_hidden_states: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, Idefics3BaseModelOutputWithPast]:
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The hidden states of the image encoder after modality projection.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.training and self.text_model.gradient_checkpointing and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)
+
+ # START VISUAL INPUTS INTEGRATION
+ if pixel_values is not None and image_hidden_states is not None:
+ raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
+ elif pixel_values is not None:
+ image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
+ elif image_hidden_states is not None:
+ image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
+
+ if image_hidden_states is not None:
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
+ # that simply don't exist
+ inputs_embeds = self.inputs_merger(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ image_hidden_states=image_hidden_states,
+ )
+
+ outputs = self.text_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ return_dict=True,
+ **kwargs,
+ )
+
+ return Idefics3BaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_hidden_states,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top.
+ """
+)
+class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Idefics3Model(config)
+ self.image_token_id = self.config.image_token_id
+
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.vocab_size = config.text_config.vocab_size
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.enable_input_require_grads
+ def enable_input_require_grads(self):
+ """
+ Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
+ the model weights fixed.
+ """
+
+ def make_inputs_require_grads(module, input, output):
+ output.requires_grad_(True)
+
+ self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
+ self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook(
+ make_inputs_require_grads
+ )
+
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.disable_input_require_grads
+ def disable_input_require_grads(self):
+ self._text_require_grads_hook.remove()
+ self._vision_require_grads_hook.remove()
+
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.get_input_embeddings
+ def get_input_embeddings(self):
+ return self.model.text_model.get_input_embeddings()
+
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.set_input_embeddings
+ def set_input_embeddings(self, value):
+ self.model.text_model.set_input_embeddings(value)
+
+ def get_image_features(
+ self, pixel_values: torch.FloatTensor, pixel_attention_mask: Optional[torch.LongTensor] = None
+ ):
+ return self.model.get_image_features(pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
+ image_hidden_states: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ return_dict: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, Idefics3CausalLMOutputWithPast]:
+ r"""
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
+ Mask to avoid performing attention on padding pixel indices.
+ image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The hidden states of the image encoder after modality projection.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`).
+ Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
+ computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> import requests
+ >>> import torch
+ >>> from PIL import Image
+ >>> from io import BytesIO
+
+ >>> from transformers import AutoProcessor, AutoModelForVision2Seq
+ >>> from transformers.image_utils import load_image
+
+ >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
+ >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
+ >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
+ >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
+
+ >>> processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
+ >>> model = AutoModelForVision2Seq.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3", dtype=torch.bfloat16, device_map="auto")
+
+ >>> # Create inputs
+ >>> messages = [
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {"type": "image"},
+ ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
+ ... {"type": "image"},
+ ... {"type": "text", "text": "What can we see in this image?"},
+ ... ]
+ ... },
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {"type": "image"},
+ ... {"type": "text", "text": "In which city is that bridge located?"},
+ ... ]
+ ... }
+ ... ]
+
+ >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
+ >>> images = [[image1, image2], [image3]]
+ >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
+ >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
+
+ >>> print(generated_texts[0])
+ Assistant: There are buildings, trees, lights, and water visible in this image.
+
+ >>> print(generated_texts[1])
+ Assistant: The bridge is in San Francisco.
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ pixel_values=pixel_values,
+ pixel_attention_mask=pixel_attention_mask,
+ image_hidden_states=image_hidden_states,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ return_dict=True,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return Idefics3CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.prepare_inputs_for_generation
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ pixel_values=None,
+ pixel_attention_mask=None,
+ image_hidden_states=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
+ # precedence is moved to the model, we can remove this fn)
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ pixel_values=pixel_values,
+ pixel_attention_mask=pixel_attention_mask,
+ image_hidden_states=image_hidden_states,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if image_hidden_states is not None or cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_attention_mask"] = None
+
+ return model_inputs
+
+
+__all__ = ["Idefics3ForConditionalGeneration", "Idefics3PreTrainedModel", "Idefics3Model", "Idefics3VisionTransformer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/idefics3/processing_idefics3.py b/venv/lib/python3.13/site-packages/transformers/models/idefics3/processing_idefics3.py
new file mode 100644
index 0000000000000000000000000000000000000000..00ee8df6d414c72761aa71003f1bec1d30350b1e
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/idefics3/processing_idefics3.py
@@ -0,0 +1,404 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for Idefics3.
+"""
+
+import re
+from itertools import accumulate
+from typing import TYPE_CHECKING, Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput, is_valid_image, load_image
+from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import AddedToken, BatchEncoding, TextInput
+from ...utils import logging
+
+
+if TYPE_CHECKING:
+ from ...tokenization_utils_base import PreTokenizedInput
+
+logger = logging.get_logger(__name__)
+
+
+def is_url(val) -> bool:
+ return isinstance(val, str) and val.startswith("http")
+
+
+def is_image_or_image_url(elem):
+ return is_url(elem) or is_valid_image(elem)
+
+
+def _prompt_split_image(image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_img_token):
+ """Prompt with expanded image tokens for when the image is split into patches."""
+ text_split_images = ""
+ for n_h in range(image_rows):
+ for n_w in range(image_cols):
+ text_split_images += (
+ f"{fake_token_around_image}" + f"" + f"{image_token}" * image_seq_len
+ )
+ text_split_images += "\n"
+
+ text_split_images += (
+ f"\n{fake_token_around_image}"
+ + f"{global_img_token}"
+ + f"{image_token}" * image_seq_len
+ + f"{fake_token_around_image}"
+ )
+ return text_split_images
+
+
+def _prompt_single_image(image_seq_len, fake_token_around_image, image_token, global_img_token):
+ """Prompt with expanded image tokens for a single image."""
+ return (
+ f"{fake_token_around_image}"
+ + f"{global_img_token}"
+ + f"{image_token}" * image_seq_len
+ + f"{fake_token_around_image}"
+ )
+
+
+def get_image_prompt_string(
+ image_rows, image_cols, image_seq_len, fake_token_around_image, image_token, global_img_token
+):
+ if image_rows == 0 and image_cols == 0:
+ return _prompt_single_image(
+ image_seq_len,
+ fake_token_around_image=fake_token_around_image,
+ image_token=image_token,
+ global_img_token=global_img_token,
+ )
+ return _prompt_split_image(
+ image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_img_token
+ )
+
+
+class Idefics3ImagesKwargs(ImagesKwargs, total=False):
+ return_row_col_info: Optional[bool]
+ max_image_size: Optional[dict[str, int]]
+
+
+class Idefics3ProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs: Idefics3ImagesKwargs
+
+ _defaults = {
+ "text_kwargs": {
+ "add_special_tokens": True,
+ "padding": False,
+ "is_split_into_words": False,
+ "return_mm_token_type_ids": False,
+ },
+ "images_kwargs": {
+ "return_row_col_info": True,
+ },
+ }
+
+
+class Idefics3Processor(ProcessorMixin):
+ r"""
+ Constructs a Idefics3 processor which wraps a LLama tokenizer and Idefics3 image processor into a single processor.
+
+ [`Idefics3Processor`] offers all the functionalities of [`Idefics3ImageProcessor`] and [`Idefics3TokenizerFast`]. See
+ the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information.
+
+ Args:
+ image_processor (`Idefics3ImageProcessor`):
+ An instance of [`Idefics3ImageProcessor`]. The image processor is a required input.
+ tokenizer (`PreTrainedTokenizerBase`, *optional*):
+ An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
+ image_seq_len (`int`, *optional*, defaults to 169):
+ The length of the image sequence i.e. the number of tokens per image in the input.
+ This parameter is used to build the string from the input prompt and image tokens and should match the
+ value the model used. It is computed as: image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2))
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "Idefics3ImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self, image_processor, tokenizer=None, image_seq_len: int = 169, chat_template: Optional[str] = None, **kwargs
+ ):
+ self.fake_image_token = AddedToken("", normalized=False, special=True).content
+ self.image_token = AddedToken("", normalized=False, special=True).content
+ self.end_of_utterance_token = AddedToken("", normalized=False, special=True).content
+ self.global_image_tag = "" # https://github.com/huggingface/transformers/pull/32473/files/8063e5e17362571b693f1db95167f5443a3be1b2#r1734825341
+ self.image_seq_len = image_seq_len
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
+ self.fake_image_token_id = tokenizer.convert_tokens_to_ids(self.fake_image_token)
+ self.global_image_token_id = tokenizer.convert_tokens_to_ids(self.global_image_tag)
+ self.row_col_ids = [
+ tokenizer.convert_tokens_to_ids(f"") for i in range(6) for j in range(6)
+ ]
+
+ # This regex matches one or more occurrences of tags (optionally surrounded by newline characters)
+ # or tags (where x and y are digits, also optionally surrounded by newline characters).
+ self._regex_to_remove_extra_special_tokens = re.compile(r"(\n?\n?|\n?)+")
+
+ tokens_to_add = {
+ "additional_special_tokens": [
+ self.fake_image_token,
+ self.image_token,
+ self.end_of_utterance_token,
+ ]
+ }
+ tokenizer.add_special_tokens(tokens_to_add)
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
+
+ super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
+
+ def _extract_images_from_prompts(self, prompts):
+ prompt_images = []
+ for prompt in prompts:
+ images = []
+ for elem in prompt:
+ if is_valid_image(elem):
+ images.append(elem)
+ elif is_url(elem):
+ images.append(load_image(elem))
+ prompt_images.append(images)
+ return prompt_images
+
+ def __call__(
+ self,
+ images: Union[ImageInput, list[ImageInput], list[list[ImageInput]]] = None,
+ text: Union[TextInput, "PreTokenizedInput", list[TextInput], list["PreTokenizedInput"]] = None,
+ audio=None,
+ videos=None,
+ image_seq_len: Optional[int] = None,
+ **kwargs: Unpack[Idefics3ProcessorKwargs],
+ ) -> BatchEncoding:
+ """
+ Processes the input prompts and returns a BatchEncoding.
+
+ Example:
+
+ ```python
+ >>> import requests
+ >>> from transformers import Idefics3Processor
+ >>> from transformers.image_utils import load_image
+
+ >>> processor = Idefics3Processor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
+ >>> processor.image_processor.do_image_splitting = False # Force as False to simplify the example
+
+ >>> url1 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
+ >>> url2 = "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg"
+
+ >>> image1, image2 = load_image(url1), load_image(url2)
+ >>> images = [[image1], [image2]]
+
+ >>> text = [
+ ... "In this image, we see",
+ ... "bla bla bla",
+ ... ]
+ >>> outputs = processor(images=images, text=text, return_tensors="pt", padding=True)
+ >>> input_ids = outputs.input_ids
+ >>> input_tokens = processor.tokenizer.batch_decode(input_ids)
+ >>> print(input_tokens)
+ ['<|begin_of_text|>(()*169) In this image, we see', '<|reserved_special_token_0|><|reserved_special_token_0|><|reserved_special_token_0|><|begin_of_text|>bla bla bla(()*169)']
+ ```
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`, *optional*):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. If is of type `list[ImageInput]`, it's assumed that this is for a single prompt i.e. of batch size 1.
+ text (`Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]`, *optional*):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ Wherever an image token, `` is encountered it is expanded to
+ `` + `` + `` * `image_seq_len` * `.
+ image_seq_len (`int`, *optional*):
+ The length of the image sequence. If not provided, the default value of self.image_seq_len is used.
+ image_seq_len should be equal to int(((image_size // patch_size) ** 2) / (scale_factor**2))
+ return_tensors (`Union[str, TensorType]`, *optional*):
+ If set, will return tensors of a particular framework. See [`PreTrainedTokenizerFast.__call__`] for more
+ information.
+ """
+ if text is None and images is None:
+ raise ValueError("You must provide either `text` or `images`.")
+
+ output_kwargs = self._merge_kwargs(
+ Idefics3ProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+
+ n_images_in_text = []
+ n_images_in_images = []
+ inputs = {}
+
+ if text is not None:
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list) and not isinstance(text[0], str):
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+ n_images_in_text = [sample.count(self.image_token) for sample in text]
+
+ if images is not None:
+ if is_image_or_image_url(images):
+ images = [[images]]
+ elif isinstance(images, (list, tuple)) and is_image_or_image_url(images[0]):
+ if text is not None:
+ if sum(n_images_in_text) != len(images):
+ raise ValueError(
+ f"The total number of {self.image_token} tokens in the prompts should be the same as the number of images passed."
+ f" Found {sum(n_images_in_text)} {self.image_token} tokens and {len(images)} images."
+ )
+ # Reorganize the images to match the prompts
+ cumsum_images_in_text = [0] + list(accumulate(n_images_in_text))
+ images = [
+ images[cumsum_images_in_text[i] : cumsum_images_in_text[i + 1]]
+ for i in range(len(n_images_in_text))
+ ]
+ else:
+ images = [images]
+ elif (
+ not isinstance(images, (list, tuple))
+ and not isinstance(images[0], (list, tuple))
+ and not is_image_or_image_url(images[0][0])
+ ):
+ raise ValueError(
+ "Invalid input images. Please provide a single image or a list of images or a list of list of images."
+ )
+ n_images_in_images = [len(sample) for sample in images]
+
+ # Load images if they are URLs
+ images = [[load_image(im) if is_url(im) else im for im in sample] for sample in images]
+
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
+ inputs.update(image_inputs)
+
+ if text is not None:
+ if n_images_in_images != n_images_in_text:
+ raise ValueError(
+ f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
+ )
+
+ image_rows = inputs.pop("rows", [[0] * len(text)])
+ image_cols = inputs.pop("cols", [[0] * len(text)])
+
+ fake_image_token = self.fake_image_token
+ image_token = self.image_token
+ global_img_token = self.global_image_tag
+
+ prompt_strings = []
+ batch_image_seq_lengths = []
+ for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
+ # Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
+ image_prompt_strings = []
+ image_seq_lengths = []
+ for n_rows, n_cols in zip(sample_rows, sample_cols):
+ image_prompt_string = get_image_prompt_string(
+ n_rows,
+ n_cols,
+ image_seq_len,
+ image_token=image_token,
+ fake_token_around_image=fake_image_token,
+ global_img_token=global_img_token,
+ )
+ # Add +2 and +3 for special BOI/EOI/fake_image_wrapper tokens
+ row_length = (self.image_seq_len + 2) * n_cols + 1
+ image_seq_lengths.append((self.image_seq_len + 3) + row_length * n_rows)
+ image_prompt_strings.append(image_prompt_string)
+
+ batch_image_seq_lengths.append(image_seq_lengths)
+ split_sample = sample.split(image_token)
+ if len(split_sample) == 0:
+ raise ValueError("The image token should be present in the text.")
+
+ # Place in the image prompt strings where the image tokens are
+ sample = split_sample[0]
+ for i, image_prompt_string in enumerate(image_prompt_strings):
+ sample += image_prompt_string + split_sample[i + 1]
+ prompt_strings.append(sample)
+
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
+ self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
+ inputs.update(text_inputs)
+
+ elif text is not None:
+ if any(n_images_in_text):
+ raise ValueError(
+ f"Found {sum(n_images_in_text)} {self.image_token} tokens in the text but no images were passed."
+ )
+ text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
+ inputs.update(text_inputs)
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(array_ids)
+ for i, seq_lengths in enumerate(batch_image_seq_lengths):
+ image_start_positions = np.where(array_ids[i] == self.fake_image_token_id)[0]
+ j = 0
+ for seq_len in seq_lengths:
+ if j >= len(image_start_positions):
+ break
+ start = image_start_positions[j]
+ end = start + seq_len
+ mm_token_type_ids[i, start:end] = 1
+ j = np.searchsorted(image_start_positions, end)
+
+ inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data=inputs, tensor_type=return_tensors)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+
+ Args:
+ image_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (height, width) per each image.
+
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = Idefics3ProcessorKwargs._defaults.get("images_kwargs", {})
+ images_kwargs.update(kwargs)
+
+ num_image_row_cols = [
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
+ for image_size in image_sizes
+ ]
+
+ base_image_length = self.image_seq_len + 3
+ col_length = self.image_seq_len + 2
+ num_image_tokens = []
+ num_image_patches = []
+
+ for num_patches, num_rows, num_cols in num_image_row_cols:
+ row_length = col_length * num_cols + 1
+ num_image_tokens.append(base_image_length + (row_length * num_rows))
+ num_image_patches.append(num_patches)
+
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
+
+ return MultiModalData(**vision_data)
+
+
+__all__ = ["Idefics3Processor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/imagegpt/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/imagegpt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..098ffb6296f547e6dd9f1f990d21e28bc5cb0f7b
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/imagegpt/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_imagegpt import *
+ from .feature_extraction_imagegpt import *
+ from .image_processing_imagegpt import *
+ from .image_processing_imagegpt_fast import *
+ from .modeling_imagegpt import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/imagegpt/configuration_imagegpt.py b/venv/lib/python3.13/site-packages/transformers/models/imagegpt/configuration_imagegpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cfa8d5e47826e4323f74b33430fa872005c88ea
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/imagegpt/configuration_imagegpt.py
@@ -0,0 +1,200 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""OpenAI ImageGPT configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+from typing import TYPE_CHECKING, Any, Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+if TYPE_CHECKING:
+ from ... import FeatureExtractionMixin, TensorType
+
+logger = logging.get_logger(__name__)
+
+
+class ImageGPTConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`ImageGPTModel`] or a [`TFImageGPTModel`]. It is
+ used to instantiate a GPT-2 model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the ImageGPT
+ [openai/imagegpt-small](https://huggingface.co/openai/imagegpt-small) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 512):
+ Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ImageGPTModel`] or [`TFImageGPTModel`].
+ n_positions (`int`, *optional*, defaults to 32*32):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ n_embd (`int`, *optional*, defaults to 512):
+ Dimensionality of the embeddings and hidden states.
+ n_layer (`int`, *optional*, defaults to 24):
+ Number of hidden layers in the Transformer encoder.
+ n_head (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ n_inner (`int`, *optional*, defaults to None):
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
+ activation_function (`str`, *optional*, defaults to `"quick_gelu"`):
+ Activation function (can be one of the activation functions defined in src/transformers/activations.py).
+ Defaults to "quick_gelu".
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ embd_pdrop (`int`, *optional*, defaults to 0.1):
+ The dropout ratio for the embeddings.
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
+ The epsilon to use in the layer normalization layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
+ Scale attention weights by dividing by sqrt(hidden_size)..
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
+ Whether to additionally scale attention weights by `1 / layer_idx + 1`.
+ reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
+ Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
+ dot-product/softmax to float() when training with mixed precision.
+
+ Example:
+
+ ```python
+ >>> from transformers import ImageGPTConfig, ImageGPTModel
+
+ >>> # Initializing a ImageGPT configuration
+ >>> configuration = ImageGPTConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = ImageGPTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "imagegpt"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {
+ "hidden_size": "n_embd",
+ "max_position_embeddings": "n_positions",
+ "num_attention_heads": "n_head",
+ "num_hidden_layers": "n_layer",
+ }
+
+ def __init__(
+ self,
+ vocab_size=512 + 1, # add one for start of sentence (sos) token
+ n_positions=32 * 32,
+ n_embd=512,
+ n_layer=24,
+ n_head=8,
+ n_inner=None,
+ activation_function="quick_gelu",
+ resid_pdrop=0.1,
+ embd_pdrop=0.1,
+ attn_pdrop=0.1,
+ layer_norm_epsilon=1e-5,
+ initializer_range=0.02,
+ scale_attn_weights=True,
+ use_cache=True,
+ tie_word_embeddings=False,
+ scale_attn_by_inverse_layer_idx=False,
+ reorder_and_upcast_attn=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.n_positions = n_positions
+ self.n_embd = n_embd
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.n_inner = n_inner
+ self.activation_function = activation_function
+ self.resid_pdrop = resid_pdrop
+ self.embd_pdrop = embd_pdrop
+ self.attn_pdrop = attn_pdrop
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_range = initializer_range
+ self.scale_attn_weights = scale_attn_weights
+ self.use_cache = use_cache
+ self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
+ self.reorder_and_upcast_attn = reorder_and_upcast_attn
+ self.tie_word_embeddings = tie_word_embeddings
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+class ImageGPTOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("input_ids", {0: "batch", 1: "sequence"}),
+ ]
+ )
+
+ def generate_dummy_inputs(
+ self,
+ preprocessor: "FeatureExtractionMixin",
+ batch_size: int = 1,
+ seq_length: int = -1,
+ is_pair: bool = False,
+ framework: Optional["TensorType"] = None,
+ num_channels: int = 3,
+ image_width: int = 32,
+ image_height: int = 32,
+ ) -> Mapping[str, Any]:
+ """
+ Generate inputs to provide to the ONNX exporter for the specific framework
+
+ Args:
+ preprocessor ([`PreTrainedTokenizerBase`] or [`FeatureExtractionMixin`]):
+ The preprocessor associated with this model configuration.
+ batch_size (`int`, *optional*, defaults to -1):
+ The batch size to export the model for (-1 means dynamic axis).
+ num_choices (`int`, *optional*, defaults to -1):
+ The number of candidate answers provided for multiple choice task (-1 means dynamic axis).
+ seq_length (`int`, *optional*, defaults to -1):
+ The sequence length to export the model for (-1 means dynamic axis).
+ is_pair (`bool`, *optional*, defaults to `False`):
+ Indicate if the input is a pair (sentence 1, sentence 2)
+ framework (`TensorType`, *optional*, defaults to `None`):
+ The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of channels of the generated images.
+ image_width (`int`, *optional*, defaults to 40):
+ The width of the generated images.
+ image_height (`int`, *optional*, defaults to 40):
+ The height of the generated images.
+
+ Returns:
+ Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
+ """
+
+ input_image = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
+ inputs = dict(preprocessor(images=input_image, return_tensors=framework))
+
+ return inputs
+
+
+__all__ = ["ImageGPTConfig", "ImageGPTOnnxConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/imagegpt/feature_extraction_imagegpt.py b/venv/lib/python3.13/site-packages/transformers/models/imagegpt/feature_extraction_imagegpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..46787f139f10a0c339e4cea9524d34c00a03ceb6
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/imagegpt/feature_extraction_imagegpt.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for ImageGPT."""
+
+import warnings
+
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_imagegpt import ImageGPTImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class ImageGPTFeatureExtractor(ImageGPTImageProcessor):
+ def __init__(self, *args, **kwargs) -> None:
+ warnings.warn(
+ "The class ImageGPTFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
+ " Please use ImageGPTImageProcessor instead.",
+ FutureWarning,
+ )
+ super().__init__(*args, **kwargs)
+
+
+__all__ = ["ImageGPTFeatureExtractor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/imagegpt/image_processing_imagegpt.py b/venv/lib/python3.13/site-packages/transformers/models/imagegpt/image_processing_imagegpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa2114509f7026d9f266be952dc220f2de1b632d
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/imagegpt/image_processing_imagegpt.py
@@ -0,0 +1,314 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for ImageGPT."""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import rescale, resize, to_channel_dimension_format
+from ...image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+from ...utils.import_utils import requires
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+def squared_euclidean_distance(a, b):
+ b = b.T
+ a2 = np.sum(np.square(a), axis=1)
+ b2 = np.sum(np.square(b), axis=0)
+ ab = np.matmul(a, b)
+ d = a2[:, None] - 2 * ab + b2[None, :]
+ return d
+
+
+def color_quantize(x, clusters):
+ x = x.reshape(-1, 3)
+ d = squared_euclidean_distance(x, clusters)
+ return np.argmin(d, axis=1)
+
+
+@requires(backends=("vision",))
+class ImageGPTImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a ImageGPT image processor. This image processor can be used to resize images to a smaller resolution
+ (such as 32x32 or 64x64), normalize them and finally color quantize them to obtain sequences of "pixel values"
+ (color clusters).
+
+ Args:
+ clusters (`np.ndarray` or `list[list[int]]`, *optional*):
+ The color clusters to use, of shape `(n_clusters, 3)` when color quantizing. Can be overridden by `clusters`
+ in `preprocess`.
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's dimensions to `(size["height"], size["width"])`. Can be overridden by
+ `do_resize` in `preprocess`.
+ size (`dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
+ Size of the image after resizing. Can be overridden by `size` in `preprocess`.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image pixel value to between [-1, 1]. Can be overridden by `do_normalize` in
+ `preprocess`.
+ do_color_quantize (`bool`, *optional*, defaults to `True`):
+ Whether to color quantize the image. Can be overridden by `do_color_quantize` in `preprocess`.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ # clusters is a first argument to maintain backwards compatibility with the old ImageGPTImageProcessor
+ clusters: Optional[Union[list[list[int]], np.ndarray]] = None,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_normalize: bool = True,
+ do_color_quantize: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 256, "width": 256}
+ size = get_size_dict(size)
+ self.clusters = np.array(clusters) if clusters is not None else None
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_normalize = do_normalize
+ self.do_color_quantize = do_color_quantize
+
+ # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
+ output_size = (size["height"], size["width"])
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Normalizes an images' pixel values to between [-1, 1].
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ image = rescale(image=image, scale=1 / 127.5, data_format=data_format, input_data_format=input_data_format)
+ image = image - 1
+ return image
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_normalize: Optional[bool] = None,
+ do_color_quantize: Optional[bool] = None,
+ clusters: Optional[Union[list[list[int]], np.ndarray]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_normalize=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image
+ do_color_quantize (`bool`, *optional*, defaults to `self.do_color_quantize`):
+ Whether to color quantize the image.
+ clusters (`np.ndarray` or `list[list[int]]`, *optional*, defaults to `self.clusters`):
+ Clusters used to quantize the image of shape `(n_clusters, 3)`. Only has an effect if
+ `do_color_quantize` is set to `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ Only has an effect if `do_color_quantize` is set to `False`.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ resample = resample if resample is not None else self.resample
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ do_color_quantize = do_color_quantize if do_color_quantize is not None else self.do_color_quantize
+ clusters = clusters if clusters is not None else self.clusters
+ clusters = np.array(clusters)
+
+ images = make_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ # Here, normalize() is using a constant factor to divide pixel values.
+ # hence, the method does not need image_mean and image_std.
+ validate_preprocess_arguments(
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if do_color_quantize and clusters is None:
+ raise ValueError("Clusters must be specified if do_color_quantize is True.")
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_normalize and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If you wish to do this, "
+ "make sure to set `do_normalize` to `False` and that pixel values are between [-1, 1].",
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if do_resize:
+ images = [
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_normalize:
+ images = [self.normalize(image=image, input_data_format=input_data_format) for image in images]
+
+ if do_color_quantize:
+ images = [to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format) for image in images]
+ # color quantize from (batch_size, height, width, 3) to (batch_size, height, width)
+ images = np.array(images)
+ images = color_quantize(images, clusters).reshape(images.shape[:-1])
+
+ # flatten to (batch_size, height*width)
+ batch_size = images.shape[0]
+ images = images.reshape(batch_size, -1)
+
+ # We need to convert back to a list of images to keep consistent behaviour across processors.
+ images = list(images)
+ data = {"input_ids": images}
+ else:
+ images = [to_channel_dimension_format(image, data_format, input_data_format) for image in images]
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def to_dict(self):
+ output = super().to_dict()
+ # Ensure clusters are JSON/equality friendly
+ if output.get("clusters") is not None and isinstance(output["clusters"], np.ndarray):
+ output["clusters"] = output["clusters"].tolist()
+ # Need to set missing keys from slow processor to match the expected behavior in save/load tests compared to fast processor
+ missing_keys = ["image_mean", "image_std", "rescale_factor", "do_rescale"]
+ for key in missing_keys:
+ if key in output:
+ output[key] = None
+
+ return output
+
+
+__all__ = ["ImageGPTImageProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/imagegpt/image_processing_imagegpt_fast.py b/venv/lib/python3.13/site-packages/transformers/models/imagegpt/image_processing_imagegpt_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a6bcc53ae1acaefa1f4a923f7d5b422c7a62d8e
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/imagegpt/image_processing_imagegpt_fast.py
@@ -0,0 +1,198 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for ImageGPT."""
+
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+)
+from ...image_transforms import group_images_by_shape, reorder_images
+from ...image_utils import PILImageResampling
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ auto_docstring,
+)
+
+
+def squared_euclidean_distance_torch(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ """
+ Compute squared Euclidean distances between all pixels and clusters.
+
+ Args:
+ a: (N, 3) tensor of pixel RGB values
+ b: (M, 3) tensor of cluster RGB values
+
+ Returns:
+ (N, M) tensor of squared distances
+ """
+ b = b.t() # (3, M)
+ a2 = torch.sum(a**2, dim=1) # (N,)
+ b2 = torch.sum(b**2, dim=0) # (M,)
+ ab = torch.matmul(a, b) # (N, M)
+ d = a2[:, None] - 2 * ab + b2[None, :] # Squared Euclidean Distance: a^2 - 2ab + b^2
+ return d # (N, M) tensor of squared distances
+
+
+def color_quantize_torch(x: torch.Tensor, clusters: torch.Tensor) -> torch.Tensor:
+ """
+ Assign each pixel to its nearest color cluster.
+
+ Args:
+ x: (H*W, 3) tensor of flattened pixel RGB values
+ clusters: (n_clusters, 3) tensor of cluster RGB values
+
+ Returns:
+ (H*W,) tensor of cluster indices
+ """
+ d = squared_euclidean_distance_torch(x, clusters)
+ return torch.argmin(d, dim=1)
+
+
+class ImageGPTFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ clusters (`np.ndarray` or `list[list[int]]` or `torch.Tensor`, *optional*):
+ The color clusters to use, of shape `(n_clusters, 3)` when color quantizing. Can be overridden by `clusters`
+ in `preprocess`.
+ do_color_quantize (`bool`, *optional*, defaults to `True`):
+ Controls whether to apply color quantization to convert continuous pixel values to discrete cluster indices.
+ When True, each pixel is assigned to its nearest color cluster, enabling ImageGPT's discrete token modeling.
+ """
+
+ clusters: Optional[Union[np.ndarray, list[list[int]], torch.Tensor]]
+ do_color_quantize: Optional[bool]
+
+
+@auto_docstring
+class ImageGPTImageProcessorFast(BaseImageProcessorFast):
+ model_input_names = ["input_ids"]
+ resample = PILImageResampling.BILINEAR
+ do_color_quantize = True
+ clusters = None
+ image_mean = [0.5, 0.5, 0.5]
+ image_std = [0.5, 0.5, 0.5]
+ do_rescale = True
+ do_normalize = True
+ valid_kwargs = ImageGPTFastImageProcessorKwargs
+
+ def __init__(
+ self,
+ clusters: Optional[Union[list, np.ndarray, torch.Tensor]] = None, # keep as arg for backwards compatibility
+ **kwargs: Unpack[ImageGPTFastImageProcessorKwargs],
+ ):
+ r"""
+ clusters (`np.ndarray` or `list[list[int]]` or `torch.Tensor`, *optional*):
+ The color clusters to use, of shape `(n_clusters, 3)` when color quantizing. Can be overridden by `clusters`
+ in `preprocess`.
+ """
+ clusters = torch.as_tensor(clusters, dtype=torch.float32) if clusters is not None else None
+ super().__init__(clusters=clusters, **kwargs)
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: dict[str, int],
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: dict[str, int],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ do_color_quantize: Optional[bool] = None,
+ clusters: Optional[Union[list, np.ndarray, torch.Tensor]] = None,
+ disable_grouping: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs,
+ ):
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_center_crop:
+ stacked_images = self.center_crop(stacked_images, crop_size)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ pixel_values = reorder_images(processed_images_grouped, grouped_images_index)
+
+ # If color quantization is requested, perform it; otherwise return pixel values
+ if do_color_quantize:
+ # Prepare clusters
+ if clusters is None:
+ raise ValueError("Clusters must be provided for color quantization.")
+ # Convert to torch tensor if needed (clusters might be passed as list/numpy)
+ clusters_torch = (
+ torch.as_tensor(clusters, dtype=torch.float32) if not isinstance(clusters, torch.Tensor) else clusters
+ ).to(pixel_values[0].device, dtype=pixel_values[0].dtype)
+
+ # Group images by shape for batch processing
+ # We need to check if the pixel values are a tensor or a list of tensors
+ grouped_images, grouped_images_index = group_images_by_shape(
+ pixel_values, disable_grouping=disable_grouping
+ )
+ # Process each group
+ input_ids_grouped = {}
+
+ for shape, stacked_images in grouped_images.items():
+ input_ids = color_quantize_torch(
+ stacked_images.permute(0, 2, 3, 1).reshape(-1, 3), clusters_torch
+ ) # (B*H*W, C)
+ input_ids_grouped[shape] = input_ids.reshape(stacked_images.shape[0], -1).reshape(
+ stacked_images.shape[0], -1
+ ) # (B, H, W)
+
+ input_ids = reorder_images(input_ids_grouped, grouped_images_index)
+
+ return BatchFeature(
+ data={"input_ids": torch.stack(input_ids, dim=0) if return_tensors else input_ids},
+ tensor_type=return_tensors,
+ )
+
+ pixel_values = torch.stack(pixel_values, dim=0) if return_tensors else pixel_values
+ return BatchFeature(data={"pixel_values": pixel_values}, tensor_type=return_tensors)
+
+ def to_dict(self):
+ # Convert torch tensors to lists for JSON serialization
+ output = super().to_dict()
+ if output.get("clusters") is not None and isinstance(output["clusters"], torch.Tensor):
+ output["clusters"] = output["clusters"].tolist()
+
+ return output
+
+
+__all__ = ["ImageGPTImageProcessorFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/imagegpt/modeling_imagegpt.py b/venv/lib/python3.13/site-packages/transformers/models/imagegpt/modeling_imagegpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..a962141e447946633a2598b3c0bfe43b46f10e6d
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/imagegpt/modeling_imagegpt.py
@@ -0,0 +1,1024 @@
+# coding=utf-8
+# Copyright 2021 The OpenAI Team Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch OpenAI ImageGPT model."""
+
+import math
+import os
+from typing import Any, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import GenerationMixin
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ SequenceClassifierOutputWithPast,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
+from ...utils import (
+ auto_docstring,
+ logging,
+ torch_float,
+)
+from .configuration_imagegpt import ImageGPTConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+def load_tf_weights_in_imagegpt(model, config, imagegpt_checkpoint_path):
+ """
+ Load tf checkpoints in a pytorch model
+ """
+ try:
+ import re
+
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(imagegpt_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array.squeeze())
+
+ for name, array in zip(names, arrays):
+ name = name[6:] # skip "model/"
+ name = name.split("/")
+
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
+ ) or name[-1] in ["_step"]:
+ logger.info("Skipping {}".format("/".join(name)))
+ continue
+
+ pointer = model
+ if name[-1] not in ["wtet"]:
+ pointer = getattr(pointer, "transformer")
+
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
+ scope_names = re.split(r"(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+
+ if scope_names[0] == "w" or scope_names[0] == "g":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "b":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
+ pointer = getattr(pointer, scope_names[0])
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] in ["q_proj", "k_proj", "v_proj"]:
+ pointer = getattr(pointer, "c_attn")
+ pointer = getattr(pointer, "weight")
+ elif len(name) == 3 and name[1] == "attn" and scope_names[0] == "c_proj":
+ pointer = getattr(pointer, scope_names[0])
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "wtet":
+ pointer = getattr(pointer, "lm_head")
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "sos":
+ pointer = getattr(pointer, "wte")
+ pointer = getattr(pointer, "weight")
+ else:
+ pointer = getattr(pointer, scope_names[0])
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+
+ if len(name) > 1 and name[1] == "attn" or name[-1] == "wtet" or name[-1] == "sos" or name[-1] == "wte":
+ pass # array is used to initialize only part of the pointer so sizes won't match
+ else:
+ try:
+ assert pointer.shape == array.shape
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+
+ logger.info(f"Initialize PyTorch weight {name}")
+
+ if name[-1] == "q_proj":
+ pointer.data[:, : config.n_embd] = torch.from_numpy(array.reshape(config.n_embd, config.n_embd)).T
+ elif name[-1] == "k_proj":
+ pointer.data[:, config.n_embd : 2 * config.n_embd] = torch.from_numpy(
+ array.reshape(config.n_embd, config.n_embd)
+ ).T
+ elif name[-1] == "v_proj":
+ pointer.data[:, 2 * config.n_embd :] = torch.from_numpy(array.reshape(config.n_embd, config.n_embd)).T
+ elif len(name) == 3 and name[1] == "attn" and name[2] == "c_proj":
+ pointer.data = torch.from_numpy(array.reshape(config.n_embd, config.n_embd))
+ elif name[-1] == "wtet":
+ pointer.data = torch.from_numpy(array)
+ elif name[-1] == "wte":
+ pointer.data[: config.vocab_size - 1, :] = torch.from_numpy(array)
+ elif name[-1] == "sos":
+ pointer.data[-1] = torch.from_numpy(array)
+ else:
+ pointer.data = torch.from_numpy(array)
+
+ return model
+
+
+class ImageGPTLayerNorm(nn.Module):
+ def __init__(self, hidden_size: tuple[int], eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.Tensor(hidden_size))
+
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
+ # input is not mean centered
+ tensor = tensor / torch.sqrt(torch.mean(torch.square(tensor), axis=-1, keepdim=True) + self.eps)
+ tensor = tensor * self.weight
+ return tensor
+
+
+class ImageGPTAttention(nn.Module):
+ def __init__(self, config, is_cross_attention: Optional[bool] = False, layer_idx: Optional[int] = None):
+ super().__init__()
+
+ max_positions = config.max_position_embeddings
+ self.register_buffer(
+ "bias",
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
+ 1, 1, max_positions, max_positions
+ ),
+ persistent=False,
+ )
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
+
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ self.split_size = self.embed_dim
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ self.scale_attn_weights = config.scale_attn_weights
+ self.is_cross_attention = is_cross_attention
+
+ # Layer-wise attention scaling, reordering, and upcasting
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
+ self.layer_idx = layer_idx
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
+
+ if self.is_cross_attention:
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
+ else:
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
+
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
+
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
+
+ # Prune conv1d layers
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
+
+ # Update hyper params
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
+ self.num_heads = self.num_heads - len(heads)
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
+
+ if self.scale_attn_weights:
+ attn_weights = attn_weights / torch_float(value.size(-1) ** 0.5)
+
+ # Layer-wise attention scaling
+ if self.scale_attn_by_inverse_layer_idx:
+ attn_weights = attn_weights / float(self.layer_idx + 1)
+
+ if not self.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.Softmax(dim=-1)(attn_weights)
+
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
+ attn_weights = attn_weights.type(value.dtype)
+ attn_weights = self.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+
+ return attn_output, attn_weights
+
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
+ bsz, num_heads, q_seq_len, dk = query.size()
+ _, _, k_seq_len, _ = key.size()
+
+ # Preallocate attn_weights for `baddbmm`
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
+
+ # Compute Scale Factor
+ scale_factor = 1.0
+ if self.scale_attn_weights:
+ scale_factor /= float(value.size(-1)) ** 0.5
+
+ if self.scale_attn_by_inverse_layer_idx:
+ scale_factor /= float(self.layer_idx + 1)
+
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
+ with torch.autocast(query.device.type, enabled=False):
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
+
+ if not self.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.Softmax(dim=-1)(attn_weights)
+
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
+ if attn_weights.dtype != torch.float32:
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
+ attn_weights = attn_weights.type(value.dtype)
+ attn_weights = self.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+
+ return attn_output, attn_weights
+
+ def _split_heads(self, tensor, num_heads, attn_head_size):
+ """
+ Splits hidden_size dim into attn_head_size and num_heads
+ """
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
+ tensor = tensor.view(*new_shape)
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
+
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
+ """
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
+ """
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
+ return tensor.view(new_shape)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ layer_past: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> tuple:
+ is_cross_attention = encoder_hidden_states is not None
+ bsz, seq_len, _ = hidden_states.shape
+
+ if layer_past is not None:
+ if isinstance(layer_past, EncoderDecoderCache):
+ is_updated = layer_past.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = layer_past.cross_attention_cache
+ else:
+ curr_past_key_value = layer_past.self_attention_cache
+ else:
+ curr_past_key_value = layer_past
+
+ current_states = encoder_hidden_states if is_cross_attention else hidden_states
+ if is_cross_attention:
+ if not hasattr(self, "q_attn"):
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `ImageGPTAttention(..., is_cross_attention=True)`."
+ )
+
+ if layer_past is not None and is_updated:
+ # reuse k,v, cross_attentions, and compute only q
+ query = self.q_attn(hidden_states)
+ key = curr_past_key_value.layers[self.layer_idx].keys
+ value = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ query = self.q_attn(hidden_states)
+ key, value = self.c_attn(current_states).split(self.split_size, dim=2)
+ key = key.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ value = value.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ else:
+ query, key, value = self.c_attn(current_states).split(self.split_size, dim=2)
+ key = key.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ value = value.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+
+ if layer_past is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key, value = curr_past_key_value.update(key, value, self.layer_idx, {"cache_position": cache_position})
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention:
+ layer_past.is_updated[self.layer_idx] = True
+
+ query = query.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ if self.reorder_and_upcast_attn:
+ attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
+ else:
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
+
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+
+ return attn_output, attn_weights
+
+
+class ImageGPTMLP(nn.Module):
+ def __init__(self, intermediate_size, config):
+ super().__init__()
+ embed_dim = config.hidden_size
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
+ self.act = ACT2FN[config.activation_function]
+ self.dropout = nn.Dropout(config.resid_pdrop)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class ImageGPTBlock(GradientCheckpointingLayer):
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ hidden_size = config.hidden_size
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
+
+ self.ln_1 = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+ self.attn = ImageGPTAttention(config, layer_idx=layer_idx)
+ self.ln_2 = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ if config.add_cross_attention:
+ self.crossattention = ImageGPTAttention(config, is_cross_attention=True, layer_idx=layer_idx)
+ self.ln_cross_attn = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ self.mlp = ImageGPTMLP(inner_dim, config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ layer_past: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> tuple:
+ residual = hidden_states
+ hidden_states = self.ln_1(hidden_states)
+ attn_outputs = self.attn(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ attn_output = attn_outputs[0]
+ outputs = attn_outputs[1:]
+ # residual connection
+ hidden_states = attn_output + residual
+
+ if encoder_hidden_states is not None:
+ # add one self-attention block for cross-attention
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
+ "cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+ residual = hidden_states
+ hidden_states = self.ln_cross_attn(hidden_states)
+ cross_attn_outputs = self.crossattention(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ attn_output = cross_attn_outputs[0]
+ # residual connection
+ hidden_states = residual + attn_output
+ outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
+
+ residual = hidden_states
+ hidden_states = self.ln_2(hidden_states)
+ feed_forward_hidden_states = self.mlp(hidden_states)
+ # residual connection
+ hidden_states = residual + feed_forward_hidden_states
+
+ return (hidden_states,) + outputs
+
+
+@auto_docstring
+class ImageGPTPreTrainedModel(PreTrainedModel):
+ config: ImageGPTConfig
+ load_tf_weights = load_tf_weights_in_imagegpt
+ base_model_prefix = "transformer"
+ main_input_name = "input_ids"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["ImageGPTBlock"]
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, (nn.Linear, Conv1D)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, ImageGPTLayerNorm):
+ module.weight.data.fill_(1.0)
+
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
+ #
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+ for name, p in module.named_parameters():
+ if "c_proj" in name and "weight" in name:
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
+
+
+@auto_docstring
+class ImageGPTModel(ImageGPTPreTrainedModel):
+ def __init__(self, config: ImageGPTConfig):
+ super().__init__(config)
+
+ self.embed_dim = config.hidden_size
+
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
+
+ self.drop = nn.Dropout(config.embd_pdrop)
+ self.h = nn.ModuleList([ImageGPTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+ self.ln_f = ImageGPTLayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.wte
+
+ def set_input_embeddings(self, new_embeddings):
+ self.wte = new_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
+ """
+ for layer, heads in heads_to_prune.items():
+ self.h[layer].attn.prune_heads(heads)
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs: Any,
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, ImageGPTModel
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
+ >>> model = ImageGPTModel.from_pretrained("openai/imagegpt-small")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size = inputs_embeds.shape[0]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if use_cache and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+ if use_cache and isinstance(past_key_values, tuple):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
+
+ past_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+
+ if position_ids is None:
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+ position_ids = position_ids.unsqueeze(0)
+
+ # ImageGPTAttention mask.
+ if attention_mask is not None:
+ if batch_size <= 0:
+ raise ValueError("batch_size has to be defined and > 0")
+ attention_mask = attention_mask.view(batch_size, -1)
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # head_mask has shape n_layer x batch x n_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+ position_embeds = self.wpe(position_ids)
+ hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
+
+ if token_type_ids is not None:
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+ output_shape = input_shape + (hidden_states.size(-1),)
+
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, block in enumerate(self.h):
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ # Ensure that attention_mask is always on the same device as hidden_states
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(hidden_states.device)
+ if isinstance(head_mask, torch.Tensor):
+ head_mask = head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = block(
+ hidden_states,
+ past_key_values,
+ attention_mask,
+ head_mask[i],
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (outputs[2],)
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.ln_f(hidden_states)
+ hidden_states = hidden_states.view(*output_shape)
+
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The ImageGPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """
+)
+class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: ImageGPTConfig):
+ super().__init__(config)
+ self.transformer = ImageGPTModel(config)
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size - 1, bias=False)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs: Any,
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.
+ labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, ImageGPTForCausalImageModeling
+ >>> import torch
+ >>> import matplotlib.pyplot as plt
+ >>> import numpy as np
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
+ >>> model = ImageGPTForCausalImageModeling.from_pretrained("openai/imagegpt-small")
+ >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ >>> model.to(device) # doctest: +IGNORE_RESULT
+
+ >>> # unconditional generation of 8 images
+ >>> batch_size = 4
+ >>> context = torch.full((batch_size, 1), model.config.vocab_size - 1) # initialize with SOS token
+ >>> context = context.to(device)
+ >>> output = model.generate(
+ ... input_ids=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40
+ ... )
+
+ >>> clusters = image_processor.clusters
+ >>> height = image_processor.size["height"]
+ >>> width = image_processor.size["width"]
+
+ >>> samples = output[:, 1:].detach().cpu().numpy()
+ >>> samples_img = [
+ ... np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [height, width, 3]).astype(np.uint8) for s in samples
+ ... ] # convert color cluster tokens back to pixels
+ >>> f, axes = plt.subplots(1, batch_size, dpi=300)
+
+ >>> for img, ax in zip(samples_img, axes): # doctest: +IGNORE_RESULT
+ ... ax.axis("off")
+ ... ax.imshow(img)
+ ```"""
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ cross_attentions=transformer_outputs.cross_attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The ImageGPT Model transformer with an image classification head on top (linear layer).
+ [`ImageGPTForImageClassification`] average-pools the hidden states in order to do the classification.
+ """
+)
+class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
+ def __init__(self, config: ImageGPTConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = ImageGPTModel(config)
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> Union[tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, ImageGPTForImageClassification
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
+ >>> model = ImageGPTForImageClassification.from_pretrained("openai/imagegpt-small")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ ```"""
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ # average-pool the hidden states along the sequence dimension
+ pooled_hidden_states = hidden_states.mean(dim=1)
+ # project from (batch_size, hidden_size) to (batch_size, num_labels)
+ logits = self.score(pooled_hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(labels, logits, self.config)
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+__all__ = [
+ "ImageGPTForCausalImageModeling",
+ "ImageGPTForImageClassification",
+ "ImageGPTModel",
+ "ImageGPTPreTrainedModel",
+ "load_tf_weights_in_imagegpt",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2eb06450487cbea467b3c7be4be07ad524b47042
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_instructblipvideo import *
+ from .image_processing_instructblipvideo import *
+ from .modeling_instructblipvideo import *
+ from .processing_instructblipvideo import *
+ from .video_processing_instructblipvideo import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/configuration_instructblipvideo.py b/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/configuration_instructblipvideo.py
new file mode 100644
index 0000000000000000000000000000000000000000..af2acc83387675e5bac3fcfa7c6ffe5c793838a0
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/configuration_instructblipvideo.py
@@ -0,0 +1,345 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/instructblipvideo/modular_instructblipvideo.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_instructblipvideo.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...configuration_utils import PretrainedConfig
+from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
+from ...utils import logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class InstructBlipVideoVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`InstructBlipVideoVisionModel`]. It is used to
+ instantiate a InstructBlipVideo vision encoder according to the specified arguments, defining the model architecture.
+ Instantiating a configuration defaults will yield a similar configuration to that of the InstructBlipVideo
+ [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 1408):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 6144):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 39):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"gelu"` are supported. to 1e-5): The epsilon used by the layer
+ normalization layers.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 1e-10):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries and values in the self-attention layers.
+
+ Example:
+
+ ```python
+ >>> from transformers import InstructBlipVideoVisionConfig, InstructBlipVideoVisionModel
+
+ >>> # Initializing a InstructBlipVideoVisionConfig with Salesforce/instruct-blip-flan-t5 style configuration
+ >>> configuration = InstructBlipVideoVisionConfig()
+
+ >>> # Initializing a InstructBlipVideoVisionModel (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
+ >>> model = InstructBlipVideoVisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "instructblipvideo_vision_model"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=1408,
+ intermediate_size=6144,
+ num_hidden_layers=39,
+ num_attention_heads=16,
+ image_size=224,
+ patch_size=14,
+ hidden_act="gelu",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ initializer_range=1e-10,
+ qkv_bias=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.initializer_range = initializer_range
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.qkv_bias = qkv_bias
+
+
+class InstructBlipVideoQFormerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`InstructBlipVideoQFormerModel`]. It is used to
+ instantiate a InstructBlipVideo Querying Transformer (Q-Former) model according to the specified arguments, defining the
+ model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
+ the InstructBlipVideo [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5)
+ architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
+ Read the documentation from [`PretrainedConfig`] for more information.
+
+ Note that [`InstructBlipVideoQFormerModel`] is very similar to [`BertLMHeadModel`] with interleaved cross-attention.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the Q-Former model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling the model.
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Token id used for padding sequences.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
+ cross_attention_frequency (`int`, *optional*, defaults to 2):
+ The frequency of adding cross-attention to the Transformer layers.
+ encoder_hidden_size (`int`, *optional*, defaults to 1408):
+ The hidden size of the hidden states for cross-attention.
+
+ Examples:
+
+ ```python
+ >>> from transformers import InstructBlipVideoQFormerConfig, InstructBlipVideoQFormerModel
+
+ >>> # Initializing a InstructBlipVideo Salesforce/instruct-blip-flan-t5 style configuration
+ >>> configuration = InstructBlipVideoQFormerConfig()
+
+ >>> # Initializing a model (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
+ >>> model = InstructBlipVideoQFormerModel(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "instructblipvideo_qformer"
+ base_config_key = "qformer_config"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=0,
+ position_embedding_type="absolute",
+ cross_attention_frequency=2,
+ encoder_hidden_size=1408,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.position_embedding_type = position_embedding_type
+ self.cross_attention_frequency = cross_attention_frequency
+ self.encoder_hidden_size = encoder_hidden_size
+
+
+class InstructBlipVideoConfig(PretrainedConfig):
+ r"""
+ [`InstructBlipVideoConfig`] is the configuration class to store the configuration of a
+ [`InstructBlipVideoForConditionalGeneration`]. It is used to instantiate a Instructblipvideo model according to the specified
+ arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
+ the defaults will yield a similar configuration to that of the Instructblipvideo
+ [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`InstructBlipVideoVisionConfig`].
+ qformer_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`InstructBlipVideoQFormerConfig`].
+ text_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize any [`PretrainedConfig`].
+ num_query_tokens (`int`, *optional*, defaults to 32):
+ The number of query tokens passed through the Transformer.
+
+ video_token_index (`int`, *optional*):
+ Token index of special video token.
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... InstructBlipVideoVisionConfig,
+ ... InstructBlipVideoQFormerConfig,
+ ... OPTConfig,
+ ... InstructBlipVideoConfig,
+ ... InstructBlipVideoForConditionalGeneration,
+ ... )
+
+ >>> # Initializing a InstructBlipVideoConfig with Salesforce/instruct-blip-flan-t5 style configuration
+ >>> configuration = InstructBlipVideoConfig()
+
+ >>> # Initializing a InstructBlipVideoForConditionalGeneration (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
+ >>> model = InstructBlipVideoForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+
+ >>> # We can also initialize a InstructBlipVideoConfig from a InstructBlipVideoVisionConfig, InstructBlipVideoQFormerConfig and any PretrainedConfig
+
+ >>> # Initializing Instructblipvideo vision, Instructblipvideo Q-Former and language model configurations
+ >>> vision_config = InstructBlipVideoVisionConfig()
+ >>> qformer_config = InstructBlipVideoQFormerConfig()
+ >>> text_config = OPTConfig()
+
+ >>> config = InstructBlipVideoConfig.from_text_vision_configs(vision_config, qformer_config, text_config)
+ ```"""
+
+ model_type = "instructblipvideo"
+ attribute_map = {
+ "video_token_id": "video_token_index",
+ }
+ sub_configs = {
+ "text_config": AutoConfig,
+ "qformer_config": InstructBlipVideoQFormerConfig,
+ "vision_config": InstructBlipVideoVisionConfig,
+ }
+
+ def __init__(
+ self,
+ vision_config=None,
+ qformer_config=None,
+ text_config=None,
+ num_query_tokens=32,
+ video_token_index=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if vision_config is None:
+ vision_config = {}
+ logger.info("vision_config is None. initializing the InstructBlipVideoVisionConfig with default values.")
+
+ if qformer_config is None:
+ qformer_config = {}
+ logger.info("qformer_config is None. Initializing the InstructBlipVideoQFormerConfig with default values.")
+
+ if text_config is None:
+ text_config = {}
+ logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
+
+ self.vision_config = InstructBlipVideoVisionConfig(**vision_config)
+ self.qformer_config = InstructBlipVideoQFormerConfig(**qformer_config)
+ text_model_type = text_config.get("model_type", "opt")
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
+
+ self.num_query_tokens = num_query_tokens
+ self.video_token_index = video_token_index
+ self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
+ self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
+ self.initializer_factor = 1.0
+ self.initializer_range = 0.02
+
+ @classmethod
+ def from_vision_qformer_text_configs(
+ cls,
+ vision_config: InstructBlipVideoVisionConfig,
+ qformer_config: InstructBlipVideoQFormerConfig,
+ text_config: PretrainedConfig,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a InstructBlipVideo vision model, Q-Former and
+ language model configurations.
+
+ Returns:
+ [`InstructBlipVideoConfig`]: An instance of a configuration object
+ """
+
+ return cls(
+ vision_config=vision_config.to_dict(),
+ qformer_config=qformer_config.to_dict(),
+ text_config=text_config.to_dict(),
+ **kwargs,
+ )
+
+
+__all__ = ["InstructBlipVideoConfig", "InstructBlipVideoQFormerConfig", "InstructBlipVideoVisionConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/image_processing_instructblipvideo.py b/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/image_processing_instructblipvideo.py
new file mode 100644
index 0000000000000000000000000000000000000000..56391b59dbdd8c04b8995708c1c471ee9f07d75e
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/image_processing_instructblipvideo.py
@@ -0,0 +1,332 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Image processor class for InstructBLIPVideo. Largely copy of Blip2Processor with addition of a video processing abilities
+"""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, logging
+from ...video_utils import VideoInput, make_batched_videos
+
+
+logger = logging.get_logger(__name__)
+
+
+# TODO (raushan): processor can be removed after v5 release. Kept for backwards compatibility
+# Copied from transformers.models.blip.image_processing_blip.BlipImageProcessor with Blip->InstructBlipVideo, BLIP->InstructBLIPVideo
+class InstructBlipVideoImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a InstructBLIPVideo image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
+ overridden by the `resample` parameter in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+ `do_rescale` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 384, "width": 384}
+ size = get_size_dict(size, default_to_square=True)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+ self.do_convert_rgb = do_convert_rgb
+
+ # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
+ def resize(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
+
+ output_size = (size["height"], size["width"])
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ # Ignore copy
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: Optional[VideoInput] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> BatchFeature:
+ """
+ Preprocess a video or batch of images/videos.
+
+ Args:
+ videos (`VideoInput`):
+ Video frames to preprocess. Expects a single or batch of videos as a list of frames with pixel values
+ ranging from 0 to 255. If passing in video with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the video.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Controls the size of the video after `resize`. The shortest edge of the image is resized to
+ `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
+ is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
+ edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the video. Only has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the video values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the video by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the video.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to normalize the video by if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to normalize the video by if `do_normalize` is set to `True`.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+
+ videos = make_batched_videos(images)
+ logger.warning(
+ "`InstructBlipVideoImageProcessor` is deprecated and will be removed in v5.0. "
+ "We recommend to load an instance of `InstructBlipVideoVideoProcessor` to process videos for the model. "
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if not valid_images(videos):
+ raise ValueError(
+ "Invalid input type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ pixel_values = [
+ [
+ self._preprocess_image(
+ image=frame,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_convert_rgb=do_convert_rgb,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ for frame in video
+ ]
+ for video in videos
+ ]
+
+ encoded_outputs = BatchFeature(data={"pixel_values": pixel_values}, tensor_type=return_tensors)
+ return encoded_outputs
+
+ # Ignore copy
+ def _preprocess_image(
+ self,
+ image: Optional[ImageInput] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ # PIL RGBA images are converted to RGB
+ if do_convert_rgb:
+ image = convert_to_rgb(image)
+
+ # All transformations expect numpy arrays.
+ image = to_numpy_array(image)
+
+ if do_rescale and is_scaled_image(image):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled video frames. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(image)
+
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+
+ return image
+
+
+__all__ = ["InstructBlipVideoImageProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/modeling_instructblipvideo.py
new file mode 100644
index 0000000000000000000000000000000000000000..79049b83e88a8df7b484180924553a66ab8c863c
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/modeling_instructblipvideo.py
@@ -0,0 +1,1604 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/instructblipvideo/modular_instructblipvideo.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_instructblipvideo.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...generation import GenerationMixin
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPooling,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
+from ...utils.generic import OutputRecorder, check_model_inputs
+from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
+from .configuration_instructblipvideo import (
+ InstructBlipVideoConfig,
+ InstructBlipVideoQFormerConfig,
+ InstructBlipVideoVisionConfig,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class InstructBlipVideoVisionEmbeddings(nn.Module):
+ def __init__(self, config: InstructBlipVideoVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embedding.shape[1] - 1
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embedding
+
+ class_pos_embed = self.position_embedding[:, :1]
+ patch_pos_embed = self.position_embedding[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ if interpolate_pos_encoding:
+ position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ position_embedding = self.position_embedding
+ embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
+ return embeddings
+
+
+@auto_docstring
+class InstructBlipVideoPreTrainedModel(PreTrainedModel):
+ config: InstructBlipVideoConfig
+ base_model_prefix = "blip"
+ supports_gradient_checkpointing = True
+ _supports_attention_backend = True
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+
+ _no_split_modules = [
+ "InstructBlipVideoQFormerEmbeddings",
+ "InstructBlipVideoAttention",
+ "InstructBlipVideoQFormerMultiHeadAttention",
+ "InstructBlipVideoQFormerSelfOutput",
+ ]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ factor = self.config.initializer_range
+
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=factor)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=factor)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, InstructBlipVideoVisionEmbeddings):
+ nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
+ nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
+ elif isinstance(module, (InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel)):
+ module.query_tokens.data.zero_()
+
+
+# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBlipVideo doesn't cast attn weights to fp32
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class InstructBlipVideoAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.is_causal = False
+ self.attention_dropout = config.attention_dropout
+
+ # small tweak here compared to CLIP, no bias here
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
+
+ if config.qkv_bias:
+ q_bias = nn.Parameter(torch.zeros(self.embed_dim))
+ v_bias = nn.Parameter(torch.zeros(self.embed_dim))
+ else:
+ q_bias = None
+ v_bias = None
+
+ if q_bias is not None:
+ qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
+ self.qkv.bias = nn.Parameter(qkv_bias)
+
+ self.projection = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, embed_dim = hidden_states.size()
+
+ mixed_qkv = self.qkv(hidden_states)
+
+ mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
+ 2, 0, 3, 1, 4
+ )
+ query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
+
+ attention_interface: Callable = eager_attention_forward
+
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scale,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
+ attn_output = self.projection(attn_output)
+
+ return attn_output, attn_weights
+
+
+class InstructBlipVideoMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class InstructBlipVideoEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: InstructBlipVideoConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = InstructBlipVideoAttention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = InstructBlipVideoMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ @auto_docstring
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ head_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = hidden_states + residual
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class InstructBlipVideoEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`InstructBlipVideoEncoderLayer`].
+
+ Args:
+ config (`InstructBlipVideoConfig`):
+ The corresponding vision configuration for the `InstructBlipVideoEncoder`.
+ """
+
+ def __init__(self, config: InstructBlipVideoConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([InstructBlipVideoEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ @auto_docstring
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutput]:
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+
+ return BaseModelOutput(last_hidden_state=hidden_states)
+
+
+class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
+ main_input_name = "pixel_values"
+ config: InstructBlipVideoVisionConfig
+ _can_record_outputs = {
+ "hidden_states": InstructBlipVideoEncoderLayer,
+ "attentions": InstructBlipVideoAttention,
+ }
+
+ def __init__(self, config: InstructBlipVideoVisionConfig):
+ super().__init__(config)
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = InstructBlipVideoVisionEmbeddings(config)
+ self.encoder = InstructBlipVideoEncoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ self.post_init()
+
+ @check_model_inputs(tie_last_hidden_states=False)
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ **kwargs,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ pooled_output = last_hidden_state[:, 0, :]
+ pooled_output = self.post_layernorm(pooled_output)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ )
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+
+class InstructBlipVideoQFormerMultiHeadAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
+ % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ mixed_query_layer = self.query(hidden_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ attention_scores_dtype = attention_scores.dtype
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ return context_layer, attention_probs
+
+
+class InstructBlipVideoQFormerSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class InstructBlipVideoQFormerAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+ self.attention = InstructBlipVideoQFormerMultiHeadAttention(config, is_cross_attention)
+ self.output = InstructBlipVideoQFormerSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ attn_output, _ = self.attention(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ **kwargs,
+ )
+ attention_output = self.output(attn_output, hidden_states)
+ return attention_output
+
+
+class InstructBlipVideoQFormerIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class InstructBlipVideoQFormerOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class InstructBlipVideoQFormerLayer(GradientCheckpointingLayer):
+ def __init__(self, config, layer_idx):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = InstructBlipVideoQFormerAttention(config)
+
+ self.layer_idx = layer_idx
+
+ if layer_idx % config.cross_attention_frequency == 0:
+ self.crossattention = InstructBlipVideoQFormerAttention(config, is_cross_attention=True)
+ self.has_cross_attention = True
+ else:
+ self.has_cross_attention = False
+
+ self.intermediate = InstructBlipVideoQFormerIntermediate(config)
+ self.output = InstructBlipVideoQFormerOutput(config)
+
+ self.intermediate_query = InstructBlipVideoQFormerIntermediate(config)
+ self.output_query = InstructBlipVideoQFormerOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ query_length=0,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ attention_output = self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ **kwargs,
+ )
+
+ if query_length > 0:
+ query_attention_output = attention_output[:, :query_length, :]
+
+ if self.has_cross_attention:
+ if encoder_hidden_states is None:
+ raise ValueError("encoder_hidden_states must be given for cross-attention layers")
+ query_attention_output = self.crossattention(
+ query_attention_output,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ **kwargs,
+ )
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk_query,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ query_attention_output,
+ )
+
+ if attention_output.shape[1] > query_length:
+ layer_output_text = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output[:, query_length:, :],
+ ).to(layer_output.device)
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
+ else:
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output,
+ )
+ return layer_output
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+ def feed_forward_chunk_query(self, attention_output):
+ intermediate_output = self.intermediate_query(attention_output)
+ layer_output = self.output_query(intermediate_output, attention_output)
+ return layer_output
+
+
+class InstructBlipVideoQFormerEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList(
+ [InstructBlipVideoQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.gradient_checkpointing = False
+
+ @can_return_tuple
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ query_length=0,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ hidden_states = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ query_length=query_length,
+ **kwargs,
+ )
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ )
+
+
+class InstructBlipVideoQFormerEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.config = config
+
+ def forward(
+ self,
+ input_ids=None,
+ position_ids=None,
+ query_embeds=None,
+ past_key_values_length=0,
+ ):
+ if input_ids is not None:
+ seq_length = input_ids.size()[1]
+ else:
+ seq_length = 0
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
+
+ if input_ids is not None:
+ embeddings = self.word_embeddings(input_ids)
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids.to(embeddings.device))
+ embeddings = embeddings + position_embeddings
+
+ if query_embeds is not None:
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
+ else:
+ embeddings = query_embeds
+
+ embeddings = embeddings.to(self.layernorm.weight.dtype)
+ embeddings = self.layernorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
+ """
+ Querying Transformer (Q-Former), used in InstructBlipVideo. Slightly modified from BLIP-2 as it also takes the
+ instruction as input.
+ """
+
+ _supports_attention_backend = False # adds position on attn weights before last matmul
+ _supports_flash_attn = False
+ _supports_sdpa = False
+ _supports_flex_attn = False
+
+ _can_record_outputs = {
+ "hidden_states": InstructBlipVideoQFormerLayer,
+ "attentions": [
+ OutputRecorder(InstructBlipVideoQFormerMultiHeadAttention, index=1, layer_name=".attention"),
+ ],
+ "cross_attentions": [
+ OutputRecorder(InstructBlipVideoQFormerMultiHeadAttention, index=1, layer_name=".crossattention"),
+ ],
+ }
+
+ def __init__(self, config: InstructBlipVideoQFormerConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = InstructBlipVideoQFormerEmbeddings(config)
+
+ self.encoder = InstructBlipVideoQFormerEncoder(config)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def get_extended_attention_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_shape: tuple[int],
+ device: torch.device,
+ has_query: bool = False,
+ ) -> torch.Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (`tuple[int]`):
+ The shape of the input to the model.
+ device: (`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})",
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ @check_model_inputs()
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ query_embeds: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ r"""
+ query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Hidden states to be used in the attention computation. If cross-attention,
+ will be used for the query (i.e., key and value will use the encoder_hidden_states).
+ """
+ if input_ids is None and query_embeds is None:
+ raise ValueError("You have to specify query_embeds when input_ids is None")
+
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ query_embeds=query_embeds,
+ )
+
+ input_shape = embedding_output.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = embedding_output.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if isinstance(encoder_hidden_states, list):
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+ else:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if isinstance(encoder_attention_mask, list):
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ query_length=query_length,
+ **kwargs,
+ )
+ sequence_output = encoder_outputs.last_hidden_state
+ pooled_output = sequence_output[:, 0, :]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Class defining the outputs of [`InstructBlipVideoForConditionalGeneration`].
+ """
+)
+class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
+ Language modeling loss from the language model.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head of the language model.
+ vision_outputs (`BaseModelOutputWithPooling`):
+ Outputs of the vision encoder.
+ qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
+ Outputs of the Q-Former (Querying Transformer).
+ language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
+ Outputs of the language model.
+ """
+
+ loss: Optional[tuple[torch.FloatTensor]] = None
+ logits: Optional[tuple[torch.FloatTensor]] = None
+ vision_outputs: Optional[torch.FloatTensor] = None
+ qformer_outputs: Optional[tuple[torch.FloatTensor]] = None
+ language_model_outputs: Optional[tuple[torch.FloatTensor]] = None
+
+ def to_tuple(self) -> tuple[Any]:
+ return tuple(
+ self[k]
+ if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
+ else getattr(self, k).to_tuple()
+ for k in self.keys()
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ InstructBlipVideo base Model consisting of language model, qformer and vision encoder.
+ """
+)
+class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
+ main_input_name = "pixel_values"
+ _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
+
+ def __init__(self, config: InstructBlipVideoConfig):
+ super().__init__(config)
+
+ self.vision_model = InstructBlipVideoVisionModel(config.vision_config)
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
+ self.qformer = InstructBlipVideoQFormerModel(config.qformer_config)
+
+ self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
+ self.language_model = AutoModel.from_config(config.text_config)
+
+ if self.language_model._no_split_modules is not None:
+ self._no_split_modules.extend(self.language_model._no_split_modules)
+
+ if self.language_model._keep_in_fp32_modules is not None:
+ self._keep_in_fp32_modules.extend(self.language_model._keep_in_fp32_modules)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def _tie_weights(self):
+ if not self.config.use_decoder_only_language_model:
+ self.language_model.encoder.embed_tokens = self.language_model.shared
+ self.language_model.decoder.embed_tokens = self.language_model.shared
+
+ def _preprocess_accelerate(self):
+ r"""
+ Some pre-processing hacks to make the model `accelerate` compatible. Check
+ https://github.com/huggingface/transformers/pull/21707 for more details.
+ """
+ hf_device_map = self.hf_device_map
+
+ if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
+ # warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
+ logger.warning(
+ "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
+ " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
+ " Please pass a `device_map` that contains `language_model` to remove this warning."
+ " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
+ " more details on creating a `device_map` for large models.",
+ )
+
+ if hasattr(self.language_model, "_hf_hook"):
+ self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
+
+ def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ return special_image_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ qformer_input_ids: torch.FloatTensor,
+ qformer_attention_mask: Optional[torch.LongTensor] = None,
+ input_ids: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
+ r"""
+ qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
+ to serve as text prompt, which the Q-Former model will encode.
+
+ Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
+ details.
+
+ [What are input IDs?](../glossary#input-ids)
+ qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+
+ Only relevant in case an encoder-decoder language model (like T5) is used.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # step 1: forward the images through the vision encoder,
+ # we process in a batched way, later unbatch it back (video has frames=4 always)
+ batch_size, frames, channel, height, width = pixel_values.shape
+ pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+ image_embeds = vision_outputs[0]
+
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
+
+ # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+ query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
+
+ if qformer_attention_mask is None:
+ qformer_attention_mask = torch.ones_like(qformer_input_ids)
+
+ qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
+ qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
+ qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
+ query_outputs = self.qformer(
+ input_ids=qformer_input_ids,
+ attention_mask=qformer_attention_mask,
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ query_output = query_outputs[0][:, : query_tokens.size(1), :]
+
+ # step 3: use the language model, conditioned on the query outputs and the prompt
+ language_model_inputs = self.language_projection(query_output)
+
+ # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
+ language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
+ if inputs_embeds is None:
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
+ special_image_mask = input_ids == self.config.video_token_id
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+ else:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
+
+ if self.config.use_decoder_only_language_model:
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=use_cache,
+ **kwargs,
+ )
+ else:
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ return InstructBlipVideoForConditionalGenerationModelOutput(
+ vision_outputs=vision_outputs,
+ qformer_outputs=query_outputs,
+ language_model_outputs=outputs,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision
+ encoder, Querying Transformer (Q-Former) and a language model.
+
+ One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
+ the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
+ """
+)
+class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin):
+ config: InstructBlipVideoConfig
+ main_input_name = "pixel_values"
+
+ _can_compile_fullgraph = True
+ _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
+
+ def __init__(self, config: InstructBlipVideoConfig):
+ super().__init__(config)
+
+ self.vision_model = InstructBlipVideoVisionModel._from_config(config.vision_config)
+
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
+ self.qformer = InstructBlipVideoQFormerModel._from_config(config.qformer_config)
+
+ self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
+
+ if config.use_decoder_only_language_model:
+ language_model = AutoModelForCausalLM.from_config(config.text_config)
+ else:
+ language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
+
+ if language_model._no_split_modules is not None:
+ self._no_split_modules.extend(language_model._no_split_modules)
+
+ if language_model._keep_in_fp32_modules is not None:
+ self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules)
+
+ self.language_model = language_model
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_output_embeddings(self, new_embeddings):
+ self.language_model.set_output_embeddings(new_embeddings)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.language_model.get_output_embeddings()
+
+ def get_encoder(self):
+ return self.language_model.get_encoder()
+
+ def get_decoder(self):
+ return self.language_model.get_decoder()
+
+ def _tie_weights(self):
+ if not self.config.use_decoder_only_language_model:
+ self.language_model.encoder.embed_tokens = self.language_model.shared
+ self.language_model.decoder.embed_tokens = self.language_model.shared
+
+ def _preprocess_accelerate(self):
+ r"""
+ Some pre-processing hacks to make the model `accelerate` compatible. Check
+ https://github.com/huggingface/transformers/pull/21707 for more details.
+ """
+ hf_device_map = self.hf_device_map
+
+ if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
+ # warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
+ logger.warning(
+ "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
+ " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
+ " Please pass a `device_map` that contains `language_model` to remove this warning."
+ " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
+ " more details on creating a `device_map` for large models.",
+ )
+
+ if hasattr(self.language_model, "_hf_hook"):
+ self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ qformer_input_ids: torch.LongTensor,
+ qformer_attention_mask: Optional[torch.LongTensor] = None,
+ interpolate_pos_encoding: Optional[bool] = False,
+ return_dict: Optional[bool] = False,
+ ):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ """
+ pass
+
+ def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.video_token_id
+
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ return special_image_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ qformer_input_ids: torch.FloatTensor,
+ qformer_attention_mask: Optional[torch.LongTensor] = None,
+ input_ids: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ return_dict: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
+ r"""
+ qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length)):
+ The sequence used as a prompt to be fed to the Q-Former module.
+ qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
+ Mask to avoid performing attention on padding token indices.
+
+ Examples:
+
+ ```python
+ >>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
+ >>> import torch
+ >>> from huggingface_hub import hf_hub_download
+ >>> import av
+ >>> import numpy as np
+
+ >>> def read_video_pyav(container, indices):
+ ... '''
+ ... Decode the video with PyAV decoder.
+ ... Args:
+ ... container (`av.container.input.InputContainer`): PyAV container.
+ ... indices (`list[int]`): List of frame indices to decode.
+ ... Returns:
+ ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
+ ... '''
+ ... frames = []
+ ... container.seek(0)
+ ... start_index = indices[0]
+ ... end_index = indices[-1]
+ ... for i, frame in enumerate(container.decode(video=0)):
+ ... if i > end_index:
+ ... break
+ ... if i >= start_index and i in indices:
+ ... frames.append(frame)
+ ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
+
+ >>> model = InstructBlipVideoForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto")
+ >>> processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
+
+ >>> file_path = hf_hub_download(
+ ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
+ ... )
+ >>> container = av.open(file_path)
+
+ >>> # sample uniformly 4 frames from the videWhy is this video funny?o
+ >>> total_frames = container.streams.video[0].frames
+ >>> indices = np.arange(0, total_frames, total_frames / 4).astype(int)
+ >>> clip = read_video_pyav(container, indices)
+
+ >>> prompt = "What is happening in the video?"
+ >>> inputs = processor(text=prompt, images=clip, return_tensors="pt").to(model.device)
+
+ >>> outputs = model.generate(
+ ... **inputs,
+ ... do_sample=False,
+ ... num_beams=5,
+ ... max_length=256,
+ ... repetition_penalty=1.5,
+ ... length_penalty=1.0,
+ ... )
+ >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
+ >>> print(generated_text)
+ "A person is eating a bowl of pasta, and they are using a fork to eat it. The person is sitting at a table, and the plate of pasta is on the table in front"
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
+ pixel_values,
+ qformer_input_ids=qformer_input_ids,
+ qformer_attention_mask=qformer_attention_mask,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=True,
+ )
+ vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
+ query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
+
+ if self.config.use_decoder_only_language_model:
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=use_cache,
+ **kwargs,
+ )
+ logits = outputs.logits if return_dict else outputs[0]
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ else:
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ labels=labels,
+ use_cache=use_cache,
+ **kwargs,
+ )
+ loss = outputs.loss if return_dict else outputs[0]
+ logits = outputs.logits if return_dict else outputs[1]
+
+ return InstructBlipVideoForConditionalGenerationModelOutput(
+ loss=loss,
+ logits=logits,
+ vision_outputs=vision_outputs,
+ qformer_outputs=query_outputs,
+ language_model_outputs=outputs,
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ pixel_values: torch.FloatTensor,
+ qformer_input_ids: Optional[torch.LongTensor] = None,
+ qformer_attention_mask: Optional[torch.LongTensor] = None,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ **generate_kwargs,
+ ) -> torch.LongTensor:
+ r"""
+ Overrides `generate` function to be able to use the model as a conditional generator.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or
+ (batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed.
+ qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
+ The sequence used as a prompt to be fed to the Q-Former module.
+ qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
+ Mask to avoid performing attention on padding token indices.
+ input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
+ The sequence used as a prompt for the generation.
+ attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
+ Mask to avoid performing attention on padding token indices.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Embedded representation of the inputs. Should be float, not int tokens.
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
+ Whether to interpolate the positional encoding of the image embeddings.
+
+ Returns:
+ captions (list): A list of strings of length batch_size * num_captions.
+ """
+ if hasattr(self, "hf_device_map"):
+ # preprocess for `accelerate`
+ self._preprocess_accelerate()
+
+ batch_size = pixel_values.shape[0]
+ language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
+ pixel_values,
+ qformer_input_ids=qformer_input_ids,
+ qformer_attention_mask=qformer_attention_mask,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=True,
+ )
+
+ if inputs_embeds is None:
+ if input_ids is None:
+ video_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4
+ start_tokens = video_tokens + [self.config.text_config.bos_token_id]
+ input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
+ input_ids = input_ids.repeat(batch_size, 1)
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
+
+ inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
+ if not self.language_model.config.is_encoder_decoder:
+ inputs["input_ids"] = input_ids
+
+ outputs = self.language_model.generate(**inputs, **generate_kwargs)
+
+ return outputs
+
+ def get_video_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ qformer_input_ids: torch.LongTensor,
+ qformer_attention_mask: Optional[torch.LongTensor] = None,
+ interpolate_pos_encoding: Optional[bool] = False,
+ return_dict: Optional[bool] = False,
+ ):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ """
+ # step 1: forward the images through the vision encoder,
+ # we process in a batched way, later unbatch it back (video has frames=4 always)
+ batch_size, frames, channel, height, width = pixel_values.shape
+ pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=True,
+ )
+ image_embeds = vision_outputs[0]
+
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
+
+ # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+ query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
+
+ if qformer_attention_mask is None:
+ qformer_attention_mask = torch.ones_like(qformer_input_ids)
+
+ qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
+ qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
+ qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
+ query_outputs = self.qformer(
+ input_ids=qformer_input_ids,
+ attention_mask=qformer_attention_mask,
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_attention_mask,
+ return_dict=True,
+ )
+ query_output = query_outputs[0][:, : query_tokens.size(1), :]
+
+ # step 3: use the language model, conditioned on the query outputs and the prompt
+ language_model_inputs = self.language_projection(query_output)
+
+ # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
+ language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
+ if return_dict:
+ return language_model_inputs, vision_outputs, query_outputs
+ return language_model_inputs
+
+
+__all__ = [
+ "InstructBlipVideoVisionModel",
+ "InstructBlipVideoPreTrainedModel",
+ "InstructBlipVideoQFormerModel",
+ "InstructBlipVideoModel",
+ "InstructBlipVideoForConditionalGeneration",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/modular_instructblipvideo.py b/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/modular_instructblipvideo.py
new file mode 100644
index 0000000000000000000000000000000000000000..5619c2e79b9a410edf443584c0a0f92eaf4de565
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/modular_instructblipvideo.py
@@ -0,0 +1,613 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+
+from transformers.models.instructblip.configuration_instructblip import (
+ InstructBlipQFormerConfig,
+ InstructBlipVisionConfig,
+)
+from transformers.models.instructblip.modeling_instructblip import (
+ InstructBlipForConditionalGeneration,
+ InstructBlipForConditionalGenerationModelOutput,
+ InstructBlipModel,
+ InstructBlipPreTrainedModel,
+ InstructBlipQFormerModel,
+ InstructBlipVisionModel,
+ TransformersKwargs,
+)
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
+from ...processing_utils import Unpack
+from ...utils import logging
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class InstructBlipVideoVisionConfig(InstructBlipVisionConfig):
+ pass
+
+
+class InstructBlipVideoQFormerConfig(InstructBlipQFormerConfig):
+ pass
+
+
+class InstructBlipVideoConfig(PretrainedConfig):
+ r"""
+ [`InstructBlipVideoConfig`] is the configuration class to store the configuration of a
+ [`InstructBlipVideoForConditionalGeneration`]. It is used to instantiate a Instructblipvideo model according to the specified
+ arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
+ the defaults will yield a similar configuration to that of the Instructblipvideo
+ [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`InstructBlipVideoVisionConfig`].
+ qformer_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`InstructBlipVideoQFormerConfig`].
+ text_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize any [`PretrainedConfig`].
+ num_query_tokens (`int`, *optional*, defaults to 32):
+ The number of query tokens passed through the Transformer.
+
+ video_token_index (`int`, *optional*):
+ Token index of special video token.
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... InstructBlipVideoVisionConfig,
+ ... InstructBlipVideoQFormerConfig,
+ ... OPTConfig,
+ ... InstructBlipVideoConfig,
+ ... InstructBlipVideoForConditionalGeneration,
+ ... )
+
+ >>> # Initializing a InstructBlipVideoConfig with Salesforce/instruct-blip-flan-t5 style configuration
+ >>> configuration = InstructBlipVideoConfig()
+
+ >>> # Initializing a InstructBlipVideoForConditionalGeneration (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
+ >>> model = InstructBlipVideoForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+
+ >>> # We can also initialize a InstructBlipVideoConfig from a InstructBlipVideoVisionConfig, InstructBlipVideoQFormerConfig and any PretrainedConfig
+
+ >>> # Initializing Instructblipvideo vision, Instructblipvideo Q-Former and language model configurations
+ >>> vision_config = InstructBlipVideoVisionConfig()
+ >>> qformer_config = InstructBlipVideoQFormerConfig()
+ >>> text_config = OPTConfig()
+
+ >>> config = InstructBlipVideoConfig.from_text_vision_configs(vision_config, qformer_config, text_config)
+ ```"""
+
+ model_type = "instructblipvideo"
+ attribute_map = {
+ "video_token_id": "video_token_index",
+ }
+ sub_configs = {
+ "text_config": AutoConfig,
+ "qformer_config": InstructBlipVideoQFormerConfig,
+ "vision_config": InstructBlipVideoVisionConfig,
+ }
+
+ def __init__(
+ self,
+ vision_config=None,
+ qformer_config=None,
+ text_config=None,
+ num_query_tokens=32,
+ video_token_index=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if vision_config is None:
+ vision_config = {}
+ logger.info("vision_config is None. initializing the InstructBlipVideoVisionConfig with default values.")
+
+ if qformer_config is None:
+ qformer_config = {}
+ logger.info("qformer_config is None. Initializing the InstructBlipVideoQFormerConfig with default values.")
+
+ if text_config is None:
+ text_config = {}
+ logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
+
+ self.vision_config = InstructBlipVideoVisionConfig(**vision_config)
+ self.qformer_config = InstructBlipVideoQFormerConfig(**qformer_config)
+ text_model_type = text_config.get("model_type", "opt")
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
+
+ self.num_query_tokens = num_query_tokens
+ self.video_token_index = video_token_index
+ self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
+ self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
+ self.initializer_factor = 1.0
+ self.initializer_range = 0.02
+
+ @classmethod
+ def from_vision_qformer_text_configs(
+ cls,
+ vision_config: InstructBlipVideoVisionConfig,
+ qformer_config: InstructBlipVideoQFormerConfig,
+ text_config: PretrainedConfig,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a InstructBlipVideo vision model, Q-Former and
+ language model configurations.
+
+ Returns:
+ [`InstructBlipVideoConfig`]: An instance of a configuration object
+ """
+
+ return cls(
+ vision_config=vision_config.to_dict(),
+ qformer_config=qformer_config.to_dict(),
+ text_config=text_config.to_dict(),
+ **kwargs,
+ )
+
+
+class InstructBlipVideoPreTrainedModel(InstructBlipPreTrainedModel):
+ pass
+
+
+class InstructBlipVideoVisionModel(InstructBlipVisionModel):
+ pass
+
+
+class InstructBlipVideoQFormerModel(InstructBlipQFormerModel):
+ pass
+
+
+class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForConditionalGenerationModelOutput):
+ pass
+
+
+class InstructBlipVideoModel(InstructBlipModel):
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ qformer_input_ids: torch.FloatTensor,
+ qformer_attention_mask: Optional[torch.LongTensor] = None,
+ input_ids: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # step 1: forward the images through the vision encoder,
+ # we process in a batched way, later unbatch it back (video has frames=4 always)
+ batch_size, frames, channel, height, width = pixel_values.shape
+ pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+ image_embeds = vision_outputs[0]
+
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
+
+ # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+ query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
+
+ if qformer_attention_mask is None:
+ qformer_attention_mask = torch.ones_like(qformer_input_ids)
+
+ qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
+ qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
+ qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
+ query_outputs = self.qformer(
+ input_ids=qformer_input_ids,
+ attention_mask=qformer_attention_mask,
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ query_output = query_outputs[0][:, : query_tokens.size(1), :]
+
+ # step 3: use the language model, conditioned on the query outputs and the prompt
+ language_model_inputs = self.language_projection(query_output)
+
+ # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
+ language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
+ if inputs_embeds is None:
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
+ special_image_mask = input_ids == self.config.video_token_id
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+ else:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
+
+ if self.config.use_decoder_only_language_model:
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=use_cache,
+ **kwargs,
+ )
+ else:
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ return InstructBlipVideoForConditionalGenerationModelOutput(
+ vision_outputs=vision_outputs,
+ qformer_outputs=query_outputs,
+ language_model_outputs=outputs,
+ )
+
+
+class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
+ def get_video_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ qformer_input_ids: torch.LongTensor,
+ qformer_attention_mask: Optional[torch.LongTensor] = None,
+ interpolate_pos_encoding: Optional[bool] = False,
+ return_dict: Optional[bool] = False,
+ ):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ """
+ # step 1: forward the images through the vision encoder,
+ # we process in a batched way, later unbatch it back (video has frames=4 always)
+ batch_size, frames, channel, height, width = pixel_values.shape
+ pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=True,
+ )
+ image_embeds = vision_outputs[0]
+
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
+
+ # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+ query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
+
+ if qformer_attention_mask is None:
+ qformer_attention_mask = torch.ones_like(qformer_input_ids)
+
+ qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
+ qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
+ qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
+ query_outputs = self.qformer(
+ input_ids=qformer_input_ids,
+ attention_mask=qformer_attention_mask,
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_attention_mask,
+ return_dict=True,
+ )
+ query_output = query_outputs[0][:, : query_tokens.size(1), :]
+
+ # step 3: use the language model, conditioned on the query outputs and the prompt
+ language_model_inputs = self.language_projection(query_output)
+
+ # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
+ language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
+ if return_dict:
+ return language_model_inputs, vision_outputs, query_outputs
+ return language_model_inputs
+
+ # Model supports only videos
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ qformer_input_ids: torch.LongTensor,
+ qformer_attention_mask: Optional[torch.LongTensor] = None,
+ interpolate_pos_encoding: Optional[bool] = False,
+ return_dict: Optional[bool] = False,
+ ):
+ pass
+
+ def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.video_token_id
+
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ return special_image_mask
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ qformer_input_ids: torch.FloatTensor,
+ qformer_attention_mask: Optional[torch.LongTensor] = None,
+ input_ids: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ return_dict: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
+ r"""
+ qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length)):
+ The sequence used as a prompt to be fed to the Q-Former module.
+ qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
+ Mask to avoid performing attention on padding token indices.
+
+ Examples:
+
+ ```python
+ >>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
+ >>> import torch
+ >>> from huggingface_hub import hf_hub_download
+ >>> import av
+ >>> import numpy as np
+
+ >>> def read_video_pyav(container, indices):
+ ... '''
+ ... Decode the video with PyAV decoder.
+ ... Args:
+ ... container (`av.container.input.InputContainer`): PyAV container.
+ ... indices (`list[int]`): List of frame indices to decode.
+ ... Returns:
+ ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
+ ... '''
+ ... frames = []
+ ... container.seek(0)
+ ... start_index = indices[0]
+ ... end_index = indices[-1]
+ ... for i, frame in enumerate(container.decode(video=0)):
+ ... if i > end_index:
+ ... break
+ ... if i >= start_index and i in indices:
+ ... frames.append(frame)
+ ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
+
+ >>> model = InstructBlipVideoForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto")
+ >>> processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
+
+ >>> file_path = hf_hub_download(
+ ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
+ ... )
+ >>> container = av.open(file_path)
+
+ >>> # sample uniformly 4 frames from the videWhy is this video funny?o
+ >>> total_frames = container.streams.video[0].frames
+ >>> indices = np.arange(0, total_frames, total_frames / 4).astype(int)
+ >>> clip = read_video_pyav(container, indices)
+
+ >>> prompt = "What is happening in the video?"
+ >>> inputs = processor(text=prompt, images=clip, return_tensors="pt").to(model.device)
+
+ >>> outputs = model.generate(
+ ... **inputs,
+ ... do_sample=False,
+ ... num_beams=5,
+ ... max_length=256,
+ ... repetition_penalty=1.5,
+ ... length_penalty=1.0,
+ ... )
+ >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
+ >>> print(generated_text)
+ "A person is eating a bowl of pasta, and they are using a fork to eat it. The person is sitting at a table, and the plate of pasta is on the table in front"
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
+ pixel_values,
+ qformer_input_ids=qformer_input_ids,
+ qformer_attention_mask=qformer_attention_mask,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=True,
+ )
+ vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
+ query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
+
+ if self.config.use_decoder_only_language_model:
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=use_cache,
+ **kwargs,
+ )
+ logits = outputs.logits if return_dict else outputs[0]
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ else:
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ labels=labels,
+ use_cache=use_cache,
+ **kwargs,
+ )
+ loss = outputs.loss if return_dict else outputs[0]
+ logits = outputs.logits if return_dict else outputs[1]
+
+ return InstructBlipVideoForConditionalGenerationModelOutput(
+ loss=loss,
+ logits=logits,
+ vision_outputs=vision_outputs,
+ qformer_outputs=query_outputs,
+ language_model_outputs=outputs,
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ pixel_values: torch.FloatTensor,
+ qformer_input_ids: Optional[torch.LongTensor] = None,
+ qformer_attention_mask: Optional[torch.LongTensor] = None,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ **generate_kwargs,
+ ) -> torch.LongTensor:
+ r"""
+ Overrides `generate` function to be able to use the model as a conditional generator.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or
+ (batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed.
+ qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
+ The sequence used as a prompt to be fed to the Q-Former module.
+ qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
+ Mask to avoid performing attention on padding token indices.
+ input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
+ The sequence used as a prompt for the generation.
+ attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
+ Mask to avoid performing attention on padding token indices.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Embedded representation of the inputs. Should be float, not int tokens.
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
+ Whether to interpolate the positional encoding of the image embeddings.
+
+ Returns:
+ captions (list): A list of strings of length batch_size * num_captions.
+ """
+ if hasattr(self, "hf_device_map"):
+ # preprocess for `accelerate`
+ self._preprocess_accelerate()
+
+ batch_size = pixel_values.shape[0]
+ language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
+ pixel_values,
+ qformer_input_ids=qformer_input_ids,
+ qformer_attention_mask=qformer_attention_mask,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=True,
+ )
+
+ if inputs_embeds is None:
+ if input_ids is None:
+ video_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4
+ start_tokens = video_tokens + [self.config.text_config.bos_token_id]
+ input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
+ input_ids = input_ids.repeat(batch_size, 1)
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
+
+ inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
+ if not self.language_model.config.is_encoder_decoder:
+ inputs["input_ids"] = input_ids
+
+ outputs = self.language_model.generate(**inputs, **generate_kwargs)
+
+ return outputs
+
+
+__all__ = [
+ "InstructBlipVideoConfig",
+ "InstructBlipVideoQFormerConfig",
+ "InstructBlipVideoVisionConfig",
+ "InstructBlipVideoVisionModel",
+ "InstructBlipVideoPreTrainedModel",
+ "InstructBlipVideoQFormerModel",
+ "InstructBlipVideoModel",
+ "InstructBlipVideoForConditionalGeneration",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/processing_instructblipvideo.py b/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/processing_instructblipvideo.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee4e843e2f330ad969c84144b9223cb053ad7ec4
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/instructblipvideo/processing_instructblipvideo.py
@@ -0,0 +1,215 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for InstructBLIP. Largely copy of Blip2Processor with addition of a tokenizer for the Q-Former.
+"""
+
+import os
+from typing import Optional, Union
+
+from ...image_processing_utils import BatchFeature
+from ...processing_utils import ProcessorMixin
+from ...tokenization_utils_base import (
+ AddedToken,
+ PaddingStrategy,
+ PreTokenizedInput,
+ TextInput,
+ TruncationStrategy,
+)
+from ...utils import TensorType, logging
+from ...video_utils import VideoInput
+from ..auto import AutoTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+
+class InstructBlipVideoProcessor(ProcessorMixin):
+ r"""
+ Constructs an InstructBLIPVideo processor which wraps a InstructBLIP image processor and a LLaMa/T5 tokenizer into a single
+ processor.
+
+ [`InstructBlipVideoProcessor`] offers all the functionalities of [`InstructBlipVideoImageProcessor`] and [`AutoTokenizer`]. See the
+ docstring of [`~InstructBlipVideoProcessor.__call__`] and [`~InstructBlipVideoProcessor.decode`] for more information.
+
+ Args:
+ video_processor (`InstructBlipVideoVideoProcessor`):
+ An instance of [`InstructBlipVideoVideoProcessor`]. The video processor is a required input.
+ tokenizer (`AutoTokenizer`):
+ An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
+ qformer_tokenizer (`AutoTokenizer`):
+ An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input.
+ num_query_tokens (`int`, *optional*):
+ Number of tokens used by the Qformer as queries, should be same as in model's config.
+ """
+
+ attributes = ["video_processor", "tokenizer", "qformer_tokenizer"]
+ video_processor_class = "AutoVideoProcessor"
+ tokenizer_class = "AutoTokenizer"
+ qformer_tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, video_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs):
+ if not hasattr(tokenizer, "video_token"):
+ self.video_token = AddedToken(""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ cls_token_box (`list[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [CLS] token.
+ sep_token_box (`list[int]`, *optional*, defaults to `[1000, 1000, 1000, 1000]`):
+ The bounding box to use for the special [SEP] token.
+ pad_token_box (`list[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [PAD] token.
+ pad_token_label (`int`, *optional*, defaults to -100):
+ The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's
+ CrossEntropyLoss.
+ only_label_first_subword (`bool`, *optional*, defaults to `True`):
+ Whether or not to only label the first subword, in case word labels are provided.
+ sp_model_kwargs (`dict`, *optional*):
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+ to set:
+
+ - `enable_sampling`: Enable subword regularization.
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+ - `nbest_size = {0,1}`: No sampling is performed.
+ - `nbest_size > 1`: samples from the nbest_size results.
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+ using forward-filtering-and-backward-sampling algorithm.
+
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+ BPE-dropout.
+
+ Attributes:
+ sp_model (`SentencePieceProcessor`):
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ cls_token_box=[0, 0, 0, 0],
+ sep_token_box=[1000, 1000, 1000, 1000],
+ pad_token_box=[0, 0, 0, 0],
+ pad_token_label=-100,
+ only_label_first_subword=True,
+ sp_model_kwargs: Optional[dict[str, Any]] = None,
+ **kwargs,
+ ) -> None:
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token
+
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(str(vocab_file))
+ self.vocab_file = vocab_file
+
+ # Original fairseq vocab and spm vocab must be "aligned":
+ # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
+ # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
+ # fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-'
+ # spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
+
+ # Mimic fairseq token-to-id alignment for the first 4 token
+ self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3}
+
+ # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
+ self.fairseq_offset = 1
+
+ self.fairseq_tokens_to_ids[""] = len(self.sp_model) + self.fairseq_offset
+ self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
+
+ # additional properties
+ self.cls_token_box = cls_token_box
+ self.sep_token_box = sep_token_box
+ self.pad_token_box = pad_token_box
+ self.pad_token_label = pad_token_label
+ self.only_label_first_subword = only_label_first_subword
+
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ cls_token_box=cls_token_box,
+ sep_token_box=sep_token_box,
+ pad_token_box=pad_token_box,
+ pad_token_label=pad_token_label,
+ only_label_first_subword=only_label_first_subword,
+ sp_model_kwargs=self.sp_model_kwargs,
+ **kwargs,
+ )
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__.update(d)
+
+ # for backward compatibility
+ if not hasattr(self, "sp_model_kwargs"):
+ self.sp_model_kwargs = {}
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. An XLM-RoBERTa sequence has the following format:
+
+ - single sequence: ` X `
+ - pair of sequences: ` A B `
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does
+ not make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of zeros.
+
+ """
+
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ @property
+ def vocab_size(self):
+ return len(self.sp_model) + self.fairseq_offset + 1 # Add the token
+
+ def get_vocab(self):
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text: str) -> list[str]:
+ return self.sp_model.encode(text, out_type=str)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ if token in self.fairseq_tokens_to_ids:
+ return self.fairseq_tokens_to_ids[token]
+ spm_id = self.sp_model.PieceToId(token)
+
+ # Need to return unknown token if the SP model returned 0
+ return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ if index in self.fairseq_ids_to_tokens:
+ return self.fairseq_ids_to_tokens[index]
+ return self.sp_model.IdToPiece(index - self.fairseq_offset)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (strings for sub-words) in a single string."""
+ out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
+ return out_string
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+ @add_end_docstrings(LAYOUTXLM_ENCODE_KWARGS_DOCSTRING)
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
+ text_pair: Optional[Union[PreTokenizedInput, list[PreTokenizedInput]]] = None,
+ boxes: Optional[Union[list[list[int]], list[list[list[int]]]]] = None,
+ word_labels: Optional[Union[list[int], list[list[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
+ sequences with word-level normalized bounding boxes and optional labels.
+
+ Args:
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings
+ (words of a single example or questions of a batch of examples) or a list of list of strings (batch of
+ words).
+ text_pair (`list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence should be a list of strings
+ (pretokenized string).
+ boxes (`list[list[int]]`, `list[list[list[int]]]`):
+ Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale.
+ word_labels (`list[int]`, `list[list[int]]`, *optional*):
+ Word-level integer labels (for token classification tasks such as FUNSD, CORD).
+ """
+
+ # Input type checking for clearer error
+ def _is_valid_text_input(t):
+ if isinstance(t, str):
+ # Strings are fine
+ return True
+ elif isinstance(t, (list, tuple)):
+ # List are fine as long as they are...
+ if len(t) == 0:
+ # ... empty
+ return True
+ elif isinstance(t[0], str):
+ # ... list of strings
+ return True
+ elif isinstance(t[0], (list, tuple)):
+ # ... list with an empty list or with a list of strings
+ return len(t[0]) == 0 or isinstance(t[0][0], str)
+ else:
+ return False
+ else:
+ return False
+
+ if text_pair is not None:
+ # in case text + text_pair are provided, text = questions, text_pair = words
+ if not _is_valid_text_input(text):
+ raise ValueError("text input must of type `str` (single example) or `list[str]` (batch of examples). ")
+ if not isinstance(text_pair, (list, tuple)):
+ raise ValueError(
+ "words must of type `list[str]` (single pretokenized example), "
+ "or `list[list[str]]` (batch of pretokenized examples)."
+ )
+ else:
+ # in case only text is provided => must be words
+ if not isinstance(text, (list, tuple)):
+ raise ValueError(
+ "Words must of type `list[str]` (single pretokenized example), "
+ "or `list[list[str]]` (batch of pretokenized examples)."
+ )
+
+ if text_pair is not None:
+ is_batched = isinstance(text, (list, tuple))
+ else:
+ is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
+
+ words = text if text_pair is None else text_pair
+ if boxes is None:
+ raise ValueError("You must provide corresponding bounding boxes")
+ if is_batched:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide words and boxes for an equal amount of examples")
+ for words_example, boxes_example in zip(words, boxes):
+ if len(words_example) != len(boxes_example):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+ else:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+
+ if is_batched:
+ if text_pair is not None and len(text) != len(text_pair):
+ raise ValueError(
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
+ )
+ batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
+ is_pair = bool(text_pair is not None)
+ return self.batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+ else:
+ return self.encode_plus(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ def _batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ list[TextInput],
+ list[TextInputPair],
+ list[PreTokenizedInput],
+ ],
+ is_pair: Optional[bool] = None,
+ boxes: Optional[list[list[list[int]]]] = None,
+ word_labels: Optional[list[list[int]]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers. "
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast."
+ )
+
+ batch_outputs = self._batch_prepare_for_model(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=return_tensors,
+ verbose=verbose,
+ )
+
+ return BatchEncoding(batch_outputs)
+
+ @add_end_docstrings(LAYOUTXLM_ENCODE_KWARGS_DOCSTRING)
+ def _batch_prepare_for_model(
+ self,
+ batch_text_or_text_pairs,
+ is_pair: Optional[bool] = None,
+ boxes: Optional[list[list[int]]] = None,
+ word_labels: Optional[list[list[int]]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[str] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ ) -> BatchEncoding:
+ """
+ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
+ adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
+ manages a moving window (with user defined stride) for overflowing tokens
+
+ Args:
+ batch_ids_pairs: list of tokenized input ids or input ids pairs
+ """
+
+ batch_outputs = {}
+ for idx, example in enumerate(zip(batch_text_or_text_pairs, boxes)):
+ batch_text_or_text_pair, boxes_example = example
+ outputs = self.prepare_for_model(
+ batch_text_or_text_pair[0] if is_pair else batch_text_or_text_pair,
+ batch_text_or_text_pair[1] if is_pair else None,
+ boxes_example,
+ word_labels=word_labels[idx] if word_labels is not None else None,
+ add_special_tokens=add_special_tokens,
+ padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=None, # we pad in batch afterward
+ padding_side=None, # we pad in batch afterward
+ return_attention_mask=False, # we pad in batch afterward
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=None, # We convert the whole batch to tensors at the end
+ prepend_batch_axis=False,
+ verbose=verbose,
+ )
+
+ for key, value in outputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ batch_outputs[key].append(value)
+
+ batch_outputs = self.pad(
+ batch_outputs,
+ padding=padding_strategy.value,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_attention_mask=return_attention_mask,
+ )
+
+ batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
+
+ return batch_outputs
+
+ def _encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[list[list[int]]] = None,
+ word_labels: Optional[list[int]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers. "
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast. "
+ "More information on available tokenizers at "
+ "https://github.com/huggingface/transformers/pull/2674"
+ )
+
+ return self.prepare_for_model(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding_strategy.value,
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ prepend_batch_axis=True,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ verbose=verbose,
+ )
+
+ @add_end_docstrings(LAYOUTXLM_ENCODE_KWARGS_DOCSTRING)
+ def prepare_for_model(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[list[list[int]]] = None,
+ word_labels: Optional[list[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ prepend_batch_axis: bool = False,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Prepares a sequence or a pair of sequences so that it can be used by the model. It adds special tokens,
+ truncates sequences if overflowing while taking into account the special tokens and manages a moving window
+ (with user defined stride) for overflowing tokens.
+
+ Word-level `boxes` are turned into token-level `bbox`. If provided, word-level `word_labels` are turned into
+ token-level `labels`. The word label is used for the first token of the word, while remaining tokens are
+ labeled with -100, such that they will be ignored by the loss function.
+
+ Args:
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
+ text_pair (`list[str]` or `list[int]`, *optional*):
+ Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
+ list of list of strings (words of a batch of examples).
+ """
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ tokens = []
+ pair_tokens = []
+ token_boxes = []
+ pair_token_boxes = []
+ labels = []
+
+ if text_pair is None:
+ if word_labels is None:
+ # CASE 1: document image classification (training + inference) + CASE 2: token classification (inference)
+ for word, box in zip(text, boxes):
+ if len(word) < 1: # skip empty words
+ continue
+ word_tokens = self.tokenize(word)
+ tokens.extend(word_tokens)
+ token_boxes.extend([box] * len(word_tokens))
+ else:
+ # CASE 2: token classification (training)
+ for word, box, label in zip(text, boxes, word_labels):
+ if len(word) < 1: # skip empty words
+ continue
+ word_tokens = self.tokenize(word)
+ tokens.extend(word_tokens)
+ token_boxes.extend([box] * len(word_tokens))
+ if self.only_label_first_subword:
+ # Use the real label id for the first token of the word, and padding ids for the remaining tokens
+ labels.extend([label] + [self.pad_token_label] * (len(word_tokens) - 1))
+ else:
+ labels.extend([label] * len(word_tokens))
+ else:
+ # CASE 3: document visual question answering (inference)
+ # text = question
+ # text_pair = words
+ tokens = self.tokenize(text)
+ token_boxes = [self.pad_token_box for _ in range(len(tokens))] + [self.sep_token_box]
+
+ for word, box in zip(text_pair, boxes):
+ if len(word) < 1: # skip empty words
+ continue
+ word_tokens = self.tokenize(word)
+ pair_tokens.extend(word_tokens)
+ pair_token_boxes.extend([box] * len(word_tokens))
+
+ # Create ids + pair_ids
+ ids = self.convert_tokens_to_ids(tokens)
+ pair_ids = self.convert_tokens_to_ids(pair_tokens) if pair_tokens else None
+
+ # Compute the total size of the returned encodings
+ pair = bool(pair_ids is not None)
+ len_ids = len(ids)
+ len_pair_ids = len(pair_ids) if pair else 0
+ total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
+
+ # Truncation: Handle max sequence length
+ overflowing_tokens = []
+ overflowing_token_boxes = []
+ overflowing_labels = []
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
+ (
+ ids,
+ token_boxes,
+ pair_ids,
+ pair_token_boxes,
+ labels,
+ overflowing_tokens,
+ overflowing_token_boxes,
+ overflowing_labels,
+ ) = self.truncate_sequences(
+ ids,
+ token_boxes,
+ pair_ids=pair_ids,
+ pair_token_boxes=pair_token_boxes,
+ labels=labels,
+ num_tokens_to_remove=total_len - max_length,
+ truncation_strategy=truncation_strategy,
+ stride=stride,
+ )
+
+ if return_token_type_ids and not add_special_tokens:
+ raise ValueError(
+ "Asking to return token_type_ids while setting add_special_tokens to False "
+ "results in an undefined behavior. Please set add_special_tokens to True or "
+ "set return_token_type_ids to None."
+ )
+
+ # Load from model defaults
+ if return_token_type_ids is None:
+ return_token_type_ids = "token_type_ids" in self.model_input_names
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ encoded_inputs = {}
+
+ if return_overflowing_tokens:
+ encoded_inputs["overflowing_tokens"] = overflowing_tokens
+ encoded_inputs["overflowing_token_boxes"] = overflowing_token_boxes
+ encoded_inputs["overflowing_labels"] = overflowing_labels
+ encoded_inputs["num_truncated_tokens"] = total_len - max_length
+
+ # Add special tokens
+ if add_special_tokens:
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
+ token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
+ token_boxes = [self.cls_token_box] + token_boxes + [self.sep_token_box]
+ if pair_token_boxes:
+ pair_token_boxes = pair_token_boxes + [self.sep_token_box]
+ if labels:
+ labels = [self.pad_token_label] + labels + [self.pad_token_label]
+ else:
+ sequence = ids + pair_ids if pair else ids
+ token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
+
+ # Build output dictionary
+ encoded_inputs["input_ids"] = sequence
+ encoded_inputs["bbox"] = token_boxes + pair_token_boxes
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = token_type_ids
+ if return_special_tokens_mask:
+ if add_special_tokens:
+ encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
+ else:
+ encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
+
+ if labels:
+ encoded_inputs["labels"] = labels
+
+ # Check lengths
+ self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
+
+ # Padding
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
+ encoded_inputs = self.pad(
+ encoded_inputs,
+ max_length=max_length,
+ padding=padding_strategy.value,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_attention_mask=return_attention_mask,
+ )
+
+ if return_length:
+ encoded_inputs["length"] = len(encoded_inputs["input_ids"])
+
+ batch_outputs = BatchEncoding(
+ encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
+ )
+
+ return batch_outputs
+
+ def truncate_sequences(
+ self,
+ ids: list[int],
+ token_boxes: list[list[int]],
+ pair_ids: Optional[list[int]] = None,
+ pair_token_boxes: Optional[list[list[int]]] = None,
+ labels: Optional[list[int]] = None,
+ num_tokens_to_remove: int = 0,
+ truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
+ stride: int = 0,
+ ) -> tuple[list[int], list[int], list[int]]:
+ """
+ Truncates a sequence pair in-place following the strategy.
+
+ Args:
+ ids (`list[int]`):
+ Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
+ `convert_tokens_to_ids` methods.
+ token_boxes (`list[list[int]]`):
+ Bounding boxes of the first sequence.
+ pair_ids (`list[int]`, *optional*):
+ Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
+ and `convert_tokens_to_ids` methods.
+ pair_token_boxes (`list[list[int]]`, *optional*):
+ Bounding boxes of the second sequence.
+ labels (`list[int]`, *optional*):
+ Labels of the first sequence (for token classification tasks).
+ num_tokens_to_remove (`int`, *optional*, defaults to 0):
+ Number of tokens to remove using the truncation strategy.
+ truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ The strategy to follow for truncation. Can be:
+
+ - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will truncate
+ token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a
+ batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater
+ than the model maximum admissible input size).
+ stride (`int`, *optional*, defaults to 0):
+ If set to a positive number, the overflowing tokens returned will contain some tokens from the main
+ sequence returned. The value of this argument defines the number of additional tokens.
+
+ Returns:
+ `tuple[list[int], list[int], list[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of
+ overflowing tokens.
+ """
+ if num_tokens_to_remove <= 0:
+ return ids, token_boxes, pair_ids, pair_token_boxes, labels, [], [], []
+
+ if not isinstance(truncation_strategy, TruncationStrategy):
+ truncation_strategy = TruncationStrategy(truncation_strategy)
+
+ overflowing_tokens = []
+ overflowing_token_boxes = []
+ overflowing_labels = []
+ if truncation_strategy == TruncationStrategy.LONGEST_FIRST:
+ for _ in range(num_tokens_to_remove):
+ if pair_ids is None or len(ids) > len(pair_ids):
+ if not overflowing_tokens:
+ window_len = min(len(ids), stride + 1)
+ else:
+ window_len = 1
+ overflowing_tokens.extend(ids[-window_len:])
+ overflowing_token_boxes.extend(token_boxes[-window_len:])
+ overflowing_labels.extend(labels[-window_len:])
+ ids = ids[:-1]
+ token_boxes = token_boxes[:-1]
+ labels = labels[:-1]
+ else:
+ if not overflowing_tokens:
+ window_len = min(len(pair_ids), stride + 1)
+ else:
+ window_len = 1
+ overflowing_tokens.extend(pair_ids[-window_len:])
+ overflowing_token_boxes.extend(pair_token_boxes[-window_len:])
+ pair_ids = pair_ids[:-1]
+ pair_token_boxes = pair_token_boxes[:-1]
+ elif truncation_strategy == TruncationStrategy.ONLY_FIRST:
+ if len(ids) > num_tokens_to_remove:
+ window_len = min(len(ids), stride + num_tokens_to_remove)
+ overflowing_tokens = ids[-window_len:]
+ overflowing_token_boxes = token_boxes[-window_len:]
+ overflowing_labels = labels[-window_len:]
+ ids = ids[:-num_tokens_to_remove]
+ token_boxes = token_boxes[:-num_tokens_to_remove]
+ labels = labels[:-num_tokens_to_remove]
+ else:
+ logger.error(
+ f"We need to remove {num_tokens_to_remove} to truncate the input "
+ f"but the first sequence has a length {len(ids)}. "
+ f"Please select another truncation strategy than {truncation_strategy}, "
+ "for instance 'longest_first' or 'only_second'."
+ )
+ elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
+ if len(pair_ids) > num_tokens_to_remove:
+ window_len = min(len(pair_ids), stride + num_tokens_to_remove)
+ overflowing_tokens = pair_ids[-window_len:]
+ overflowing_token_boxes = pair_token_boxes[-window_len:]
+ pair_ids = pair_ids[:-num_tokens_to_remove]
+ pair_token_boxes = pair_token_boxes[:-num_tokens_to_remove]
+ else:
+ logger.error(
+ f"We need to remove {num_tokens_to_remove} to truncate the input "
+ f"but the second sequence has a length {len(pair_ids)}. "
+ f"Please select another truncation strategy than {truncation_strategy}, "
+ "for instance 'longest_first' or 'only_first'."
+ )
+
+ return (
+ ids,
+ token_boxes,
+ pair_ids,
+ pair_token_boxes,
+ labels,
+ overflowing_tokens,
+ overflowing_token_boxes,
+ overflowing_labels,
+ )
+
+ def _pad(
+ self,
+ encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`list[int]`) or batch of tokenized inputs (`list[list[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta).
+ padding_side (`str`, *optional*):
+ The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+ Default value is picked from the class attribute of the same name.
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if return_attention_mask and "attention_mask" not in encoded_inputs:
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+ padding_side = padding_side if padding_side is not None else self.padding_side
+ if padding_side == "right":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
+ )
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
+ elif padding_side == "left":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
+ "token_type_ids"
+ ]
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"]
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"]
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+ else:
+ raise ValueError("Invalid padding strategy:" + str(padding_side))
+
+ return encoded_inputs
+
+
+__all__ = ["LayoutXLMTokenizer"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py b/venv/lib/python3.13/site-packages/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..6710c6c8cb66ed08da2df391c28ad1db2e6cf81d
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py
@@ -0,0 +1,815 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License
+"""Tokenization classes for LayoutXLM model."""
+
+import os
+from shutil import copyfile
+from typing import Optional, Union
+
+from ...tokenization_utils import AddedToken
+from ...tokenization_utils_base import (
+ BatchEncoding,
+ EncodedInput,
+ PreTokenizedInput,
+ TextInput,
+ TextInputPair,
+ TruncationStrategy,
+)
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import PaddingStrategy, TensorType, add_end_docstrings, is_sentencepiece_available, logging
+from ..xlm_roberta.tokenization_xlm_roberta_fast import (
+ VOCAB_FILES_NAMES,
+)
+
+
+if is_sentencepiece_available():
+ from .tokenization_layoutxlm import LayoutXLMTokenizer
+else:
+ LayoutXLMTokenizer = None
+
+
+logger = logging.get_logger(__name__)
+
+LAYOUTXLM_ENCODE_KWARGS_DOCSTRING = r"""
+ add_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to encode the sequences with the special tokens relative to their model.
+ padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ Activates and controls truncation. Accepts the following values:
+
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
+ to the maximum acceptable input length for the model if that argument is not provided. This will
+ truncate token by token, removing a token from the longest sequence in the pair if a pair of
+ sequences (or a batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters.
+
+ If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
+ is required by one of the truncation/padding parameters. If the model has no specific maximum input
+ length (like XLNet) truncation/padding to a maximum length will be deactivated.
+ stride (`int`, *optional*, defaults to 0):
+ If set to a number along with `max_length`, the overflowing tokens returned when
+ `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
+ returned to provide some overlap between truncated and overflowing sequences. The value of this
+ argument defines the number of overlapping tokens.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
+ the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ return_token_type_ids (`bool`, *optional*):
+ Whether to return token type IDs. If left to the default, will return the token type IDs according to
+ the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are attention masks?](../glossary#attention-mask)
+ return_overflowing_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
+ of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead
+ of returning overflowing tokens.
+ return_special_tokens_mask (`bool`, *optional*, defaults to `False`):
+ Whether or not to return special tokens mask information.
+ return_offsets_mapping (`bool`, *optional*, defaults to `False`):
+ Whether or not to return `(char_start, char_end)` for each token.
+
+ This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using
+ Python's tokenizer, this method will raise `NotImplementedError`.
+ return_length (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the lengths of the encoded inputs.
+ verbose (`bool`, *optional*, defaults to `True`):
+ Whether or not to print more information and warnings.
+ **kwargs: passed to the `self.tokenize()` method
+
+ Return:
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ - **bbox** -- List of bounding boxes to be fed to a model.
+
+ - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or
+ if *"token_type_ids"* is in `self.model_input_names`).
+
+ [What are token type IDs?](../glossary#token-type-ids)
+
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`).
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ - **labels** -- List of labels to be fed to a model. (when `word_labels` is specified).
+ - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and
+ `return_overflowing_tokens=True`).
+ - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and
+ `return_overflowing_tokens=True`).
+ - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying
+ regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`).
+ - **length** -- The length of the inputs (when `return_length=True`).
+"""
+
+
+class LayoutXLMTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" LayoutXLM tokenizer (backed by HuggingFace's *tokenizers* library). Adapted from
+ [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
+ [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ cls_token_box (`list[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [CLS] token.
+ sep_token_box (`list[int]`, *optional*, defaults to `[1000, 1000, 1000, 1000]`):
+ The bounding box to use for the special [SEP] token.
+ pad_token_box (`list[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [PAD] token.
+ pad_token_label (`int`, *optional*, defaults to -100):
+ The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's
+ CrossEntropyLoss.
+ only_label_first_subword (`bool`, *optional*, defaults to `True`):
+ Whether or not to only label the first subword, in case word labels are provided.
+ additional_special_tokens (`list[str]`, *optional*, defaults to `["NOTUSED", "NOTUSED"]`):
+ Additional special tokens used by the tokenizer.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = LayoutXLMTokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ cls_token_box=[0, 0, 0, 0],
+ sep_token_box=[1000, 1000, 1000, 1000],
+ pad_token_box=[0, 0, 0, 0],
+ pad_token_label=-100,
+ only_label_first_subword=True,
+ **kwargs,
+ ):
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
+
+ super().__init__(
+ vocab_file,
+ tokenizer_file=tokenizer_file,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ cls_token_box=cls_token_box,
+ sep_token_box=sep_token_box,
+ pad_token_box=pad_token_box,
+ pad_token_label=pad_token_label,
+ only_label_first_subword=only_label_first_subword,
+ **kwargs,
+ )
+
+ self.vocab_file = vocab_file
+
+ # additional properties
+ self.cls_token_box = cls_token_box
+ self.sep_token_box = sep_token_box
+ self.pad_token_box = pad_token_box
+ self.pad_token_label = pad_token_label
+ self.only_label_first_subword = only_label_first_subword
+
+ @add_end_docstrings(LAYOUTXLM_ENCODE_KWARGS_DOCSTRING)
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
+ text_pair: Optional[Union[PreTokenizedInput, list[PreTokenizedInput]]] = None,
+ boxes: Optional[Union[list[list[int]], list[list[list[int]]]]] = None,
+ word_labels: Optional[Union[list[int], list[list[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
+ sequences with word-level normalized bounding boxes and optional labels.
+
+ Args:
+ text (`str`, `list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings
+ (words of a single example or questions of a batch of examples) or a list of list of strings (batch of
+ words).
+ text_pair (`list[str]`, `list[list[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence should be a list of strings
+ (pretokenized string).
+ boxes (`list[list[int]]`, `list[list[list[int]]]`):
+ Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale.
+ word_labels (`list[int]`, `list[list[int]]`, *optional*):
+ Word-level integer labels (for token classification tasks such as FUNSD, CORD).
+ """
+
+ # Input type checking for clearer error
+ def _is_valid_text_input(t):
+ if isinstance(t, str):
+ # Strings are fine
+ return True
+ elif isinstance(t, (list, tuple)):
+ # List are fine as long as they are...
+ if len(t) == 0:
+ # ... empty
+ return True
+ elif isinstance(t[0], str):
+ # ... list of strings
+ return True
+ elif isinstance(t[0], (list, tuple)):
+ # ... list with an empty list or with a list of strings
+ return len(t[0]) == 0 or isinstance(t[0][0], str)
+ else:
+ return False
+ else:
+ return False
+
+ if text_pair is not None:
+ # in case text + text_pair are provided, text = questions, text_pair = words
+ if not _is_valid_text_input(text):
+ raise ValueError("text input must of type `str` (single example) or `list[str]` (batch of examples). ")
+ if not isinstance(text_pair, (list, tuple)):
+ raise ValueError(
+ "words must of type `list[str]` (single pretokenized example), "
+ "or `list[list[str]]` (batch of pretokenized examples)."
+ )
+ else:
+ # in case only text is provided => must be words
+ if not isinstance(text, (list, tuple)):
+ raise ValueError(
+ "Words must of type `list[str]` (single pretokenized example), "
+ "or `list[list[str]]` (batch of pretokenized examples)."
+ )
+
+ if text_pair is not None:
+ is_batched = isinstance(text, (list, tuple))
+ else:
+ is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
+
+ words = text if text_pair is None else text_pair
+ if boxes is None:
+ raise ValueError("You must provide corresponding bounding boxes")
+ if is_batched:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide words and boxes for an equal amount of examples")
+ for words_example, boxes_example in zip(words, boxes):
+ if len(words_example) != len(boxes_example):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+ else:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+
+ if is_batched:
+ if text_pair is not None and len(text) != len(text_pair):
+ raise ValueError(
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
+ )
+ batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
+ is_pair = bool(text_pair is not None)
+ return self.batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+ else:
+ return self.encode_plus(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> list[str]:
+ batched_input = [(text, pair)] if pair else [text]
+
+ self._tokenizer.encode_special_tokens = kwargs.pop(
+ "split_special_tokens", self._tokenizer.encode_special_tokens
+ )
+
+ encodings = self._tokenizer.encode_batch(
+ batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs
+ )
+
+ return encodings[0].tokens
+
+ def _batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ list[TextInput],
+ list[TextInputPair],
+ list[PreTokenizedInput],
+ ],
+ is_pair: Optional[bool] = None,
+ boxes: Optional[list[list[list[int]]]] = None,
+ word_labels: Optional[list[list[int]]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[str] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ if not isinstance(batch_text_or_text_pairs, list):
+ raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})")
+
+ # Set the truncation and padding strategy and restore the initial configuration
+ self.set_truncation_and_padding(
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ )
+
+ if is_pair:
+ batch_text_or_text_pairs = [(text.split(), text_pair) for text, text_pair in batch_text_or_text_pairs]
+
+ encodings = self._tokenizer.encode_batch(
+ batch_text_or_text_pairs,
+ add_special_tokens=add_special_tokens,
+ is_pretokenized=True, # we set this to True as LayoutLMv2 always expects pretokenized inputs
+ )
+
+ # Convert encoding to dict
+ # `Tokens` has type: tuple[
+ # list[dict[str, list[list[int]]]] or list[dict[str, 2D-Tensor]],
+ # list[EncodingFast]
+ # ]
+ # with nested dimensions corresponding to batch, overflows, sequence length
+ tokens_and_encodings = [
+ self._convert_encoding(
+ encoding=encoding,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=True
+ if word_labels is not None
+ else return_offsets_mapping, # we use offsets to create the labels
+ return_length=return_length,
+ verbose=verbose,
+ )
+ for encoding in encodings
+ ]
+
+ # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
+ # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
+ # (we say ~ because the number of overflow varies with the example in the batch)
+ #
+ # To match each overflowing sample with the original sample in the batch
+ # we add an overflow_to_sample_mapping array (see below)
+ sanitized_tokens = {}
+ for key in tokens_and_encodings[0][0]:
+ stack = [e for item, _ in tokens_and_encodings for e in item[key]]
+ sanitized_tokens[key] = stack
+ sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
+
+ # If returning overflowing tokens, we need to return a mapping
+ # from the batch idx to the original sample
+ if return_overflowing_tokens:
+ overflow_to_sample_mapping = []
+ for i, (toks, _) in enumerate(tokens_and_encodings):
+ overflow_to_sample_mapping += [i] * len(toks["input_ids"])
+ sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
+
+ for input_ids in sanitized_tokens["input_ids"]:
+ self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
+
+ # create the token boxes
+ token_boxes = []
+ for batch_index in range(len(sanitized_tokens["input_ids"])):
+ if return_overflowing_tokens:
+ original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
+ else:
+ original_index = batch_index
+ token_boxes_example = []
+ for id, sequence_id, word_id in zip(
+ sanitized_tokens["input_ids"][batch_index],
+ sanitized_encodings[batch_index].sequence_ids,
+ sanitized_encodings[batch_index].word_ids,
+ ):
+ if word_id is not None:
+ if is_pair and sequence_id == 0:
+ token_boxes_example.append(self.pad_token_box)
+ else:
+ token_boxes_example.append(boxes[original_index][word_id])
+ else:
+ if id == self.cls_token_id:
+ token_boxes_example.append(self.cls_token_box)
+ elif id == self.sep_token_id:
+ token_boxes_example.append(self.sep_token_box)
+ elif id == self.pad_token_id:
+ token_boxes_example.append(self.pad_token_box)
+ else:
+ raise ValueError("Id not recognized")
+ token_boxes.append(token_boxes_example)
+
+ sanitized_tokens["bbox"] = token_boxes
+
+ # optionally, create the labels
+ if word_labels is not None:
+ labels = []
+ for batch_index in range(len(sanitized_tokens["input_ids"])):
+ if return_overflowing_tokens:
+ original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
+ else:
+ original_index = batch_index
+ labels_example = []
+ for id, offset, word_id in zip(
+ sanitized_tokens["input_ids"][batch_index],
+ sanitized_tokens["offset_mapping"][batch_index],
+ sanitized_encodings[batch_index].word_ids,
+ ):
+ if word_id is not None:
+ if self.only_label_first_subword:
+ if offset[0] == 0:
+ # Use the real label id for the first token of the word, and padding ids for the remaining tokens
+ labels_example.append(word_labels[original_index][word_id])
+ else:
+ labels_example.append(self.pad_token_label)
+ else:
+ labels_example.append(word_labels[original_index][word_id])
+ else:
+ labels_example.append(self.pad_token_label)
+ labels.append(labels_example)
+
+ sanitized_tokens["labels"] = labels
+ # finally, remove offsets if the user didn't want them
+ if not return_offsets_mapping:
+ del sanitized_tokens["offset_mapping"]
+
+ return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
+
+ def _encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[list[list[int]]] = None,
+ word_labels: Optional[list[int]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[bool] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ # make it a batched input
+ # 2 options:
+ # 1) only text, in case text must be a list of str
+ # 2) text + text_pair, in which case text = str and text_pair a list of str
+ batched_input = [(text, text_pair)] if text_pair else [text]
+ batched_boxes = [boxes]
+ batched_word_labels = [word_labels] if word_labels is not None else None
+ batched_output = self._batch_encode_plus(
+ batched_input,
+ is_pair=bool(text_pair is not None),
+ boxes=batched_boxes,
+ word_labels=batched_word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ # Return tensor is None, then we can remove the leading batch axis
+ # Overflowing tokens are returned as a batch of output so we keep them in this case
+ if return_tensors is None and not return_overflowing_tokens:
+ batched_output = BatchEncoding(
+ {
+ key: value[0] if len(value) > 0 and isinstance(value[0], list) else value
+ for key, value in batched_output.items()
+ },
+ batched_output.encodings,
+ )
+
+ self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
+
+ return batched_output
+
+ def _pad(
+ self,
+ encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`list[int]`) or batch of tokenized inputs (`list[list[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta).
+ padding_side (`str`, *optional*):
+ The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+ Default value is picked from the class attribute of the same name.
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if return_attention_mask and "attention_mask" not in encoded_inputs:
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+ padding_side = padding_side if padding_side is not None else self.padding_side
+ if padding_side == "right":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
+ )
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
+ elif padding_side == "left":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
+ "token_type_ids"
+ ]
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"]
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"]
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+ else:
+ raise ValueError("Invalid padding strategy:" + str(padding_side))
+
+ return encoded_inputs
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. An XLM-RoBERTa sequence has the following format:
+
+ - single sequence: ` X `
+ - pair of sequences: ` A B `
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does
+ not make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of IDs.
+ token_ids_1 (`list[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `list[int]`: List of zeros.
+
+ """
+
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ if not self.can_save_slow_tokenizer:
+ raise ValueError(
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+ "tokenizer."
+ )
+
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory.")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
+
+
+__all__ = ["LayoutXLMTokenizerFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/__init__.py b/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b24d99815e0122c9b3e07c86547d074777e07b8
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_llava_onevision import *
+ from .image_processing_llava_onevision import *
+ from .image_processing_llava_onevision_fast import *
+ from .modeling_llava_onevision import *
+ from .processing_llava_onevision import *
+ from .video_processing_llava_onevision import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/configuration_llava_onevision.py b/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/configuration_llava_onevision.py
new file mode 100644
index 0000000000000000000000000000000000000000..21ead3df17061f3c3545a9d997396babe6d7fbd6
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/configuration_llava_onevision.py
@@ -0,0 +1,194 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import (
+ logging,
+)
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class LlavaOnevisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LlavaOnevisionForConditionalGeneration`]. It is used to instantiate an
+ Llava-NeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the [llava-hf/llava-onevision-qwen2-7b-ov-hf](https://huggingface.co/llava-hf/llava-onevision-qwen2-7b-ov-hf)
+ model.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SiglipVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
+ The config object or dictionary of the text backbone.
+ image_token_index (`int`, *optional*, defaults to 151646):
+ The image token index to encode the image prompt.
+ video_token_index (`int`, *optional*, defaults to 151647):
+ The video token index to encode the video prompt.
+ projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The activation function used by the multimodal projector.
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
+ If `"full"`, the full vision features are used.
+ vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -1):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
+ Aspect ratio used when processong image features. The default value is "anyres_max_9".
+ image_grid_pinpoints (`List`, *optional*):
+ A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
+ Whether to use bias in the multimodal projector.
+
+ Example:
+
+ ```python
+ >>> from transformers import LlavaOnevisionForConditionalGeneration, LlavaOnevisionConfig, SiglipVisionConfig, Qwen2Config
+
+ >>> # Initializing a CLIP-vision config
+ >>> vision_config = SiglipVisionConfig()
+
+ >>> # Initializing a Llama config
+ >>> text_config = Qwen2Config()
+
+ >>> # Initializing a Llava-Next llava-hf/llava-onevision-qwen2-7b-ov-hf style configuration
+ >>> configuration = LlavaOnevisionConfig(vision_config, text_config)
+
+ >>> # Initializing a model from the llava-hf/llava-onevision-qwen2-7b-ov-hf style configuration
+ >>> model = LlavaOnevisionForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "llava_onevision"
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ "video_token_id": "video_token_index",
+ }
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ image_token_index=151646,
+ video_token_index=151647,
+ projector_hidden_act="gelu",
+ vision_feature_select_strategy="full",
+ vision_feature_layer=-1,
+ vision_aspect_ratio="anyres_max_9",
+ image_grid_pinpoints=None,
+ tie_word_embeddings=False,
+ multimodal_projector_bias=True,
+ **kwargs,
+ ):
+ self.image_token_index = image_token_index
+ self.video_token_index = video_token_index
+ self.projector_hidden_act = projector_hidden_act
+ self.multimodal_projector_bias = multimodal_projector_bias
+
+ if vision_feature_select_strategy not in ["default", "full"]:
+ raise ValueError(
+ "vision_feature_select_strategy should be one of 'default', 'full'."
+ f"Got: {vision_feature_select_strategy}"
+ )
+
+ self.vision_feature_select_strategy = vision_feature_select_strategy
+ self.vision_feature_layer = vision_feature_layer
+ self.vision_aspect_ratio = vision_aspect_ratio
+ image_grid_pinpoints = (
+ image_grid_pinpoints
+ if image_grid_pinpoints is not None
+ else [
+ [384, 384],
+ [384, 768],
+ [384, 1152],
+ [384, 1536],
+ [384, 1920],
+ [384, 2304],
+ [768, 384],
+ [768, 768],
+ [768, 1152],
+ [768, 1536],
+ [768, 1920],
+ [768, 2304],
+ [1152, 384],
+ [1152, 768],
+ [1152, 1152],
+ [1152, 1536],
+ [1152, 1920],
+ [1152, 2304],
+ [1536, 384],
+ [1536, 768],
+ [1536, 1152],
+ [1536, 1536],
+ [1536, 1920],
+ [1536, 2304],
+ [1920, 384],
+ [1920, 768],
+ [1920, 1152],
+ [1920, 1536],
+ [1920, 1920],
+ [1920, 2304],
+ [2304, 384],
+ [2304, 768],
+ [2304, 1152],
+ [2304, 1536],
+ [2304, 1920],
+ [2304, 2304],
+ ]
+ )
+ self.image_grid_pinpoints = image_grid_pinpoints
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "siglip_vision_model")
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ elif vision_config is None:
+ vision_config = CONFIG_MAPPING["siglip_vision_model"](
+ hidden_size=1152,
+ intermediate_size=4304,
+ patch_size=14,
+ image_size=384,
+ num_hidden_layers=26,
+ num_attention_heads=16,
+ vision_use_head=False,
+ )
+
+ self.vision_config = vision_config
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "qwen2")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ text_config = CONFIG_MAPPING["qwen2"]()
+
+ self.text_config = text_config
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+__all__ = ["LlavaOnevisionConfig"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/image_processing_llava_onevision.py b/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/image_processing_llava_onevision.py
new file mode 100644
index 0000000000000000000000000000000000000000..836a1984a522fdd5e6cf0db15aca94f0bffddc22
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/image_processing_llava_onevision.py
@@ -0,0 +1,786 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for LLaVa-Onevision."""
+
+from collections.abc import Iterable
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import (
+ BaseImageProcessor,
+ BatchFeature,
+ get_patch_output_size,
+ get_size_dict,
+ select_best_resolution,
+)
+from ...image_transforms import (
+ PaddingMode,
+ convert_to_rgb,
+ pad,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, is_vision_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_vision_available():
+ from PIL import Image
+
+
+# Copied from transformers.models.llava_next.image_processing_llava_next.divide_to_patches
+def divide_to_patches(image: np.ndarray, patch_size: int, input_data_format) -> list[np.ndarray]:
+ """
+ Divides an image into patches of a specified size.
+
+ Args:
+ image (`np.ndarray`):
+ The input image.
+ patch_size (`int`):
+ The size of each patch.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ list: A list of np.ndarray representing the patches.
+ """
+ patches = []
+ height, width = get_image_size(image, channel_dim=input_data_format)
+ for i in range(0, height, patch_size):
+ for j in range(0, width, patch_size):
+ if input_data_format == ChannelDimension.LAST:
+ patch = image[i : i + patch_size, j : j + patch_size]
+ else:
+ patch = image[:, i : i + patch_size, j : j + patch_size]
+ patches.append(patch)
+
+ return patches
+
+
+# Copied from transformers.models.llava_next.image_processing_llava_next.expand_to_square
+def expand_to_square(image: np.ndarray, background_color, input_data_format) -> np.ndarray:
+ """
+ Expands an image to a square by adding a background color.
+ """
+
+ height, width = get_image_size(image, channel_dim=input_data_format)
+ if width == height:
+ return image
+ elif width > height:
+ result = np.ones((width, width, image.shape[2]), dtype=image.dtype) * background_color
+ result[(width - height) // 2 : (width - height) // 2 + height, :] = image
+ return result
+ else:
+ result = np.ones((height, height, image.shape[2]), dtype=image.dtype) * background_color
+ result[:, (height - width) // 2 : (height - width) // 2 + width] = image
+ return result
+
+
+class LlavaOnevisionImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a LLaVa-Onevision image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in the `preprocess` method.
+ size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
+ method.
+ image_grid_pinpoints (`List` *optional*, defaults to `[[672, 336], [336, 672], [672, 672], [336, 1008], [1008, 336]]`):
+ A list of possible resolutions to use for processing high resolution images. The best resolution is selected
+ based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
+ method. Not used for processing videos.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
+ number of patches in the batch. Padding will be applied to the bottom and right with zeros.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ """
+
+ model_input_names = ["pixel_values", "image_sizes", "batch_num_images"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[dict[str, int]] = None,
+ image_grid_pinpoints: Optional[list] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = True,
+ do_convert_rgb: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 384, "width": 384}
+ size = get_size_dict(size, default_to_square=False)
+ image_grid_pinpoints = (
+ image_grid_pinpoints
+ if image_grid_pinpoints is not None
+ else [
+ [384, 384],
+ [384, 768],
+ [384, 1152],
+ [384, 1536],
+ [384, 1920],
+ [384, 2304],
+ [768, 384],
+ [768, 768],
+ [768, 1152],
+ [768, 1536],
+ [768, 1920],
+ [768, 2304],
+ [1152, 384],
+ [1152, 768],
+ [1152, 1152],
+ [1152, 1536],
+ [1152, 1920],
+ [1152, 2304],
+ [1536, 384],
+ [1536, 768],
+ [1536, 1152],
+ [1536, 1536],
+ [1536, 1920],
+ [1536, 2304],
+ [1920, 384],
+ [1920, 768],
+ [1920, 1152],
+ [1920, 1536],
+ [1920, 1920],
+ [1920, 2304],
+ [2304, 384],
+ [2304, 768],
+ [2304, 1152],
+ [2304, 1536],
+ [2304, 1920],
+ [2304, 2304],
+ ]
+ )
+
+ self.do_resize = do_resize
+ self.size = size
+ self.image_grid_pinpoints = image_grid_pinpoints
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+ self.do_pad = do_pad
+ self.do_convert_rgb = do_convert_rgb
+
+ # Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor.pad
+ def pad(
+ self,
+ image: np.ndarray,
+ padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]],
+ mode: PaddingMode = PaddingMode.CONSTANT,
+ constant_values: Union[float, Iterable[float]] = 0.0,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
+ dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
+ as input.
+
+ Args:
+ image (`np.ndarray`):
+ The image to pad.
+ padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
+ Padding to apply to the edges of the height, width axes. Can be one of three formats:
+ - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
+ - `((before, after),)` yields same before and after pad for height and width.
+ - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
+ mode (`PaddingMode`):
+ The padding mode to use. Can be one of:
+ - `"constant"`: pads with a constant value.
+ - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
+ vector along each axis.
+ - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
+ - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
+ constant_values (`float` or `Iterable[float]`, *optional*):
+ The value to use for the padding if `mode` is `"constant"`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+
+ Returns:
+ `np.ndarray`: The padded image.
+
+ """
+
+ # call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
+ if isinstance(padding, int) or len(padding) != 4:
+ return pad(image, padding, mode, constant_values, data_format, input_data_format)
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+ if mode == PaddingMode.CONSTANT:
+ image = np.pad(image, padding, mode="constant", constant_values=constant_values)
+ elif mode == PaddingMode.REFLECT:
+ image = np.pad(image, padding, mode="reflect")
+ elif mode == PaddingMode.REPLICATE:
+ image = np.pad(image, padding, mode="edge")
+ elif mode == PaddingMode.SYMMETRIC:
+ image = np.pad(image, padding, mode="symmetric")
+ else:
+ raise ValueError(f"Invalid padding mode: {mode}")
+ image = (
+ to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
+ )
+ return image
+
+ # Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._resize_for_patching
+ def _resize_for_patching(
+ self, image: np.ndarray, target_resolution: tuple, resample, input_data_format: ChannelDimension
+ ) -> np.ndarray:
+ """
+ Resizes an image to a target resolution while maintaining aspect ratio.
+
+ Args:
+ image (np.ndarray):
+ The input image.
+ target_resolution (tuple):
+ The target resolution (height, width) of the image.
+ resample (`PILImageResampling`):
+ Resampling filter to use if resizing the image.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ np.ndarray: The resized and padded image.
+ """
+ new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
+
+ # Resize the image
+ resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
+
+ return resized_image
+
+ # Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._get_padding_size
+ def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
+ original_height, original_width = original_resolution
+ target_height, target_width = target_resolution
+ paste_x, r_x = divmod(target_width - original_width, 2)
+ paste_y, r_y = divmod(target_height - original_height, 2)
+ return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
+
+ # Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._pad_for_patching
+ def _pad_for_patching(
+ self, image: np.ndarray, target_resolution: tuple, input_data_format: ChannelDimension
+ ) -> np.ndarray:
+ """
+ Pad an image to a target resolution while maintaining aspect ratio.
+ """
+ new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
+ padding = self._get_padding_size(new_resolution, target_resolution)
+
+ padded_image = self.pad(image, padding=padding)
+
+ return padded_image
+
+ # Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor.get_image_patches
+ def get_image_patches(
+ self,
+ image: np.ndarray,
+ grid_pinpoints,
+ size: tuple,
+ patch_size: int,
+ resample: PILImageResampling,
+ data_format: ChannelDimension,
+ input_data_format: ChannelDimension,
+ ) -> list[np.ndarray]:
+ """
+ Process an image with variable resolutions by dividing it into patches.
+
+ Args:
+ image (np.ndarray):
+ The input image to be processed.
+ grid_pinpoints (List):
+ A string representation of a list of possible resolutions.
+ size (`tuple`):
+ Size to resize the original image to.
+ patch_size (`int`):
+ Size of the patches to divide the image into.
+ resample (`PILImageResampling`):
+ Resampling filter to use if resizing the image.
+ data_format (`ChannelDimension` or `str`):
+ The channel dimension format for the output image.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ list[np.ndarray]: A list of NumPy arrays containing the processed image patches.
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints must be a list of possible resolutions.")
+
+ possible_resolutions = grid_pinpoints
+
+ image_size = get_image_size(image, channel_dim=input_data_format)
+ best_resolution = select_best_resolution(image_size, possible_resolutions)
+ resized_image = self._resize_for_patching(
+ image, best_resolution, resample=resample, input_data_format=input_data_format
+ )
+ padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
+
+ patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format)
+
+ # make sure that all patches are in the input data format
+ patches = [
+ to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format)
+ for patch in patches
+ ]
+
+ resized_original_image = resize(
+ image,
+ size=size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+
+ image_patches = [resized_original_image] + patches
+
+ return image_patches
+
+ # Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._pad_for_batching
+ def _pad_for_batching(
+ self,
+ pixel_values: list[np.ndarray],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
+
+ Args:
+ pixel_values (`list[np.ndarray]`):
+ An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+
+ Returns:
+ list[`np.ndarray`]: The padded images.
+ """
+ max_patch = max(len(x) for x in pixel_values)
+ pixel_values = [
+ self.pad(
+ image,
+ padding=((0, max_patch - image.shape[0]), (0, 0), (0, 0), (0, 0)),
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ for image in pixel_values
+ ]
+
+ return pixel_values
+
+ # Copied from transformers.models.llava.image_processing_llava.LlavaImageProcessor.pad_to_square
+ def pad_to_square(
+ self,
+ image: np.ndarray,
+ background_color: Union[int, tuple[int, int, int]] = 0,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pads an image to a square based on the longest edge.
+
+ Args:
+ image (`np.ndarray`):
+ The image to pad.
+ background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
+ The color to use for the padding. Can be an integer for single channel or a
+ tuple of integers representing for multi-channel images. If passed as integer
+ in multi-channel mode, it will default to `0` in subsequent channels.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+
+ Returns:
+ `np.ndarray`: The padded image.
+ """
+ height, width = get_image_size(image, input_data_format)
+ num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
+
+ if height == width:
+ image = (
+ to_channel_dimension_format(image, data_format, input_data_format)
+ if data_format is not None
+ else image
+ )
+ return image
+
+ max_dim = max(height, width)
+
+ # Ensure background_color is the correct shape
+ if isinstance(background_color, int):
+ background_color = [background_color]
+ elif len(background_color) != num_channels:
+ raise ValueError(
+ f"background_color must have no more than {num_channels} elements to match the number of channels"
+ )
+
+ if input_data_format == ChannelDimension.FIRST:
+ result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
+ for i, color in enumerate(background_color):
+ result[i, :, :] = color
+ if width > height:
+ start = (max_dim - height) // 2
+ result[:, start : start + height, :] = image
+ else:
+ start = (max_dim - width) // 2
+ result[:, :, start : start + width] = image
+ else:
+ result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
+ for i, color in enumerate(background_color):
+ result[:, :, i] = color
+ if width > height:
+ start = (max_dim - height) // 2
+ result[start : start + height, :, :] = image
+ else:
+ start = (max_dim - width) // 2
+ result[:, start : start + width, :] = image
+
+ image = (
+ to_channel_dimension_format(result, data_format, input_data_format) if data_format is not None else result
+ )
+ return image
+
+ def _preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> Image.Image:
+ """
+ Args:
+ images (`ImageInput`):
+ Batch of frames (one video) to preprocess. Expects a batch of frames with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ if do_resize:
+ images = [
+ resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_rescale:
+ images = [
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_normalize:
+ images = [
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+ ]
+
+ return images
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ image_grid_pinpoints: Optional[list] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ do_convert_rgb: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ image_grid_pinpoints (`List` *optional*, defaults to `self.image_grid_pinpoints`):
+ A list of possible resolutions to use for processing high resolution images. The best resolution is
+ selected based on the original size of the image.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
+ number of patches in the batch. Padding will be applied to the bottom and right with zeros.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+ image_grid_pinpoints = image_grid_pinpoints if image_grid_pinpoints is not None else self.image_grid_pinpoints
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)):
+ # if the first element is a list, we assume that all elements are lists
+ images = [x for x in images if x] # handle text-only case
+ batch_num_images = [len(x) for x in images]
+ elif isinstance(images, (tuple, list)):
+ # treat this as a single-image case for backward compatibility
+ batch_num_images = [1] * len(images)
+ else:
+ batch_num_images = [1]
+ # only single image patching is supported
+ need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
+
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ size_tuple = (
+ (size["height"], size["width"])
+ if "height" in size and "width" in size
+ else (size["shortest_edge"], size["shortest_edge"])
+ )
+
+ processed_images = []
+ image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
+ for i, image in enumerate(images):
+ if need_patching[i]:
+ # convert image into a list of patches
+ # we intentionally use the same data format as the input data format
+ image_patches = self.get_image_patches(
+ image,
+ image_grid_pinpoints,
+ size=size_tuple,
+ patch_size=size_tuple[0],
+ resample=resample,
+ data_format=input_data_format,
+ input_data_format=input_data_format,
+ )
+ else:
+ padded_image = self.pad_to_square(
+ image=image,
+ background_color=tuple(int(x * 255) for x in self.image_mean),
+ input_data_format=input_data_format,
+ )
+ image_patches = [padded_image]
+
+ # preprocess patches
+ pixel_values = self._preprocess(
+ image_patches,
+ do_resize=do_resize,
+ size=size_tuple,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ pixel_values = np.array(pixel_values)
+ processed_images.append(pixel_values)
+
+ if do_pad:
+ processed_images = self._pad_for_batching(processed_images)
+
+ return BatchFeature(
+ data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images},
+ tensor_type=return_tensors,
+ )
+
+
+__all__ = ["LlavaOnevisionImageProcessor"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py b/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..11872cb67bf3a8806c08c08ab698a99cbbfaa006
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
@@ -0,0 +1,346 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/llava_onevision/modular_llava_onevision.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_llava_onevision.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils import BatchFeature, get_patch_output_size, select_best_resolution
+from ...image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ divide_to_patches,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+ get_image_size,
+)
+from ...processing_utils import Unpack
+from ...utils import TensorType, auto_docstring
+
+
+class LlavaOnevisionFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ image_grid_pinpoints (`list[list[int]]`, *optional*):
+ A list of possible resolutions to use for processing high resolution images. The best resolution is selected
+ based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
+ method.
+ """
+
+ image_grid_pinpoints: Optional[list[list[int]]]
+
+
+@auto_docstring
+class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ size = {"height": 384, "width": 384}
+ default_to_square = False
+ crop_size = None
+ do_resize = True
+ do_center_crop = None
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+ do_pad = True
+ image_grid_pinpoints = [[384, 384], [384, 768], [384, 1152], [384, 1536], [384, 1920], [384, 2304], [768, 384], [768, 768], [768, 1152], [768, 1536], [768, 1920], [768, 2304], [1152, 384], [1152, 768], [1152, 1152], [1152, 1536], [1152, 1920], [1152, 2304], [1536, 384], [1536, 768], [1536, 1152], [1536, 1536], [1536, 1920], [1536, 2304], [1920, 384], [1920, 768], [1920, 1152], [1920, 1536], [1920, 1920], [1920, 2304], [2304, 384], [2304, 768], [2304, 1152], [2304, 1536], [2304, 1920], [2304, 2304]] # fmt: skip
+ valid_kwargs = LlavaOnevisionFastImageProcessorKwargs
+ model_input_names = ["pixel_values", "image_sizes", "batch_num_images"]
+
+ def __init__(self, **kwargs: Unpack[LlavaOnevisionFastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaOnevisionFastImageProcessorKwargs]) -> BatchFeature:
+ if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)):
+ # if the first element is a list, we assume that all elements are lists
+ batch_num_images = [len(x) for x in images]
+ elif isinstance(images, (tuple, list)):
+ # treat this as a single-image case for backward compatibility
+ batch_num_images = [1] * len(images)
+ else:
+ batch_num_images = [1]
+ kwargs["batch_num_images"] = batch_num_images
+ return super().preprocess(images, **kwargs)
+
+ def _resize_for_patching(
+ self,
+ image: "torch.Tensor",
+ target_resolution: tuple,
+ interpolation: "F.InterpolationMode",
+ input_data_format: ChannelDimension,
+ ) -> "torch.Tensor":
+ """
+ Resizes an image to a target resolution while maintaining aspect ratio.
+
+ Args:
+ image ("torch.Tensor"):
+ The input image.
+ target_resolution (tuple):
+ The target resolution (height, width) of the image.
+ interpolation (`InterpolationMode`):
+ Resampling filter to use if resizing the image.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ "torch.Tensor": The resized and padded image.
+ """
+ new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
+
+ # Resize the image
+ resized_image = self.resize(
+ image=image,
+ size=SizeDict(height=new_height, width=new_width),
+ interpolation=interpolation,
+ )
+
+ return resized_image
+
+ def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
+ original_height, original_width = original_resolution
+ target_height, target_width = target_resolution
+ paste_x, r_x = divmod(target_width - original_width, 2)
+ paste_y, r_y = divmod(target_height - original_height, 2)
+ return [paste_x, paste_y, paste_x + r_x, paste_y + r_y]
+
+ def _pad_for_patching(
+ self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
+ ) -> "torch.Tensor":
+ """
+ Pad an image to a target resolution while maintaining aspect ratio.
+ """
+ new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
+ padding = self._get_padding_size(new_resolution, target_resolution)
+
+ padded_image = F.pad(image, padding=padding)
+
+ return padded_image
+
+ def _get_image_patches(
+ self,
+ image: "torch.Tensor",
+ grid_pinpoints,
+ size: tuple,
+ patch_size: int,
+ interpolation: "F.InterpolationMode",
+ ) -> list["torch.Tensor"]:
+ """
+ Process an image with variable resolutions by dividing it into patches.
+
+ Args:
+ image ("torch.Tensor"):
+ The input image to be processed.
+ grid_pinpoints (List):
+ A string representation of a list of possible resolutions.
+ size (`tuple`):
+ Size to resize the original image to.
+ patch_size (`int`):
+ Size of the patches to divide the image into.
+ interpolation (`"InterpolationMode"`):
+ Resampling filter to use if resizing the image.
+
+ Returns:
+ list["torch.Tensor"]: A list of NumPy arrays containing the processed image patches.
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints must be a list of possible resolutions.")
+
+ possible_resolutions = grid_pinpoints
+
+ image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
+ best_resolution = select_best_resolution(image_size, possible_resolutions)
+ resized_image = self._resize_for_patching(
+ image, best_resolution, interpolation=interpolation, input_data_format=ChannelDimension.FIRST
+ )
+ padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=ChannelDimension.FIRST)
+ patches = divide_to_patches(padded_image, patch_size=patch_size)
+ resized_original_image = F.resize(image, size=size, interpolation=interpolation)
+
+ image_patches = [resized_original_image] + patches
+
+ return image_patches
+
+ def _pad_for_batching(
+ self,
+ pixel_values: list["torch.Tensor"],
+ ) -> list["torch.Tensor"]:
+ """
+ Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
+
+ Args:
+ pixel_values (`list[torch.Tensor]`):
+ An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
+
+ Returns:
+ list[`torch.Tensor`]: The padded images.
+ """
+ max_patch = max(len(x) for x in pixel_values)
+ pixel_values = [
+ torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
+ for image in pixel_values
+ ]
+
+ return pixel_values
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ image_grid_pinpoints: list[list[int]],
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ do_pad: bool,
+ batch_num_images: list[int],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ processed_images = []
+ image_sizes = []
+
+ # only single image patching is supported
+ need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
+
+ # Determine the size tuple
+ if size and size.height and size.width:
+ size_tuple = (size.height, size.width)
+ else:
+ size_tuple = (size.shortest_edge, size.shortest_edge)
+
+ # Determine the patch size
+ if crop_size and crop_size.height:
+ patch_size = crop_size.height
+ elif size and size.height:
+ patch_size = size.height
+ else:
+ patch_size = size.shortest_edge
+
+ for i, image in enumerate(images):
+ if need_patching[i]:
+ image_patches = self._get_image_patches(
+ image,
+ image_grid_pinpoints,
+ size=size_tuple,
+ patch_size=patch_size,
+ interpolation=interpolation,
+ )
+ else:
+ padded_image = self.pad_to_square(
+ images=image, background_color=tuple(int(x * 255) for x in self.image_mean)
+ )
+ image_patches = [padded_image]
+
+ # Group images by size for batched processing
+ processed_image_patches_grouped = {}
+ grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
+ image_patches, disable_grouping=disable_grouping
+ )
+ for shape, stacked_image_patches in grouped_image_patches.items():
+ if do_resize:
+ stacked_image_patches = self.resize(
+ image=stacked_image_patches,
+ size=size,
+ interpolation=interpolation,
+ )
+ if do_center_crop:
+ stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
+ # Fused rescale and normalize
+ stacked_image_patches = self.rescale_and_normalize(
+ stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_image_patches_grouped[shape] = stacked_image_patches
+ processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
+ processed_image_patches = (
+ torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
+ )
+ processed_images.append(processed_image_patches)
+ image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
+
+ if do_pad:
+ processed_images = self._pad_for_batching(processed_images)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+ return BatchFeature(
+ data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images},
+ tensor_type=return_tensors,
+ )
+
+ # Copied from transformers.models.llava.image_processing_llava_fast.LlavaImageProcessorFast.pad_to_square
+ def pad_to_square(
+ self,
+ images: "torch.Tensor",
+ background_color: Union[int, tuple[int, int, int]] = 0,
+ ) -> "torch.Tensor":
+ """
+ Pads an image to a square based on the longest edge.
+
+ Args:
+ images (`np.ndarray`):
+ The images to pad.
+ background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
+ The color to use for the padding. Can be an integer for single channel or a
+ tuple of integers representing for multi-channel images. If passed as integer
+ in multi-channel mode, it will default to `0` in subsequent channels.
+ Returns:
+ `torch.Tensor`: The padded images.
+ """
+ height, width = get_image_size(images, ChannelDimension.FIRST)
+
+ if height == width:
+ return images
+
+ num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
+ if isinstance(background_color, int):
+ background_color = [background_color] + [0] * (num_channels - 1)
+ elif len(background_color) != num_channels:
+ raise ValueError(
+ f"background_color must have no more than {num_channels} elements to match the number of channels"
+ )
+
+ max_dim = max(height, width)
+ paste_x_left = (max_dim - width) // 2
+ paste_y_left = (max_dim - height) // 2
+ paste_x_right = max_dim - width - paste_x_left
+ paste_y_right = max_dim - height - paste_y_left
+ padded_images = F.pad(
+ images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color
+ )
+
+ return padded_images
+
+
+__all__ = ["LlavaOnevisionImageProcessorFast"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/modeling_llava_onevision.py b/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/modeling_llava_onevision.py
new file mode 100644
index 0000000000000000000000000000000000000000..727655374574c62ef2c351358dd28cd44b2ed32a
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/modeling_llava_onevision.py
@@ -0,0 +1,960 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/llava_onevision/modular_llava_onevision.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_llava_onevision.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...image_processing_utils import select_best_resolution
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ..auto import AutoModel
+from .configuration_llava_onevision import LlavaOnevisionConfig
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class LlavaOnevisionModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ video_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
+ video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+ video_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for LlavaOnevision causal language model (or autoregressive) outputs.
+ """
+)
+class LlavaOnevisionCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ video_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
+ video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+ video_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@auto_docstring
+class LlavaOnevisionPreTrainedModel(PreTrainedModel):
+ config: LlavaOnevisionConfig
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
+
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, LlavaOnevisionModel):
+ embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
+ module.image_newline.data.normal_(mean=0.0, std=embed_std)
+
+
+class LlavaOnevisionMultiModalProjector(nn.Module):
+ def __init__(self, config: LlavaOnevisionConfig):
+ super().__init__()
+ # We have hidden_size * the number of vision feature layers
+ num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
+ self.linear_1 = nn.Linear(
+ config.vision_config.hidden_size * num_feature_layers,
+ config.text_config.hidden_size,
+ bias=config.multimodal_projector_bias,
+ )
+ self.act = ACT2FN[config.projector_hidden_act]
+ self.linear_2 = nn.Linear(
+ config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
+ )
+
+ def forward(self, image_features):
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+ """
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (`tuple`):
+ The size of the input image in the format (width, height).
+ grid_pinpoints (`List`):
+ A list containing possible resolutions. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ patch_size (`int`):
+ The size of each image patch.
+
+ Returns:
+ tuple: The shape of the image patch grid in the format (width, height).
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints should be a list of tuples or lists")
+
+ # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
+ if not isinstance(image_size, (list, tuple)):
+ if not isinstance(image_size, (torch.Tensor, np.ndarray)):
+ raise TypeError(
+ f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
+ )
+ image_size = image_size.tolist()
+
+ height, width = select_best_resolution(image_size, grid_pinpoints)
+ return height // patch_size, width // patch_size
+
+
+def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
+ """
+ Calculate the number of patches after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (`torch.LongTensor` or `np.ndarray` or `tuple[int, int]`):
+ The size of the input image in the format (height, width). ?
+ grid_pinpoints (`List`):
+ A list containing possible resolutions. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ patch_size (`int`):
+ The size of each image patch.
+
+ Returns:
+ int: the number of patches
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints should be a list of tuples or lists")
+
+ # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
+ if not isinstance(image_size, (list, tuple)):
+ if not isinstance(image_size, (torch.Tensor, np.ndarray)):
+ raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
+ image_size = image_size.tolist()
+
+ best_resolution = select_best_resolution(image_size, grid_pinpoints)
+ height, width = best_resolution
+ num_patches = 0
+ # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
+ for i in range(0, height, patch_size):
+ for j in range(0, width, patch_size):
+ num_patches += 1
+ # add the base patch
+ num_patches += 1
+ return num_patches
+
+
+def unpad_image(tensor, original_size):
+ """
+ Unpads a PyTorch tensor of a padded and resized image.
+
+ Args:
+ tensor (`torch.Tensor`):
+ The image tensor, assumed to be of shape (num_channels, height, width).
+ original_size (`tuple`):
+ The original size of the image (height, width).
+
+ Returns:
+ `torch.Tensor`: The unpadded image tensor.
+ """
+ if not isinstance(original_size, (list, tuple)):
+ if not isinstance(original_size, (torch.Tensor, np.ndarray)):
+ raise TypeError(
+ f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
+ )
+ original_size = original_size.tolist()
+ original_height, original_width = original_size
+ current_height, current_width = tensor.shape[1:]
+
+ original_aspect_ratio = original_width / original_height
+ current_aspect_ratio = current_width / current_height
+
+ if original_aspect_ratio > current_aspect_ratio:
+ scale_factor = current_width / original_width
+ new_height = int(round(original_height * scale_factor, 7))
+ padding = (current_height - new_height) // 2
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
+ else:
+ scale_factor = current_height / original_height
+ new_width = int(round(original_width * scale_factor, 7))
+ padding = (current_width - new_width) // 2
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
+
+ return unpadded_tensor
+
+
+@auto_docstring(
+ custom_intro="""
+ The Llava-Next model which consists of a vision backbone and a language model without language modeling head.
+ """
+)
+class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config.vision_config)
+
+ self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
+ embed_std = 1 / math.sqrt(config.text_config.hidden_size)
+ self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
+
+ self.vocab_size = config.text_config.vocab_size
+ self.language_model = AutoModel.from_config(config.text_config)
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"):
+ """
+ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
+
+ Args:
+ image_features (`list[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
+ List of image feature tensor, each contains all the visual feature of all patches.
+ image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
+ Actual image size of each images (H, W).
+ image_newline (`torch.Tensor` of shape `(embed_dim)`)
+ New line embedding vector.
+ vision_aspect_ratio (`str`, *optional*, "anyres_max_9"):
+ Aspect ratio used when processong image features. The default value is "anyres_max_9".
+ Returns:
+ image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
+ feature_lens (`list[int]`)
+ token length of each image in image_features
+ """
+ new_image_features = []
+ feature_lens = []
+ for image_idx, image_feature in enumerate(image_features):
+ if image_feature.shape[0] > 1:
+ base_image_feature = image_feature[0]
+ image_feature = image_feature[1:]
+ height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
+ if height * width != base_image_feature.shape[0]:
+ raise ValueError("The number of patches is not consistent with the image size.")
+ num_patch_height, num_patch_width = get_anyres_image_grid_shape(
+ image_sizes[image_idx],
+ self.config.image_grid_pinpoints,
+ self.config.vision_config.image_size,
+ )
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
+ max_num_patches = int(vision_aspect_ratio.strip("anyres_max_"))
+ channels, curr_height, curr_width = image_feature.shape
+ ratio = math.sqrt(curr_height * curr_width / (max_num_patches * height**2))
+ if ratio > 1.1:
+ image_feature = image_feature[None]
+ image_feature = nn.functional.interpolate(
+ image_feature, [int(curr_height // ratio), int(curr_width // ratio)], mode="bilinear"
+ )[0]
+ if image_newline is not None:
+ image_feature = torch.cat(
+ (
+ image_feature,
+ image_newline[:, None, None]
+ .expand(*image_feature.shape[:-1], 1)
+ .to(image_feature.device, image_feature.dtype),
+ ),
+ dim=-1,
+ )
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
+ else:
+ image_feature = image_feature[0]
+ if image_newline is not None:
+ image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
+ new_image_features.append(image_feature)
+ feature_lens.append(image_feature.size(0))
+ feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device)
+ return new_image_features, feature_lens
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_sizes: torch.Tensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ vision_aspect_ratio: Optional[str] = None,
+ batch_num_images: Optional[torch.LongTensor] = None,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
+ The tensors corresponding to the input images.
+ image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
+ Actual image size of each images (H, W).
+ vision_feature_layer (`Union[int, list[int]]`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ batch_num_images (`torch.LongTensor`, *optional*):
+ Number of images in each sample.
+ Returns:
+ image_features (list[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
+ and are of shape `(num_patches, image_length, embed_dim)`).
+ """
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+ vision_aspect_ratio = (
+ vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
+ )
+
+ # ! infer image_num_patches from image_sizes
+ if batch_num_images is None:
+ # treat this as a single-image case for backward compatibility
+ need_patching = [True] * len(image_sizes)
+ else:
+ need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
+ image_num_patches = [
+ image_size_to_num_patches(
+ image_size=imsize,
+ grid_pinpoints=self.config.image_grid_pinpoints,
+ patch_size=self.config.vision_config.image_size,
+ )
+ if should_patch
+ else 1
+ for imsize, should_patch in zip(image_sizes, need_patching)
+ ]
+ if pixel_values.dim() == 5:
+ # stacked if input is (batch_size, num_patches, num_channels, height, width)
+ _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
+ pixel_values = torch.cat(_pixel_values_list, dim=0)
+ elif pixel_values.dim() != 4:
+ # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
+ raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
+
+ image_features = self.vision_tower(pixel_values, output_hidden_states=True)
+ # If we have one vision feature layer, return the corresponding hidden states,
+ # otherwise, select the hidden states of each feature layer and concatenate them
+ if isinstance(vision_feature_layer, int):
+ selected_image_feature = image_features.hidden_states[vision_feature_layer]
+ else:
+ hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
+ selected_image_feature = torch.cat(hs_pool, dim=-1)
+
+ if vision_feature_select_strategy == "default":
+ selected_image_feature = selected_image_feature[:, 1:]
+ image_features = self.multi_modal_projector(selected_image_feature)
+ image_features = torch.split(image_features, image_num_patches, dim=0)
+
+ image_features, feature_lens = self.pack_image_features(
+ image_features,
+ image_sizes,
+ image_newline=self.image_newline,
+ vision_aspect_ratio=vision_aspect_ratio,
+ )
+ return image_features
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: torch.FloatTensor,
+ image_features: Optional[torch.FloatTensor] = None,
+ video_features: Optional[torch.FloatTensor] = None,
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_video_mask = special_video_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+ special_video_mask = input_ids == self.config.video_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
+ )
+
+ n_video_tokens = special_video_mask.sum()
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
+ raise ValueError(
+ f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
+ )
+
+ return special_image_mask, special_video_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_sizes_videos: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ vision_aspect_ratio: Optional[str] = None,
+ batch_num_images: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, LlavaOnevisionModelOutputWithPast]:
+ r"""
+ image_sizes_videos (`torch.LongTensor` of shape `(batch_size, frames, 2)`, *optional*):
+ The sizes of the videos in the batch, being (height, width) for each frame in the video.
+ vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
+ Aspect ratio used when processong image features. The default value is "anyres_max_9".
+ batch_num_images (`torch.LongTensor`, *optional*):
+ Number of images in each sample.
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+ vision_aspect_ratio = (
+ vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # Images are processed with Anyres
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values,
+ image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ batch_num_images=batch_num_images,
+ )
+ image_features = torch.cat(image_features, dim=0)
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ # Video are simply embedded and further pooled to decrease seq len
+ if pixel_values_videos is not None:
+ video_features = self.get_video_features(
+ pixel_values_videos,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+ image_newline = (
+ self.image_newline[None, None, :].repeat(video_features.shape[0], 1, 1).to(video_features.device)
+ )
+ video_features = torch.cat((video_features, image_newline), dim=1)
+ video_features = video_features.flatten(0, 1).to(inputs_embeds.device, inputs_embeds.dtype)
+ _, special_video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return LlavaOnevisionModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ video_hidden_states=video_features if pixel_values_videos is not None else None,
+ )
+
+ def get_video_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Union[int, list[int]],
+ vision_feature_select_strategy: str,
+ ):
+ """
+ Obtains video last hidden states from the vision tower, apply multimodal projection and pooling.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
+ The tensors corresponding to the input video.
+ vision_feature_layer (`Union[int, list[int]], *optional*, defaults to -2`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ Returns:
+ video_features (list[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
+ and are of shape `(num_videos, video_length, embed_dim)`).
+ """
+ batch_size, frames, channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.view(batch_size * frames, channels, height, width)
+ video_features = self.vision_tower(pixel_values, output_hidden_states=True)
+
+ # If we have one vision feature layer, return the corresponding hidden states,
+ # otherwise, select the hidden states of each feature layer and concatenate them
+ if isinstance(vision_feature_layer, int):
+ selected_video_feature = video_features.hidden_states[vision_feature_layer]
+ else:
+ hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
+ selected_video_feature = torch.cat(hs_pool, dim=-1)
+
+ if vision_feature_select_strategy == "default":
+ selected_video_feature = selected_video_feature[:, 1:]
+ video_features = self.multi_modal_projector(selected_video_feature)
+
+ video_features = self.apply_pooling(video_features)
+ video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1)
+
+ return video_features
+
+ def apply_pooling(self, image_features):
+ height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
+ batch_frames, seq_len, dim = image_features.shape
+ image_features = image_features.view(batch_frames, height, width, -1)
+ image_features = image_features.permute(0, 3, 1, 2).contiguous()
+
+ height, width = image_features.shape[2:]
+ scaled_shape = [math.ceil(height / 2), math.ceil(width / 2)]
+ image_features = nn.functional.interpolate(image_features, size=scaled_shape, mode="bilinear")
+
+ image_features = image_features.permute(0, 2, 3, 1)
+ image_features = image_features.view(batch_frames, -1, dim)
+ return image_features
+
+
+@auto_docstring(
+ custom_intro="""
+ The LLAVA-NeXT model which consists of a vision backbone and a language model.
+ """
+)
+class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^image_newline": "model.image_newline",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: LlavaOnevisionConfig):
+ super().__init__(config)
+ self.model = LlavaOnevisionModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
+ return self.model.pack_image_features(
+ image_features=image_features,
+ image_sizes=image_sizes,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_newline=image_newline,
+ )
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_sizes: torch.Tensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ ):
+ return self.model.get_image_features(
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+
+ # Make modules available through conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_sizes_videos: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ vision_aspect_ratio: Optional[str] = None,
+ batch_num_images: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, LlavaOnevisionCausalLMOutputWithPast]:
+ r"""
+ image_sizes_videos (`torch.LongTensor` of shape `(batch_size, frames, 2)`, *optional*):
+ The sizes of the videos in the batch, being (height, width) for each frame in the video.
+ vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
+ Aspect ratio used when processong image features. The default value is "anyres_max_9".
+ batch_num_images (`torch.LongTensor`, *optional*):
+ Number of images in each sample.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> import torch
+ >>> from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration
+
+ >>> model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", dtype="float16", device_map="cuda:0")
+ >>> processor = LlavaOnevisionProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
+
+ >>> conversation = [
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {"type": "text", "text": "What is shown in this image?"},
+ ... {"type": "image"},
+ ... ],
+ ... },
+ ... ]
+ >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
+
+ >>> image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> raw_image = Image.open(requests.get(image_file, stream=True).raw)
+ >>> inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.float16)
+
+ >>> output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ >>> processor.batch_decode(output, skip_special_tokens=True)[0]
+ "user\n\nWhat is shown in this image?\nassistant\ncat"
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+ vision_aspect_ratio = (
+ vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_sizes=image_sizes,
+ image_sizes_videos=image_sizes_videos,
+ vision_aspect_ratio=vision_aspect_ratio,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ batch_num_images=batch_num_images,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return LlavaOnevisionCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ video_hidden_states=outputs.video_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ image_sizes=None,
+ pixel_values_videos=None,
+ image_sizes_videos=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["image_sizes"] = image_sizes
+ model_inputs["pixel_values_videos"] = pixel_values_videos
+ model_inputs["image_sizes_videos"] = image_sizes_videos
+
+ return model_inputs
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+ def get_video_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ ):
+ return self.model.get_video_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+
+
+__all__ = ["LlavaOnevisionModel", "LlavaOnevisionForConditionalGeneration", "LlavaOnevisionPreTrainedModel"]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/modular_llava_onevision.py b/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/modular_llava_onevision.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4f64dee8e041096dc7b648f23341e3eb5c11a1c
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/modular_llava_onevision.py
@@ -0,0 +1,744 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torchvision.transforms.v2 import functional as F
+
+from transformers.models.llava_next.image_processing_llava_next_fast import LlavaNextImageProcessorFast
+from transformers.models.llava_next_video.modeling_llava_next_video import (
+ LlavaNextVideoCausalLMOutputWithPast,
+ LlavaNextVideoForConditionalGeneration,
+ LlavaNextVideoModel,
+ LlavaNextVideoModelOutputWithPast,
+ LlavaNextVideoPreTrainedModel,
+ TransformersKwargs,
+ get_anyres_image_grid_shape,
+ image_size_to_num_patches,
+ unpad_image,
+)
+
+from ...cache_utils import Cache
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import DefaultFastImageProcessorKwargs, group_images_by_shape, reorder_images
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+ get_image_size,
+)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ auto_docstring,
+ can_return_tuple,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class LlavaOnevisionFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ """
+ image_grid_pinpoints (`list[list[int]]`, *optional*):
+ A list of possible resolutions to use for processing high resolution images. The best resolution is selected
+ based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
+ method.
+ """
+
+ image_grid_pinpoints: Optional[list[list[int]]]
+
+
+class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ size = {"height": 384, "width": 384}
+ crop_size = None
+ default_to_square = False
+ do_resize = True
+ do_center_crop = None
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+ do_pad = True
+ image_grid_pinpoints = [[384, 384], [384, 768], [384, 1152], [384, 1536], [384, 1920], [384, 2304], [768, 384], [768, 768], [768, 1152], [768, 1536], [768, 1920], [768, 2304], [1152, 384], [1152, 768], [1152, 1152], [1152, 1536], [1152, 1920], [1152, 2304], [1536, 384], [1536, 768], [1536, 1152], [1536, 1536], [1536, 1920], [1536, 2304], [1920, 384], [1920, 768], [1920, 1152], [1920, 1536], [1920, 1920], [1920, 2304], [2304, 384], [2304, 768], [2304, 1152], [2304, 1536], [2304, 1920], [2304, 2304]] # fmt: skip
+ model_input_names = ["pixel_values", "image_sizes", "batch_num_images"]
+
+ # Copied from transformers.models.llava.image_processing_llava_fast.LlavaImageProcessorFast.pad_to_square
+ def pad_to_square(
+ self,
+ images: "torch.Tensor",
+ background_color: Union[int, tuple[int, int, int]] = 0,
+ ) -> "torch.Tensor":
+ """
+ Pads an image to a square based on the longest edge.
+
+ Args:
+ images (`np.ndarray`):
+ The images to pad.
+ background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
+ The color to use for the padding. Can be an integer for single channel or a
+ tuple of integers representing for multi-channel images. If passed as integer
+ in multi-channel mode, it will default to `0` in subsequent channels.
+ Returns:
+ `torch.Tensor`: The padded images.
+ """
+ height, width = get_image_size(images, ChannelDimension.FIRST)
+
+ if height == width:
+ return images
+
+ num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
+ if isinstance(background_color, int):
+ background_color = [background_color] + [0] * (num_channels - 1)
+ elif len(background_color) != num_channels:
+ raise ValueError(
+ f"background_color must have no more than {num_channels} elements to match the number of channels"
+ )
+
+ max_dim = max(height, width)
+ paste_x_left = (max_dim - width) // 2
+ paste_y_left = (max_dim - height) // 2
+ paste_x_right = max_dim - width - paste_x_left
+ paste_y_right = max_dim - height - paste_y_left
+ padded_images = F.pad(
+ images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color
+ )
+
+ return padded_images
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaOnevisionFastImageProcessorKwargs]) -> BatchFeature:
+ if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)):
+ # if the first element is a list, we assume that all elements are lists
+ batch_num_images = [len(x) for x in images]
+ elif isinstance(images, (tuple, list)):
+ # treat this as a single-image case for backward compatibility
+ batch_num_images = [1] * len(images)
+ else:
+ batch_num_images = [1]
+ kwargs["batch_num_images"] = batch_num_images
+ return super().preprocess(images, **kwargs)
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ image_grid_pinpoints: list[list[int]],
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ do_pad: bool,
+ batch_num_images: list[int],
+ disable_grouping: Optional[bool],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ processed_images = []
+ image_sizes = []
+
+ # only single image patching is supported
+ need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
+
+ # Determine the size tuple
+ if size and size.height and size.width:
+ size_tuple = (size.height, size.width)
+ else:
+ size_tuple = (size.shortest_edge, size.shortest_edge)
+
+ # Determine the patch size
+ if crop_size and crop_size.height:
+ patch_size = crop_size.height
+ elif size and size.height:
+ patch_size = size.height
+ else:
+ patch_size = size.shortest_edge
+
+ for i, image in enumerate(images):
+ if need_patching[i]:
+ image_patches = self._get_image_patches(
+ image,
+ image_grid_pinpoints,
+ size=size_tuple,
+ patch_size=patch_size,
+ interpolation=interpolation,
+ )
+ else:
+ padded_image = self.pad_to_square(
+ images=image, background_color=tuple(int(x * 255) for x in self.image_mean)
+ )
+ image_patches = [padded_image]
+
+ # Group images by size for batched processing
+ processed_image_patches_grouped = {}
+ grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
+ image_patches, disable_grouping=disable_grouping
+ )
+ for shape, stacked_image_patches in grouped_image_patches.items():
+ if do_resize:
+ stacked_image_patches = self.resize(
+ image=stacked_image_patches,
+ size=size,
+ interpolation=interpolation,
+ )
+ if do_center_crop:
+ stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
+ # Fused rescale and normalize
+ stacked_image_patches = self.rescale_and_normalize(
+ stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_image_patches_grouped[shape] = stacked_image_patches
+ processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
+ processed_image_patches = (
+ torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
+ )
+ processed_images.append(processed_image_patches)
+ image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
+
+ if do_pad:
+ processed_images = self._pad_for_batching(processed_images)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+ return BatchFeature(
+ data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images},
+ tensor_type=return_tensors,
+ )
+
+
+class LlavaOnevisionModelOutputWithPast(LlavaNextVideoModelOutputWithPast):
+ pass
+
+
+class LlavaOnevisionCausalLMOutputWithPast(LlavaNextVideoCausalLMOutputWithPast):
+ pass
+
+
+class LlavaOnevisionPreTrainedModel(LlavaNextVideoPreTrainedModel):
+ pass
+
+
+class LlavaOnevisionModel(LlavaNextVideoModel):
+ def __init__(self, config):
+ super().__init__(config)
+ del self.vision_resampler
+
+ def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"):
+ """
+ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
+
+ Args:
+ image_features (`list[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
+ List of image feature tensor, each contains all the visual feature of all patches.
+ image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
+ Actual image size of each images (H, W).
+ image_newline (`torch.Tensor` of shape `(embed_dim)`)
+ New line embedding vector.
+ vision_aspect_ratio (`str`, *optional*, "anyres_max_9"):
+ Aspect ratio used when processong image features. The default value is "anyres_max_9".
+ Returns:
+ image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
+ feature_lens (`list[int]`)
+ token length of each image in image_features
+ """
+ new_image_features = []
+ feature_lens = []
+ for image_idx, image_feature in enumerate(image_features):
+ if image_feature.shape[0] > 1:
+ base_image_feature = image_feature[0]
+ image_feature = image_feature[1:]
+ height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
+ if height * width != base_image_feature.shape[0]:
+ raise ValueError("The number of patches is not consistent with the image size.")
+ num_patch_height, num_patch_width = get_anyres_image_grid_shape(
+ image_sizes[image_idx],
+ self.config.image_grid_pinpoints,
+ self.config.vision_config.image_size,
+ )
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
+ max_num_patches = int(vision_aspect_ratio.strip("anyres_max_"))
+ channels, curr_height, curr_width = image_feature.shape
+ ratio = math.sqrt(curr_height * curr_width / (max_num_patches * height**2))
+ if ratio > 1.1:
+ image_feature = image_feature[None]
+ image_feature = nn.functional.interpolate(
+ image_feature, [int(curr_height // ratio), int(curr_width // ratio)], mode="bilinear"
+ )[0]
+ if image_newline is not None:
+ image_feature = torch.cat(
+ (
+ image_feature,
+ image_newline[:, None, None]
+ .expand(*image_feature.shape[:-1], 1)
+ .to(image_feature.device, image_feature.dtype),
+ ),
+ dim=-1,
+ )
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
+ else:
+ image_feature = image_feature[0]
+ if image_newline is not None:
+ image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
+ new_image_features.append(image_feature)
+ feature_lens.append(image_feature.size(0))
+ feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device)
+ return new_image_features, feature_lens
+
+ def apply_pooling(self, image_features):
+ height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
+ batch_frames, seq_len, dim = image_features.shape
+ image_features = image_features.view(batch_frames, height, width, -1)
+ image_features = image_features.permute(0, 3, 1, 2).contiguous()
+
+ height, width = image_features.shape[2:]
+ scaled_shape = [math.ceil(height / 2), math.ceil(width / 2)]
+ image_features = nn.functional.interpolate(image_features, size=scaled_shape, mode="bilinear")
+
+ image_features = image_features.permute(0, 2, 3, 1)
+ image_features = image_features.view(batch_frames, -1, dim)
+ return image_features
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_sizes: torch.Tensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ vision_aspect_ratio: Optional[str] = None,
+ batch_num_images: Optional[torch.LongTensor] = None,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
+ The tensors corresponding to the input images.
+ image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
+ Actual image size of each images (H, W).
+ vision_feature_layer (`Union[int, list[int]]`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ batch_num_images (`torch.LongTensor`, *optional*):
+ Number of images in each sample.
+ Returns:
+ image_features (list[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
+ and are of shape `(num_patches, image_length, embed_dim)`).
+ """
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+ vision_aspect_ratio = (
+ vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
+ )
+
+ # ! infer image_num_patches from image_sizes
+ if batch_num_images is None:
+ # treat this as a single-image case for backward compatibility
+ need_patching = [True] * len(image_sizes)
+ else:
+ need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
+ image_num_patches = [
+ image_size_to_num_patches(
+ image_size=imsize,
+ grid_pinpoints=self.config.image_grid_pinpoints,
+ patch_size=self.config.vision_config.image_size,
+ )
+ if should_patch
+ else 1
+ for imsize, should_patch in zip(image_sizes, need_patching)
+ ]
+ if pixel_values.dim() == 5:
+ # stacked if input is (batch_size, num_patches, num_channels, height, width)
+ _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
+ pixel_values = torch.cat(_pixel_values_list, dim=0)
+ elif pixel_values.dim() != 4:
+ # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
+ raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
+
+ image_features = self.vision_tower(pixel_values, output_hidden_states=True)
+ # If we have one vision feature layer, return the corresponding hidden states,
+ # otherwise, select the hidden states of each feature layer and concatenate them
+ if isinstance(vision_feature_layer, int):
+ selected_image_feature = image_features.hidden_states[vision_feature_layer]
+ else:
+ hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
+ selected_image_feature = torch.cat(hs_pool, dim=-1)
+
+ if vision_feature_select_strategy == "default":
+ selected_image_feature = selected_image_feature[:, 1:]
+ image_features = self.multi_modal_projector(selected_image_feature)
+ image_features = torch.split(image_features, image_num_patches, dim=0)
+
+ image_features, feature_lens = self.pack_image_features(
+ image_features,
+ image_sizes,
+ image_newline=self.image_newline,
+ vision_aspect_ratio=vision_aspect_ratio,
+ )
+ return image_features
+
+ def get_video_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Union[int, list[int]],
+ vision_feature_select_strategy: str,
+ ):
+ """
+ Obtains video last hidden states from the vision tower, apply multimodal projection and pooling.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
+ The tensors corresponding to the input video.
+ vision_feature_layer (`Union[int, list[int]], *optional*, defaults to -2`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ Returns:
+ video_features (list[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
+ and are of shape `(num_videos, video_length, embed_dim)`).
+ """
+ batch_size, frames, channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.view(batch_size * frames, channels, height, width)
+ video_features = self.vision_tower(pixel_values, output_hidden_states=True)
+
+ # If we have one vision feature layer, return the corresponding hidden states,
+ # otherwise, select the hidden states of each feature layer and concatenate them
+ if isinstance(vision_feature_layer, int):
+ selected_video_feature = video_features.hidden_states[vision_feature_layer]
+ else:
+ hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
+ selected_video_feature = torch.cat(hs_pool, dim=-1)
+
+ if vision_feature_select_strategy == "default":
+ selected_video_feature = selected_video_feature[:, 1:]
+ video_features = self.multi_modal_projector(selected_video_feature)
+
+ video_features = self.apply_pooling(video_features)
+ video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1)
+
+ return video_features
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_sizes_videos: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ vision_aspect_ratio: Optional[str] = None,
+ batch_num_images: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, LlavaOnevisionModelOutputWithPast]:
+ r"""
+ image_sizes_videos (`torch.LongTensor` of shape `(batch_size, frames, 2)`, *optional*):
+ The sizes of the videos in the batch, being (height, width) for each frame in the video.
+ vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
+ Aspect ratio used when processong image features. The default value is "anyres_max_9".
+ batch_num_images (`torch.LongTensor`, *optional*):
+ Number of images in each sample.
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+ vision_aspect_ratio = (
+ vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # Images are processed with Anyres
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values,
+ image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ batch_num_images=batch_num_images,
+ )
+ image_features = torch.cat(image_features, dim=0)
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask, _ = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ # Video are simply embedded and further pooled to decrease seq len
+ if pixel_values_videos is not None:
+ video_features = self.get_video_features(
+ pixel_values_videos,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+ image_newline = (
+ self.image_newline[None, None, :].repeat(video_features.shape[0], 1, 1).to(video_features.device)
+ )
+ video_features = torch.cat((video_features, image_newline), dim=1)
+ video_features = video_features.flatten(0, 1).to(inputs_embeds.device, inputs_embeds.dtype)
+ _, special_video_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return LlavaOnevisionModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ video_hidden_states=video_features if pixel_values_videos is not None else None,
+ )
+
+
+class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGeneration):
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_sizes_videos: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ vision_aspect_ratio: Optional[str] = None,
+ batch_num_images: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, LlavaOnevisionCausalLMOutputWithPast]:
+ r"""
+ image_sizes_videos (`torch.LongTensor` of shape `(batch_size, frames, 2)`, *optional*):
+ The sizes of the videos in the batch, being (height, width) for each frame in the video.
+ vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
+ Aspect ratio used when processong image features. The default value is "anyres_max_9".
+ batch_num_images (`torch.LongTensor`, *optional*):
+ Number of images in each sample.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> import torch
+ >>> from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration
+
+ >>> model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", dtype="float16", device_map="cuda:0")
+ >>> processor = LlavaOnevisionProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
+
+ >>> conversation = [
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {"type": "text", "text": "What is shown in this image?"},
+ ... {"type": "image"},
+ ... ],
+ ... },
+ ... ]
+ >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
+
+ >>> image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> raw_image = Image.open(requests.get(image_file, stream=True).raw)
+ >>> inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.float16)
+
+ >>> output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ >>> processor.batch_decode(output, skip_special_tokens=True)[0]
+ "user\n\nWhat is shown in this image?\nassistant\ncat"
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+ vision_aspect_ratio = (
+ vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_sizes=image_sizes,
+ image_sizes_videos=image_sizes_videos,
+ vision_aspect_ratio=vision_aspect_ratio,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ batch_num_images=batch_num_images,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return LlavaOnevisionCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ video_hidden_states=outputs.video_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ image_sizes=None,
+ pixel_values_videos=None,
+ image_sizes_videos=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["image_sizes"] = image_sizes
+ model_inputs["pixel_values_videos"] = pixel_values_videos
+ model_inputs["image_sizes_videos"] = image_sizes_videos
+
+ return model_inputs
+
+
+__all__ = [
+ "LlavaOnevisionImageProcessorFast",
+ "LlavaOnevisionModel",
+ "LlavaOnevisionForConditionalGeneration",
+ "LlavaOnevisionPreTrainedModel",
+]
diff --git a/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/processing_llava_onevision.py b/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/processing_llava_onevision.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fadc6af3067e3478c3f21626e4f9d3f99e64b00
--- /dev/null
+++ b/venv/lib/python3.13/site-packages/transformers/models/llava_onevision/processing_llava_onevision.py
@@ -0,0 +1,333 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for LLaVa-Onevision.
+"""
+
+import math
+from collections.abc import Iterable
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_processing_utils import select_best_resolution
+from ...image_utils import ImageInput, get_image_size, to_numpy_array
+from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import logging
+from ...video_utils import VideoInput
+
+
+logger = logging.get_logger(__name__)
+
+
+class LlavaOnevisionProcessorKwargs(ProcessingKwargs, total=False):
+ # see processing_utils.ProcessingKwargs documentation for usage.
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "return_mm_token_type_ids": False,
+ },
+ "image_kwargs": {},
+ "videos_kwargs": {},
+ }
+
+
+class LlavaOnevisionProcessor(ProcessorMixin):
+ r"""
+ Constructs a LLaVa-Onevision processor which wraps a LLaVa-Onevision video processor, LLaVa-NeXT image processor and a LLaMa tokenizer into a single processor.
+
+ [`LlavaNextProcessor`] offers all the functionalities of [`LlavaOnevisionVideoProcessor`], [`LlavaOnevisionImageProcessor`] and [`LlamaTokenizerFast`]. See the
+ [`~LlavaOnevisionVideoProcessor.__call__`], [`~LlavaNextProcessor.__call__`] and [`~LlavaNextProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`LlamaTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ video_processor ([`LlavaOnevisionVideoProcessor`], *optional*):
+ The video processor is a required input.
+ num_image_tokens (`int`, *optional*):
+ Number of image tokens for one imagethat will be returned by vision tower.
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Should be same as in model's config
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ image_token (`str`, *optional*, defaults to `""`):
+ Special token used to denote image location.
+ video_token (`str`, *optional*, defaults to `"