Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- venv/lib/python3.13/site-packages/transformers/models/albert/__init__.py +31 -0
- venv/lib/python3.13/site-packages/transformers/models/albert/configuration_albert.py +170 -0
- venv/lib/python3.13/site-packages/transformers/models/albert/modeling_albert.py +1349 -0
- venv/lib/python3.13/site-packages/transformers/models/albert/modeling_flax_albert.py +1132 -0
- venv/lib/python3.13/site-packages/transformers/models/albert/modeling_tf_albert.py +1572 -0
- venv/lib/python3.13/site-packages/transformers/models/albert/tokenization_albert.py +320 -0
- venv/lib/python3.13/site-packages/transformers/models/albert/tokenization_albert_fast.py +178 -0
- venv/lib/python3.13/site-packages/transformers/models/apertus/__init__.py +32 -0
- venv/lib/python3.13/site-packages/transformers/models/apertus/configuration_apertus.py +214 -0
- venv/lib/python3.13/site-packages/transformers/models/apertus/modeling_apertus.py +488 -0
- venv/lib/python3.13/site-packages/transformers/models/apertus/modular_apertus.py +371 -0
- venv/lib/python3.13/site-packages/transformers/models/arcee/__init__.py +27 -0
- venv/lib/python3.13/site-packages/transformers/models/arcee/configuration_arcee.py +201 -0
- venv/lib/python3.13/site-packages/transformers/models/arcee/modeling_arcee.py +506 -0
- venv/lib/python3.13/site-packages/transformers/models/arcee/modular_arcee.py +225 -0
- venv/lib/python3.13/site-packages/transformers/models/aria/__init__.py +30 -0
- venv/lib/python3.13/site-packages/transformers/models/aria/configuration_aria.py +307 -0
- venv/lib/python3.13/site-packages/transformers/models/aria/image_processing_aria.py +527 -0
- venv/lib/python3.13/site-packages/transformers/models/aria/modeling_aria.py +1275 -0
- venv/lib/python3.13/site-packages/transformers/models/aria/modular_aria.py +1610 -0
- venv/lib/python3.13/site-packages/transformers/models/aria/processing_aria.py +189 -0
- venv/lib/python3.13/site-packages/transformers/models/auto/__init__.py +35 -0
- venv/lib/python3.13/site-packages/transformers/models/auto/auto_factory.py +882 -0
- venv/lib/python3.13/site-packages/transformers/models/auto/configuration_auto.py +1404 -0
- venv/lib/python3.13/site-packages/transformers/models/auto/feature_extraction_auto.py +422 -0
- venv/lib/python3.13/site-packages/transformers/models/auto/image_processing_auto.py +688 -0
- venv/lib/python3.13/site-packages/transformers/models/auto/modeling_auto.py +0 -0
- venv/lib/python3.13/site-packages/transformers/models/auto/modeling_flax_auto.py +413 -0
- venv/lib/python3.13/site-packages/transformers/models/auto/modeling_tf_auto.py +776 -0
- venv/lib/python3.13/site-packages/transformers/models/auto/processing_auto.py +443 -0
- venv/lib/python3.13/site-packages/transformers/models/auto/tokenization_auto.py +1235 -0
- venv/lib/python3.13/site-packages/transformers/models/auto/video_processing_auto.py +393 -0
- venv/lib/python3.13/site-packages/transformers/models/aya_vision/__init__.py +28 -0
- venv/lib/python3.13/site-packages/transformers/models/aya_vision/configuration_aya_vision.py +110 -0
- venv/lib/python3.13/site-packages/transformers/models/aya_vision/modeling_aya_vision.py +518 -0
- venv/lib/python3.13/site-packages/transformers/models/aya_vision/modular_aya_vision.py +297 -0
- venv/lib/python3.13/site-packages/transformers/models/aya_vision/processing_aya_vision.py +257 -0
- venv/lib/python3.13/site-packages/transformers/models/barthez/__init__.py +27 -0
- venv/lib/python3.13/site-packages/transformers/models/barthez/tokenization_barthez.py +291 -0
- venv/lib/python3.13/site-packages/transformers/models/barthez/tokenization_barthez_fast.py +193 -0
- venv/lib/python3.13/site-packages/transformers/models/bert_japanese/__init__.py +26 -0
- venv/lib/python3.13/site-packages/transformers/models/bert_japanese/tokenization_bert_japanese.py +952 -0
- venv/lib/python3.13/site-packages/transformers/models/bertweet/__init__.py +26 -0
- venv/lib/python3.13/site-packages/transformers/models/bertweet/tokenization_bertweet.py +769 -0
- venv/lib/python3.13/site-packages/transformers/models/biogpt/__init__.py +28 -0
- venv/lib/python3.13/site-packages/transformers/models/biogpt/configuration_biogpt.py +134 -0
- venv/lib/python3.13/site-packages/transformers/models/biogpt/modeling_biogpt.py +967 -0
- venv/lib/python3.13/site-packages/transformers/models/biogpt/modular_biogpt.py +789 -0
- venv/lib/python3.13/site-packages/transformers/models/biogpt/tokenization_biogpt.py +331 -0
- venv/lib/python3.13/site-packages/transformers/models/bit/__init__.py +29 -0
venv/lib/python3.13/site-packages/transformers/models/albert/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_albert import *
|
| 22 |
+
from .modeling_albert import *
|
| 23 |
+
from .modeling_flax_albert import *
|
| 24 |
+
from .modeling_tf_albert import *
|
| 25 |
+
from .tokenization_albert import *
|
| 26 |
+
from .tokenization_albert_fast import *
|
| 27 |
+
else:
|
| 28 |
+
import sys
|
| 29 |
+
|
| 30 |
+
_file = globals()["__file__"]
|
| 31 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
venv/lib/python3.13/site-packages/transformers/models/albert/configuration_albert.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""ALBERT model configuration"""
|
| 17 |
+
|
| 18 |
+
from collections import OrderedDict
|
| 19 |
+
from collections.abc import Mapping
|
| 20 |
+
|
| 21 |
+
from ...configuration_utils import PretrainedConfig
|
| 22 |
+
from ...onnx import OnnxConfig
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AlbertConfig(PretrainedConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`AlbertModel`] or a [`TFAlbertModel`]. It is used
|
| 28 |
+
to instantiate an ALBERT model according to the specified arguments, defining the model architecture. Instantiating
|
| 29 |
+
a configuration with the defaults will yield a similar configuration to that of the ALBERT
|
| 30 |
+
[albert/albert-xxlarge-v2](https://huggingface.co/albert/albert-xxlarge-v2) architecture.
|
| 31 |
+
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size (`int`, *optional*, defaults to 30000):
|
| 37 |
+
Vocabulary size of the ALBERT model. Defines the number of different tokens that can be represented by the
|
| 38 |
+
`inputs_ids` passed when calling [`AlbertModel`] or [`TFAlbertModel`].
|
| 39 |
+
embedding_size (`int`, *optional*, defaults to 128):
|
| 40 |
+
Dimensionality of vocabulary embeddings.
|
| 41 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 42 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 43 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 44 |
+
Number of hidden layers in the Transformer encoder.
|
| 45 |
+
num_hidden_groups (`int`, *optional*, defaults to 1):
|
| 46 |
+
Number of groups for the hidden layers, parameters in the same group are shared.
|
| 47 |
+
num_attention_heads (`int`, *optional*, defaults to 64):
|
| 48 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 49 |
+
intermediate_size (`int`, *optional*, defaults to 16384):
|
| 50 |
+
The dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
| 51 |
+
inner_group_num (`int`, *optional*, defaults to 1):
|
| 52 |
+
The number of inner repetition of attention and ffn.
|
| 53 |
+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu_new"`):
|
| 54 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 55 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 56 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0):
|
| 57 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 58 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0):
|
| 59 |
+
The dropout ratio for the attention probabilities.
|
| 60 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 61 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 62 |
+
(e.g., 512 or 1024 or 2048).
|
| 63 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
| 64 |
+
The vocabulary size of the `token_type_ids` passed when calling [`AlbertModel`] or [`TFAlbertModel`].
|
| 65 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 66 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 67 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 68 |
+
The epsilon used by the layer normalization layers.
|
| 69 |
+
classifier_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 70 |
+
The dropout ratio for attached classifiers.
|
| 71 |
+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
| 72 |
+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
| 73 |
+
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
| 74 |
+
[Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
|
| 75 |
+
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
| 76 |
+
with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
|
| 77 |
+
pad_token_id (`int`, *optional*, defaults to 0):
|
| 78 |
+
Padding token id.
|
| 79 |
+
bos_token_id (`int`, *optional*, defaults to 2):
|
| 80 |
+
Beginning of stream token id.
|
| 81 |
+
eos_token_id (`int`, *optional*, defaults to 3):
|
| 82 |
+
End of stream token id.
|
| 83 |
+
|
| 84 |
+
Examples:
|
| 85 |
+
|
| 86 |
+
```python
|
| 87 |
+
>>> from transformers import AlbertConfig, AlbertModel
|
| 88 |
+
|
| 89 |
+
>>> # Initializing an ALBERT-xxlarge style configuration
|
| 90 |
+
>>> albert_xxlarge_configuration = AlbertConfig()
|
| 91 |
+
|
| 92 |
+
>>> # Initializing an ALBERT-base style configuration
|
| 93 |
+
>>> albert_base_configuration = AlbertConfig(
|
| 94 |
+
... hidden_size=768,
|
| 95 |
+
... num_attention_heads=12,
|
| 96 |
+
... intermediate_size=3072,
|
| 97 |
+
... )
|
| 98 |
+
|
| 99 |
+
>>> # Initializing a model (with random weights) from the ALBERT-base style configuration
|
| 100 |
+
>>> model = AlbertModel(albert_xxlarge_configuration)
|
| 101 |
+
|
| 102 |
+
>>> # Accessing the model configuration
|
| 103 |
+
>>> configuration = model.config
|
| 104 |
+
```"""
|
| 105 |
+
|
| 106 |
+
model_type = "albert"
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
vocab_size=30000,
|
| 111 |
+
embedding_size=128,
|
| 112 |
+
hidden_size=4096,
|
| 113 |
+
num_hidden_layers=12,
|
| 114 |
+
num_hidden_groups=1,
|
| 115 |
+
num_attention_heads=64,
|
| 116 |
+
intermediate_size=16384,
|
| 117 |
+
inner_group_num=1,
|
| 118 |
+
hidden_act="gelu_new",
|
| 119 |
+
hidden_dropout_prob=0,
|
| 120 |
+
attention_probs_dropout_prob=0,
|
| 121 |
+
max_position_embeddings=512,
|
| 122 |
+
type_vocab_size=2,
|
| 123 |
+
initializer_range=0.02,
|
| 124 |
+
layer_norm_eps=1e-12,
|
| 125 |
+
classifier_dropout_prob=0.1,
|
| 126 |
+
position_embedding_type="absolute",
|
| 127 |
+
pad_token_id=0,
|
| 128 |
+
bos_token_id=2,
|
| 129 |
+
eos_token_id=3,
|
| 130 |
+
**kwargs,
|
| 131 |
+
):
|
| 132 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 133 |
+
|
| 134 |
+
self.vocab_size = vocab_size
|
| 135 |
+
self.embedding_size = embedding_size
|
| 136 |
+
self.hidden_size = hidden_size
|
| 137 |
+
self.num_hidden_layers = num_hidden_layers
|
| 138 |
+
self.num_hidden_groups = num_hidden_groups
|
| 139 |
+
self.num_attention_heads = num_attention_heads
|
| 140 |
+
self.inner_group_num = inner_group_num
|
| 141 |
+
self.hidden_act = hidden_act
|
| 142 |
+
self.intermediate_size = intermediate_size
|
| 143 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 144 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 145 |
+
self.max_position_embeddings = max_position_embeddings
|
| 146 |
+
self.type_vocab_size = type_vocab_size
|
| 147 |
+
self.initializer_range = initializer_range
|
| 148 |
+
self.layer_norm_eps = layer_norm_eps
|
| 149 |
+
self.classifier_dropout_prob = classifier_dropout_prob
|
| 150 |
+
self.position_embedding_type = position_embedding_type
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert
|
| 154 |
+
class AlbertOnnxConfig(OnnxConfig):
|
| 155 |
+
@property
|
| 156 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 157 |
+
if self.task == "multiple-choice":
|
| 158 |
+
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
| 159 |
+
else:
|
| 160 |
+
dynamic_axis = {0: "batch", 1: "sequence"}
|
| 161 |
+
return OrderedDict(
|
| 162 |
+
[
|
| 163 |
+
("input_ids", dynamic_axis),
|
| 164 |
+
("attention_mask", dynamic_axis),
|
| 165 |
+
("token_type_ids", dynamic_axis),
|
| 166 |
+
]
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
__all__ = ["AlbertConfig", "AlbertOnnxConfig"]
|
venv/lib/python3.13/site-packages/transformers/models/albert/modeling_albert.py
ADDED
|
@@ -0,0 +1,1349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch ALBERT model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import os
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Optional, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch import nn
|
| 24 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 25 |
+
|
| 26 |
+
from ...activations import ACT2FN
|
| 27 |
+
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
|
| 28 |
+
from ...modeling_outputs import (
|
| 29 |
+
BaseModelOutput,
|
| 30 |
+
BaseModelOutputWithPooling,
|
| 31 |
+
MaskedLMOutput,
|
| 32 |
+
MultipleChoiceModelOutput,
|
| 33 |
+
QuestionAnsweringModelOutput,
|
| 34 |
+
SequenceClassifierOutput,
|
| 35 |
+
TokenClassifierOutput,
|
| 36 |
+
)
|
| 37 |
+
from ...modeling_utils import PreTrainedModel
|
| 38 |
+
from ...pytorch_utils import (
|
| 39 |
+
apply_chunking_to_forward,
|
| 40 |
+
find_pruneable_heads_and_indices,
|
| 41 |
+
prune_linear_layer,
|
| 42 |
+
)
|
| 43 |
+
from ...utils import ModelOutput, auto_docstring, logging
|
| 44 |
+
from .configuration_albert import AlbertConfig
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
logger = logging.get_logger(__name__)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
| 51 |
+
"""Load tf checkpoints in a pytorch model."""
|
| 52 |
+
try:
|
| 53 |
+
import re
|
| 54 |
+
|
| 55 |
+
import numpy as np
|
| 56 |
+
import tensorflow as tf
|
| 57 |
+
except ImportError:
|
| 58 |
+
logger.error(
|
| 59 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
| 60 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
| 61 |
+
)
|
| 62 |
+
raise
|
| 63 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
| 64 |
+
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
| 65 |
+
# Load weights from TF model
|
| 66 |
+
init_vars = tf.train.list_variables(tf_path)
|
| 67 |
+
names = []
|
| 68 |
+
arrays = []
|
| 69 |
+
for name, shape in init_vars:
|
| 70 |
+
logger.info(f"Loading TF weight {name} with shape {shape}")
|
| 71 |
+
array = tf.train.load_variable(tf_path, name)
|
| 72 |
+
names.append(name)
|
| 73 |
+
arrays.append(array)
|
| 74 |
+
|
| 75 |
+
for name, array in zip(names, arrays):
|
| 76 |
+
print(name)
|
| 77 |
+
|
| 78 |
+
for name, array in zip(names, arrays):
|
| 79 |
+
original_name = name
|
| 80 |
+
|
| 81 |
+
# If saved from the TF HUB module
|
| 82 |
+
name = name.replace("module/", "")
|
| 83 |
+
|
| 84 |
+
# Renaming and simplifying
|
| 85 |
+
name = name.replace("ffn_1", "ffn")
|
| 86 |
+
name = name.replace("bert/", "albert/")
|
| 87 |
+
name = name.replace("attention_1", "attention")
|
| 88 |
+
name = name.replace("transform/", "")
|
| 89 |
+
name = name.replace("LayerNorm_1", "full_layer_layer_norm")
|
| 90 |
+
name = name.replace("LayerNorm", "attention/LayerNorm")
|
| 91 |
+
name = name.replace("transformer/", "")
|
| 92 |
+
|
| 93 |
+
# The feed forward layer had an 'intermediate' step which has been abstracted away
|
| 94 |
+
name = name.replace("intermediate/dense/", "")
|
| 95 |
+
name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")
|
| 96 |
+
|
| 97 |
+
# ALBERT attention was split between self and output which have been abstracted away
|
| 98 |
+
name = name.replace("/output/", "/")
|
| 99 |
+
name = name.replace("/self/", "/")
|
| 100 |
+
|
| 101 |
+
# The pooler is a linear layer
|
| 102 |
+
name = name.replace("pooler/dense", "pooler")
|
| 103 |
+
|
| 104 |
+
# The classifier was simplified to predictions from cls/predictions
|
| 105 |
+
name = name.replace("cls/predictions", "predictions")
|
| 106 |
+
name = name.replace("predictions/attention", "predictions")
|
| 107 |
+
|
| 108 |
+
# Naming was changed to be more explicit
|
| 109 |
+
name = name.replace("embeddings/attention", "embeddings")
|
| 110 |
+
name = name.replace("inner_group_", "albert_layers/")
|
| 111 |
+
name = name.replace("group_", "albert_layer_groups/")
|
| 112 |
+
|
| 113 |
+
# Classifier
|
| 114 |
+
if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
|
| 115 |
+
name = "classifier/" + name
|
| 116 |
+
|
| 117 |
+
# No ALBERT model currently handles the next sentence prediction task
|
| 118 |
+
if "seq_relationship" in name:
|
| 119 |
+
name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
|
| 120 |
+
name = name.replace("weights", "weight")
|
| 121 |
+
|
| 122 |
+
name = name.split("/")
|
| 123 |
+
|
| 124 |
+
# Ignore the gradients applied by the LAMB/ADAM optimizers.
|
| 125 |
+
if (
|
| 126 |
+
"adam_m" in name
|
| 127 |
+
or "adam_v" in name
|
| 128 |
+
or "AdamWeightDecayOptimizer" in name
|
| 129 |
+
or "AdamWeightDecayOptimizer_1" in name
|
| 130 |
+
or "global_step" in name
|
| 131 |
+
):
|
| 132 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
pointer = model
|
| 136 |
+
for m_name in name:
|
| 137 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
| 138 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
| 139 |
+
else:
|
| 140 |
+
scope_names = [m_name]
|
| 141 |
+
|
| 142 |
+
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
| 143 |
+
pointer = getattr(pointer, "weight")
|
| 144 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
| 145 |
+
pointer = getattr(pointer, "bias")
|
| 146 |
+
elif scope_names[0] == "output_weights":
|
| 147 |
+
pointer = getattr(pointer, "weight")
|
| 148 |
+
elif scope_names[0] == "squad":
|
| 149 |
+
pointer = getattr(pointer, "classifier")
|
| 150 |
+
else:
|
| 151 |
+
try:
|
| 152 |
+
pointer = getattr(pointer, scope_names[0])
|
| 153 |
+
except AttributeError:
|
| 154 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
| 155 |
+
continue
|
| 156 |
+
if len(scope_names) >= 2:
|
| 157 |
+
num = int(scope_names[1])
|
| 158 |
+
pointer = pointer[num]
|
| 159 |
+
|
| 160 |
+
if m_name[-11:] == "_embeddings":
|
| 161 |
+
pointer = getattr(pointer, "weight")
|
| 162 |
+
elif m_name == "kernel":
|
| 163 |
+
array = np.transpose(array)
|
| 164 |
+
try:
|
| 165 |
+
if pointer.shape != array.shape:
|
| 166 |
+
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
| 167 |
+
except ValueError as e:
|
| 168 |
+
e.args += (pointer.shape, array.shape)
|
| 169 |
+
raise
|
| 170 |
+
print(f"Initialize PyTorch weight {name} from {original_name}")
|
| 171 |
+
pointer.data = torch.from_numpy(array)
|
| 172 |
+
|
| 173 |
+
return model
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class AlbertEmbeddings(nn.Module):
|
| 177 |
+
"""
|
| 178 |
+
Construct the embeddings from word, position and token_type embeddings.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
def __init__(self, config: AlbertConfig):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
|
| 184 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
|
| 185 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
|
| 186 |
+
|
| 187 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 188 |
+
# any TensorFlow checkpoint file
|
| 189 |
+
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
|
| 190 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 191 |
+
|
| 192 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 193 |
+
self.register_buffer(
|
| 194 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 195 |
+
)
|
| 196 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 197 |
+
self.register_buffer(
|
| 198 |
+
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
| 202 |
+
def forward(
|
| 203 |
+
self,
|
| 204 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 205 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 206 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 207 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 208 |
+
past_key_values_length: int = 0,
|
| 209 |
+
) -> torch.Tensor:
|
| 210 |
+
if input_ids is not None:
|
| 211 |
+
input_shape = input_ids.size()
|
| 212 |
+
else:
|
| 213 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 214 |
+
|
| 215 |
+
seq_length = input_shape[1]
|
| 216 |
+
|
| 217 |
+
if position_ids is None:
|
| 218 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
| 219 |
+
|
| 220 |
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
| 221 |
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
| 222 |
+
# issue #5664
|
| 223 |
+
if token_type_ids is None:
|
| 224 |
+
if hasattr(self, "token_type_ids"):
|
| 225 |
+
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
| 226 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
| 227 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 228 |
+
else:
|
| 229 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
| 230 |
+
|
| 231 |
+
if inputs_embeds is None:
|
| 232 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 233 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 234 |
+
|
| 235 |
+
embeddings = inputs_embeds + token_type_embeddings
|
| 236 |
+
if self.position_embedding_type == "absolute":
|
| 237 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 238 |
+
embeddings += position_embeddings
|
| 239 |
+
embeddings = self.LayerNorm(embeddings)
|
| 240 |
+
embeddings = self.dropout(embeddings)
|
| 241 |
+
return embeddings
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class AlbertAttention(nn.Module):
|
| 245 |
+
def __init__(self, config: AlbertConfig):
|
| 246 |
+
super().__init__()
|
| 247 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 248 |
+
raise ValueError(
|
| 249 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 250 |
+
f"heads ({config.num_attention_heads}"
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
self.num_attention_heads = config.num_attention_heads
|
| 254 |
+
self.hidden_size = config.hidden_size
|
| 255 |
+
self.attention_head_size = config.hidden_size // config.num_attention_heads
|
| 256 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 257 |
+
|
| 258 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 259 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 260 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 261 |
+
|
| 262 |
+
self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 263 |
+
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 264 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 265 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 266 |
+
self.pruned_heads = set()
|
| 267 |
+
|
| 268 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 269 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 270 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 271 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 272 |
+
|
| 273 |
+
def prune_heads(self, heads: list[int]) -> None:
|
| 274 |
+
if len(heads) == 0:
|
| 275 |
+
return
|
| 276 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 277 |
+
heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Prune linear layers
|
| 281 |
+
self.query = prune_linear_layer(self.query, index)
|
| 282 |
+
self.key = prune_linear_layer(self.key, index)
|
| 283 |
+
self.value = prune_linear_layer(self.value, index)
|
| 284 |
+
self.dense = prune_linear_layer(self.dense, index, dim=1)
|
| 285 |
+
|
| 286 |
+
# Update hyper params and store pruned heads
|
| 287 |
+
self.num_attention_heads = self.num_attention_heads - len(heads)
|
| 288 |
+
self.all_head_size = self.attention_head_size * self.num_attention_heads
|
| 289 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 290 |
+
|
| 291 |
+
def forward(
|
| 292 |
+
self,
|
| 293 |
+
hidden_states: torch.Tensor,
|
| 294 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 295 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 296 |
+
output_attentions: bool = False,
|
| 297 |
+
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
| 298 |
+
batch_size, seq_length, _ = hidden_states.shape
|
| 299 |
+
query_layer = self.query(hidden_states)
|
| 300 |
+
key_layer = self.key(hidden_states)
|
| 301 |
+
value_layer = self.value(hidden_states)
|
| 302 |
+
query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
|
| 303 |
+
1, 2
|
| 304 |
+
)
|
| 305 |
+
key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
|
| 306 |
+
value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
|
| 307 |
+
1, 2
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 311 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 312 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 313 |
+
|
| 314 |
+
if attention_mask is not None:
|
| 315 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 316 |
+
attention_scores = attention_scores + attention_mask
|
| 317 |
+
|
| 318 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 319 |
+
seq_length = hidden_states.size()[1]
|
| 320 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 321 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 322 |
+
distance = position_ids_l - position_ids_r
|
| 323 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 324 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 325 |
+
|
| 326 |
+
if self.position_embedding_type == "relative_key":
|
| 327 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 328 |
+
attention_scores = attention_scores + relative_position_scores
|
| 329 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 330 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 331 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 332 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 333 |
+
|
| 334 |
+
# Normalize the attention scores to probabilities.
|
| 335 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 336 |
+
|
| 337 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 338 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 339 |
+
attention_probs = self.attention_dropout(attention_probs)
|
| 340 |
+
|
| 341 |
+
# Mask heads if we want to
|
| 342 |
+
if head_mask is not None:
|
| 343 |
+
attention_probs = attention_probs * head_mask
|
| 344 |
+
|
| 345 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 346 |
+
context_layer = context_layer.transpose(2, 1).flatten(2)
|
| 347 |
+
|
| 348 |
+
projected_context_layer = self.dense(context_layer)
|
| 349 |
+
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
|
| 350 |
+
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
|
| 351 |
+
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class AlbertSdpaAttention(AlbertAttention):
|
| 355 |
+
def __init__(self, config):
|
| 356 |
+
super().__init__(config)
|
| 357 |
+
self.dropout_prob = config.attention_probs_dropout_prob
|
| 358 |
+
|
| 359 |
+
def forward(
|
| 360 |
+
self,
|
| 361 |
+
hidden_states: torch.Tensor,
|
| 362 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 363 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 364 |
+
output_attentions: bool = False,
|
| 365 |
+
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
| 366 |
+
if self.position_embedding_type != "absolute" or output_attentions:
|
| 367 |
+
logger.warning(
|
| 368 |
+
"AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
| 369 |
+
"non-absolute `position_embedding_type` or `output_attentions=True` . Falling back to "
|
| 370 |
+
"the eager attention implementation, but specifying the eager implementation will be required from "
|
| 371 |
+
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
|
| 372 |
+
'`attn_implementation="eager"` when loading the model.'
|
| 373 |
+
)
|
| 374 |
+
return super().forward(hidden_states, attention_mask, output_attentions=output_attentions)
|
| 375 |
+
|
| 376 |
+
batch_size, seq_len, _ = hidden_states.size()
|
| 377 |
+
query_layer = (
|
| 378 |
+
self.query(hidden_states)
|
| 379 |
+
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
| 380 |
+
.transpose(1, 2)
|
| 381 |
+
)
|
| 382 |
+
key_layer = (
|
| 383 |
+
self.key(hidden_states)
|
| 384 |
+
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
| 385 |
+
.transpose(1, 2)
|
| 386 |
+
)
|
| 387 |
+
value_layer = (
|
| 388 |
+
self.value(hidden_states)
|
| 389 |
+
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
| 390 |
+
.transpose(1, 2)
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
attention_output = torch.nn.functional.scaled_dot_product_attention(
|
| 394 |
+
query=query_layer,
|
| 395 |
+
key=key_layer,
|
| 396 |
+
value=value_layer,
|
| 397 |
+
attn_mask=attention_mask,
|
| 398 |
+
dropout_p=self.dropout_prob if self.training else 0.0,
|
| 399 |
+
is_causal=False,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
attention_output = attention_output.transpose(1, 2)
|
| 403 |
+
attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size)
|
| 404 |
+
|
| 405 |
+
projected_context_layer = self.dense(attention_output)
|
| 406 |
+
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
|
| 407 |
+
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
|
| 408 |
+
return (layernormed_context_layer,)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
ALBERT_ATTENTION_CLASSES = {
|
| 412 |
+
"eager": AlbertAttention,
|
| 413 |
+
"sdpa": AlbertSdpaAttention,
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class AlbertLayer(nn.Module):
|
| 418 |
+
def __init__(self, config: AlbertConfig):
|
| 419 |
+
super().__init__()
|
| 420 |
+
|
| 421 |
+
self.config = config
|
| 422 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 423 |
+
self.seq_len_dim = 1
|
| 424 |
+
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 425 |
+
self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config)
|
| 426 |
+
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 427 |
+
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 428 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 429 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 430 |
+
|
| 431 |
+
def forward(
|
| 432 |
+
self,
|
| 433 |
+
hidden_states: torch.Tensor,
|
| 434 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 435 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 436 |
+
output_attentions: bool = False,
|
| 437 |
+
output_hidden_states: bool = False,
|
| 438 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 439 |
+
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
|
| 440 |
+
|
| 441 |
+
ffn_output = apply_chunking_to_forward(
|
| 442 |
+
self.ff_chunk,
|
| 443 |
+
self.chunk_size_feed_forward,
|
| 444 |
+
self.seq_len_dim,
|
| 445 |
+
attention_output[0],
|
| 446 |
+
)
|
| 447 |
+
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
|
| 448 |
+
|
| 449 |
+
return (hidden_states,) + attention_output[1:] # add attentions if we output them
|
| 450 |
+
|
| 451 |
+
def ff_chunk(self, attention_output: torch.Tensor) -> torch.Tensor:
|
| 452 |
+
ffn_output = self.ffn(attention_output)
|
| 453 |
+
ffn_output = self.activation(ffn_output)
|
| 454 |
+
ffn_output = self.ffn_output(ffn_output)
|
| 455 |
+
return ffn_output
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
class AlbertLayerGroup(nn.Module):
|
| 459 |
+
def __init__(self, config: AlbertConfig):
|
| 460 |
+
super().__init__()
|
| 461 |
+
|
| 462 |
+
self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
|
| 463 |
+
|
| 464 |
+
def forward(
|
| 465 |
+
self,
|
| 466 |
+
hidden_states: torch.Tensor,
|
| 467 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 468 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 469 |
+
output_attentions: bool = False,
|
| 470 |
+
output_hidden_states: bool = False,
|
| 471 |
+
) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
|
| 472 |
+
layer_hidden_states = ()
|
| 473 |
+
layer_attentions = ()
|
| 474 |
+
|
| 475 |
+
for layer_index, albert_layer in enumerate(self.albert_layers):
|
| 476 |
+
layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
|
| 477 |
+
hidden_states = layer_output[0]
|
| 478 |
+
|
| 479 |
+
if output_attentions:
|
| 480 |
+
layer_attentions = layer_attentions + (layer_output[1],)
|
| 481 |
+
|
| 482 |
+
if output_hidden_states:
|
| 483 |
+
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
| 484 |
+
|
| 485 |
+
outputs = (hidden_states,)
|
| 486 |
+
if output_hidden_states:
|
| 487 |
+
outputs = outputs + (layer_hidden_states,)
|
| 488 |
+
if output_attentions:
|
| 489 |
+
outputs = outputs + (layer_attentions,)
|
| 490 |
+
return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class AlbertTransformer(nn.Module):
|
| 494 |
+
def __init__(self, config: AlbertConfig):
|
| 495 |
+
super().__init__()
|
| 496 |
+
|
| 497 |
+
self.config = config
|
| 498 |
+
self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
|
| 499 |
+
self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
|
| 500 |
+
|
| 501 |
+
def forward(
|
| 502 |
+
self,
|
| 503 |
+
hidden_states: torch.Tensor,
|
| 504 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 505 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 506 |
+
output_attentions: bool = False,
|
| 507 |
+
output_hidden_states: bool = False,
|
| 508 |
+
return_dict: bool = True,
|
| 509 |
+
) -> Union[BaseModelOutput, tuple]:
|
| 510 |
+
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
| 511 |
+
|
| 512 |
+
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
| 513 |
+
all_attentions = () if output_attentions else None
|
| 514 |
+
|
| 515 |
+
head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask
|
| 516 |
+
|
| 517 |
+
for i in range(self.config.num_hidden_layers):
|
| 518 |
+
# Number of layers in a hidden group
|
| 519 |
+
layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
|
| 520 |
+
|
| 521 |
+
# Index of the hidden group
|
| 522 |
+
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
| 523 |
+
|
| 524 |
+
layer_group_output = self.albert_layer_groups[group_idx](
|
| 525 |
+
hidden_states,
|
| 526 |
+
attention_mask,
|
| 527 |
+
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
|
| 528 |
+
output_attentions,
|
| 529 |
+
output_hidden_states,
|
| 530 |
+
)
|
| 531 |
+
hidden_states = layer_group_output[0]
|
| 532 |
+
|
| 533 |
+
if output_attentions:
|
| 534 |
+
all_attentions = all_attentions + layer_group_output[-1]
|
| 535 |
+
|
| 536 |
+
if output_hidden_states:
|
| 537 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 538 |
+
|
| 539 |
+
if not return_dict:
|
| 540 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
| 541 |
+
return BaseModelOutput(
|
| 542 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
@auto_docstring
|
| 547 |
+
class AlbertPreTrainedModel(PreTrainedModel):
|
| 548 |
+
config: AlbertConfig
|
| 549 |
+
load_tf_weights = load_tf_weights_in_albert
|
| 550 |
+
base_model_prefix = "albert"
|
| 551 |
+
_supports_sdpa = True
|
| 552 |
+
|
| 553 |
+
def _init_weights(self, module):
|
| 554 |
+
"""Initialize the weights."""
|
| 555 |
+
if isinstance(module, nn.Linear):
|
| 556 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 557 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 558 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 559 |
+
if module.bias is not None:
|
| 560 |
+
module.bias.data.zero_()
|
| 561 |
+
elif isinstance(module, nn.Embedding):
|
| 562 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 563 |
+
if module.padding_idx is not None:
|
| 564 |
+
module.weight.data[module.padding_idx].zero_()
|
| 565 |
+
elif isinstance(module, nn.LayerNorm):
|
| 566 |
+
module.bias.data.zero_()
|
| 567 |
+
module.weight.data.fill_(1.0)
|
| 568 |
+
elif isinstance(module, AlbertMLMHead):
|
| 569 |
+
module.bias.data.zero_()
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
@dataclass
|
| 573 |
+
@auto_docstring(
|
| 574 |
+
custom_intro="""
|
| 575 |
+
Output type of [`AlbertForPreTraining`].
|
| 576 |
+
"""
|
| 577 |
+
)
|
| 578 |
+
class AlbertForPreTrainingOutput(ModelOutput):
|
| 579 |
+
r"""
|
| 580 |
+
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
| 581 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction
|
| 582 |
+
(classification) loss.
|
| 583 |
+
prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 584 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 585 |
+
sop_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
|
| 586 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
| 587 |
+
before SoftMax).
|
| 588 |
+
"""
|
| 589 |
+
|
| 590 |
+
loss: Optional[torch.FloatTensor] = None
|
| 591 |
+
prediction_logits: Optional[torch.FloatTensor] = None
|
| 592 |
+
sop_logits: Optional[torch.FloatTensor] = None
|
| 593 |
+
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 594 |
+
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
@auto_docstring
|
| 598 |
+
class AlbertModel(AlbertPreTrainedModel):
|
| 599 |
+
config: AlbertConfig
|
| 600 |
+
base_model_prefix = "albert"
|
| 601 |
+
|
| 602 |
+
def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True):
|
| 603 |
+
r"""
|
| 604 |
+
add_pooling_layer (bool, *optional*, defaults to `True`):
|
| 605 |
+
Whether to add a pooling layer
|
| 606 |
+
"""
|
| 607 |
+
super().__init__(config)
|
| 608 |
+
|
| 609 |
+
self.config = config
|
| 610 |
+
self.embeddings = AlbertEmbeddings(config)
|
| 611 |
+
self.encoder = AlbertTransformer(config)
|
| 612 |
+
if add_pooling_layer:
|
| 613 |
+
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
|
| 614 |
+
self.pooler_activation = nn.Tanh()
|
| 615 |
+
else:
|
| 616 |
+
self.pooler = None
|
| 617 |
+
self.pooler_activation = None
|
| 618 |
+
|
| 619 |
+
self.attn_implementation = config._attn_implementation
|
| 620 |
+
self.position_embedding_type = config.position_embedding_type
|
| 621 |
+
|
| 622 |
+
# Initialize weights and apply final processing
|
| 623 |
+
self.post_init()
|
| 624 |
+
|
| 625 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 626 |
+
return self.embeddings.word_embeddings
|
| 627 |
+
|
| 628 |
+
def set_input_embeddings(self, value: nn.Embedding) -> None:
|
| 629 |
+
self.embeddings.word_embeddings = value
|
| 630 |
+
|
| 631 |
+
def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
|
| 632 |
+
"""
|
| 633 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
|
| 634 |
+
a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
|
| 635 |
+
model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers.
|
| 636 |
+
|
| 637 |
+
These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
|
| 638 |
+
while [2,3] correspond to the two inner groups of the second hidden layer.
|
| 639 |
+
|
| 640 |
+
Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more
|
| 641 |
+
information about head pruning
|
| 642 |
+
"""
|
| 643 |
+
for layer, heads in heads_to_prune.items():
|
| 644 |
+
group_idx = int(layer / self.config.inner_group_num)
|
| 645 |
+
inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
|
| 646 |
+
self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
|
| 647 |
+
|
| 648 |
+
@auto_docstring
|
| 649 |
+
def forward(
|
| 650 |
+
self,
|
| 651 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 652 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 653 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 654 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 655 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 656 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 657 |
+
output_attentions: Optional[bool] = None,
|
| 658 |
+
output_hidden_states: Optional[bool] = None,
|
| 659 |
+
return_dict: Optional[bool] = None,
|
| 660 |
+
) -> Union[BaseModelOutputWithPooling, tuple]:
|
| 661 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 662 |
+
output_hidden_states = (
|
| 663 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 664 |
+
)
|
| 665 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 666 |
+
|
| 667 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 668 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 669 |
+
elif input_ids is not None:
|
| 670 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 671 |
+
input_shape = input_ids.size()
|
| 672 |
+
elif inputs_embeds is not None:
|
| 673 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 674 |
+
else:
|
| 675 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 676 |
+
|
| 677 |
+
batch_size, seq_length = input_shape
|
| 678 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 679 |
+
|
| 680 |
+
if attention_mask is None:
|
| 681 |
+
attention_mask = torch.ones(input_shape, device=device)
|
| 682 |
+
if token_type_ids is None:
|
| 683 |
+
if hasattr(self.embeddings, "token_type_ids"):
|
| 684 |
+
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
| 685 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
| 686 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 687 |
+
else:
|
| 688 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
| 689 |
+
|
| 690 |
+
embedding_output = self.embeddings(
|
| 691 |
+
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
use_sdpa_attention_mask = (
|
| 695 |
+
self.attn_implementation == "sdpa"
|
| 696 |
+
and self.position_embedding_type == "absolute"
|
| 697 |
+
and head_mask is None
|
| 698 |
+
and not output_attentions
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
if use_sdpa_attention_mask:
|
| 702 |
+
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
| 703 |
+
attention_mask, embedding_output.dtype, tgt_len=seq_length
|
| 704 |
+
)
|
| 705 |
+
else:
|
| 706 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 707 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 708 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
|
| 709 |
+
|
| 710 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 711 |
+
|
| 712 |
+
encoder_outputs = self.encoder(
|
| 713 |
+
embedding_output,
|
| 714 |
+
extended_attention_mask,
|
| 715 |
+
head_mask=head_mask,
|
| 716 |
+
output_attentions=output_attentions,
|
| 717 |
+
output_hidden_states=output_hidden_states,
|
| 718 |
+
return_dict=return_dict,
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
sequence_output = encoder_outputs[0]
|
| 722 |
+
|
| 723 |
+
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
|
| 724 |
+
|
| 725 |
+
if not return_dict:
|
| 726 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 727 |
+
|
| 728 |
+
return BaseModelOutputWithPooling(
|
| 729 |
+
last_hidden_state=sequence_output,
|
| 730 |
+
pooler_output=pooled_output,
|
| 731 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 732 |
+
attentions=encoder_outputs.attentions,
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
@auto_docstring(
|
| 737 |
+
custom_intro="""
|
| 738 |
+
Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
|
| 739 |
+
`sentence order prediction (classification)` head.
|
| 740 |
+
"""
|
| 741 |
+
)
|
| 742 |
+
class AlbertForPreTraining(AlbertPreTrainedModel):
|
| 743 |
+
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
|
| 744 |
+
|
| 745 |
+
def __init__(self, config: AlbertConfig):
|
| 746 |
+
super().__init__(config)
|
| 747 |
+
|
| 748 |
+
self.albert = AlbertModel(config)
|
| 749 |
+
self.predictions = AlbertMLMHead(config)
|
| 750 |
+
self.sop_classifier = AlbertSOPHead(config)
|
| 751 |
+
|
| 752 |
+
# Initialize weights and apply final processing
|
| 753 |
+
self.post_init()
|
| 754 |
+
|
| 755 |
+
def get_output_embeddings(self) -> nn.Linear:
|
| 756 |
+
return self.predictions.decoder
|
| 757 |
+
|
| 758 |
+
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
| 759 |
+
self.predictions.decoder = new_embeddings
|
| 760 |
+
|
| 761 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 762 |
+
return self.albert.embeddings.word_embeddings
|
| 763 |
+
|
| 764 |
+
@auto_docstring
|
| 765 |
+
def forward(
|
| 766 |
+
self,
|
| 767 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 768 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 769 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 770 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 771 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 772 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 773 |
+
labels: Optional[torch.LongTensor] = None,
|
| 774 |
+
sentence_order_label: Optional[torch.LongTensor] = None,
|
| 775 |
+
output_attentions: Optional[bool] = None,
|
| 776 |
+
output_hidden_states: Optional[bool] = None,
|
| 777 |
+
return_dict: Optional[bool] = None,
|
| 778 |
+
) -> Union[AlbertForPreTrainingOutput, tuple]:
|
| 779 |
+
r"""
|
| 780 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 781 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 782 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 783 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 784 |
+
sentence_order_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 785 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
| 786 |
+
(see `input_ids` docstring) Indices should be in `[0, 1]`. `0` indicates original order (sequence A, then
|
| 787 |
+
sequence B), `1` indicates switched order (sequence B, then sequence A).
|
| 788 |
+
|
| 789 |
+
Example:
|
| 790 |
+
|
| 791 |
+
```python
|
| 792 |
+
>>> from transformers import AutoTokenizer, AlbertForPreTraining
|
| 793 |
+
>>> import torch
|
| 794 |
+
|
| 795 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
|
| 796 |
+
>>> model = AlbertForPreTraining.from_pretrained("albert/albert-base-v2")
|
| 797 |
+
|
| 798 |
+
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
|
| 799 |
+
>>> # Batch size 1
|
| 800 |
+
>>> outputs = model(input_ids)
|
| 801 |
+
|
| 802 |
+
>>> prediction_logits = outputs.prediction_logits
|
| 803 |
+
>>> sop_logits = outputs.sop_logits
|
| 804 |
+
```"""
|
| 805 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 806 |
+
|
| 807 |
+
outputs = self.albert(
|
| 808 |
+
input_ids,
|
| 809 |
+
attention_mask=attention_mask,
|
| 810 |
+
token_type_ids=token_type_ids,
|
| 811 |
+
position_ids=position_ids,
|
| 812 |
+
head_mask=head_mask,
|
| 813 |
+
inputs_embeds=inputs_embeds,
|
| 814 |
+
output_attentions=output_attentions,
|
| 815 |
+
output_hidden_states=output_hidden_states,
|
| 816 |
+
return_dict=return_dict,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
sequence_output, pooled_output = outputs[:2]
|
| 820 |
+
|
| 821 |
+
prediction_scores = self.predictions(sequence_output)
|
| 822 |
+
sop_scores = self.sop_classifier(pooled_output)
|
| 823 |
+
|
| 824 |
+
total_loss = None
|
| 825 |
+
if labels is not None and sentence_order_label is not None:
|
| 826 |
+
loss_fct = CrossEntropyLoss()
|
| 827 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 828 |
+
sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
|
| 829 |
+
total_loss = masked_lm_loss + sentence_order_loss
|
| 830 |
+
|
| 831 |
+
if not return_dict:
|
| 832 |
+
output = (prediction_scores, sop_scores) + outputs[2:]
|
| 833 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 834 |
+
|
| 835 |
+
return AlbertForPreTrainingOutput(
|
| 836 |
+
loss=total_loss,
|
| 837 |
+
prediction_logits=prediction_scores,
|
| 838 |
+
sop_logits=sop_scores,
|
| 839 |
+
hidden_states=outputs.hidden_states,
|
| 840 |
+
attentions=outputs.attentions,
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
class AlbertMLMHead(nn.Module):
|
| 845 |
+
def __init__(self, config: AlbertConfig):
|
| 846 |
+
super().__init__()
|
| 847 |
+
|
| 848 |
+
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
|
| 849 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 850 |
+
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
| 851 |
+
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
|
| 852 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 853 |
+
self.decoder.bias = self.bias
|
| 854 |
+
|
| 855 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 856 |
+
hidden_states = self.dense(hidden_states)
|
| 857 |
+
hidden_states = self.activation(hidden_states)
|
| 858 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 859 |
+
hidden_states = self.decoder(hidden_states)
|
| 860 |
+
|
| 861 |
+
prediction_scores = hidden_states
|
| 862 |
+
|
| 863 |
+
return prediction_scores
|
| 864 |
+
|
| 865 |
+
def _tie_weights(self) -> None:
|
| 866 |
+
# For accelerate compatibility and to not break backward compatibility
|
| 867 |
+
if self.decoder.bias.device.type == "meta":
|
| 868 |
+
self.decoder.bias = self.bias
|
| 869 |
+
else:
|
| 870 |
+
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
| 871 |
+
self.bias = self.decoder.bias
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
class AlbertSOPHead(nn.Module):
|
| 875 |
+
def __init__(self, config: AlbertConfig):
|
| 876 |
+
super().__init__()
|
| 877 |
+
|
| 878 |
+
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
| 879 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 880 |
+
|
| 881 |
+
def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
|
| 882 |
+
dropout_pooled_output = self.dropout(pooled_output)
|
| 883 |
+
logits = self.classifier(dropout_pooled_output)
|
| 884 |
+
return logits
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
@auto_docstring
|
| 888 |
+
class AlbertForMaskedLM(AlbertPreTrainedModel):
|
| 889 |
+
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
|
| 890 |
+
|
| 891 |
+
def __init__(self, config):
|
| 892 |
+
super().__init__(config)
|
| 893 |
+
|
| 894 |
+
self.albert = AlbertModel(config, add_pooling_layer=False)
|
| 895 |
+
self.predictions = AlbertMLMHead(config)
|
| 896 |
+
|
| 897 |
+
# Initialize weights and apply final processing
|
| 898 |
+
self.post_init()
|
| 899 |
+
|
| 900 |
+
def get_output_embeddings(self) -> nn.Linear:
|
| 901 |
+
return self.predictions.decoder
|
| 902 |
+
|
| 903 |
+
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
| 904 |
+
self.predictions.decoder = new_embeddings
|
| 905 |
+
self.predictions.bias = new_embeddings.bias
|
| 906 |
+
|
| 907 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 908 |
+
return self.albert.embeddings.word_embeddings
|
| 909 |
+
|
| 910 |
+
@auto_docstring
|
| 911 |
+
def forward(
|
| 912 |
+
self,
|
| 913 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 914 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 915 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 916 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 917 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 918 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 919 |
+
labels: Optional[torch.LongTensor] = None,
|
| 920 |
+
output_attentions: Optional[bool] = None,
|
| 921 |
+
output_hidden_states: Optional[bool] = None,
|
| 922 |
+
return_dict: Optional[bool] = None,
|
| 923 |
+
) -> Union[MaskedLMOutput, tuple]:
|
| 924 |
+
r"""
|
| 925 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 926 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 927 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 928 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 929 |
+
|
| 930 |
+
Example:
|
| 931 |
+
|
| 932 |
+
```python
|
| 933 |
+
>>> import torch
|
| 934 |
+
>>> from transformers import AutoTokenizer, AlbertForMaskedLM
|
| 935 |
+
|
| 936 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
|
| 937 |
+
>>> model = AlbertForMaskedLM.from_pretrained("albert/albert-base-v2")
|
| 938 |
+
|
| 939 |
+
>>> # add mask_token
|
| 940 |
+
>>> inputs = tokenizer("The capital of [MASK] is Paris.", return_tensors="pt")
|
| 941 |
+
>>> with torch.no_grad():
|
| 942 |
+
... logits = model(**inputs).logits
|
| 943 |
+
|
| 944 |
+
>>> # retrieve index of [MASK]
|
| 945 |
+
>>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
|
| 946 |
+
>>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
|
| 947 |
+
>>> tokenizer.decode(predicted_token_id)
|
| 948 |
+
'france'
|
| 949 |
+
```
|
| 950 |
+
|
| 951 |
+
```python
|
| 952 |
+
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
|
| 953 |
+
>>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
|
| 954 |
+
>>> outputs = model(**inputs, labels=labels)
|
| 955 |
+
>>> round(outputs.loss.item(), 2)
|
| 956 |
+
0.81
|
| 957 |
+
```
|
| 958 |
+
"""
|
| 959 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 960 |
+
|
| 961 |
+
outputs = self.albert(
|
| 962 |
+
input_ids=input_ids,
|
| 963 |
+
attention_mask=attention_mask,
|
| 964 |
+
token_type_ids=token_type_ids,
|
| 965 |
+
position_ids=position_ids,
|
| 966 |
+
head_mask=head_mask,
|
| 967 |
+
inputs_embeds=inputs_embeds,
|
| 968 |
+
output_attentions=output_attentions,
|
| 969 |
+
output_hidden_states=output_hidden_states,
|
| 970 |
+
return_dict=return_dict,
|
| 971 |
+
)
|
| 972 |
+
sequence_outputs = outputs[0]
|
| 973 |
+
|
| 974 |
+
prediction_scores = self.predictions(sequence_outputs)
|
| 975 |
+
|
| 976 |
+
masked_lm_loss = None
|
| 977 |
+
if labels is not None:
|
| 978 |
+
loss_fct = CrossEntropyLoss()
|
| 979 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 980 |
+
|
| 981 |
+
if not return_dict:
|
| 982 |
+
output = (prediction_scores,) + outputs[2:]
|
| 983 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 984 |
+
|
| 985 |
+
return MaskedLMOutput(
|
| 986 |
+
loss=masked_lm_loss,
|
| 987 |
+
logits=prediction_scores,
|
| 988 |
+
hidden_states=outputs.hidden_states,
|
| 989 |
+
attentions=outputs.attentions,
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
@auto_docstring(
|
| 994 |
+
custom_intro="""
|
| 995 |
+
Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
| 996 |
+
output) e.g. for GLUE tasks.
|
| 997 |
+
"""
|
| 998 |
+
)
|
| 999 |
+
class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
| 1000 |
+
def __init__(self, config: AlbertConfig):
|
| 1001 |
+
super().__init__(config)
|
| 1002 |
+
self.num_labels = config.num_labels
|
| 1003 |
+
self.config = config
|
| 1004 |
+
|
| 1005 |
+
self.albert = AlbertModel(config)
|
| 1006 |
+
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
| 1007 |
+
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
| 1008 |
+
|
| 1009 |
+
# Initialize weights and apply final processing
|
| 1010 |
+
self.post_init()
|
| 1011 |
+
|
| 1012 |
+
@auto_docstring
|
| 1013 |
+
def forward(
|
| 1014 |
+
self,
|
| 1015 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1016 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1017 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1018 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1019 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1020 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1021 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1022 |
+
output_attentions: Optional[bool] = None,
|
| 1023 |
+
output_hidden_states: Optional[bool] = None,
|
| 1024 |
+
return_dict: Optional[bool] = None,
|
| 1025 |
+
) -> Union[SequenceClassifierOutput, tuple]:
|
| 1026 |
+
r"""
|
| 1027 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1028 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1029 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1030 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1031 |
+
"""
|
| 1032 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1033 |
+
|
| 1034 |
+
outputs = self.albert(
|
| 1035 |
+
input_ids=input_ids,
|
| 1036 |
+
attention_mask=attention_mask,
|
| 1037 |
+
token_type_ids=token_type_ids,
|
| 1038 |
+
position_ids=position_ids,
|
| 1039 |
+
head_mask=head_mask,
|
| 1040 |
+
inputs_embeds=inputs_embeds,
|
| 1041 |
+
output_attentions=output_attentions,
|
| 1042 |
+
output_hidden_states=output_hidden_states,
|
| 1043 |
+
return_dict=return_dict,
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
pooled_output = outputs[1]
|
| 1047 |
+
|
| 1048 |
+
pooled_output = self.dropout(pooled_output)
|
| 1049 |
+
logits = self.classifier(pooled_output)
|
| 1050 |
+
|
| 1051 |
+
loss = None
|
| 1052 |
+
if labels is not None:
|
| 1053 |
+
if self.config.problem_type is None:
|
| 1054 |
+
if self.num_labels == 1:
|
| 1055 |
+
self.config.problem_type = "regression"
|
| 1056 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1057 |
+
self.config.problem_type = "single_label_classification"
|
| 1058 |
+
else:
|
| 1059 |
+
self.config.problem_type = "multi_label_classification"
|
| 1060 |
+
|
| 1061 |
+
if self.config.problem_type == "regression":
|
| 1062 |
+
loss_fct = MSELoss()
|
| 1063 |
+
if self.num_labels == 1:
|
| 1064 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 1065 |
+
else:
|
| 1066 |
+
loss = loss_fct(logits, labels)
|
| 1067 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1068 |
+
loss_fct = CrossEntropyLoss()
|
| 1069 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1070 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1071 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1072 |
+
loss = loss_fct(logits, labels)
|
| 1073 |
+
|
| 1074 |
+
if not return_dict:
|
| 1075 |
+
output = (logits,) + outputs[2:]
|
| 1076 |
+
return ((loss,) + output) if loss is not None else output
|
| 1077 |
+
|
| 1078 |
+
return SequenceClassifierOutput(
|
| 1079 |
+
loss=loss,
|
| 1080 |
+
logits=logits,
|
| 1081 |
+
hidden_states=outputs.hidden_states,
|
| 1082 |
+
attentions=outputs.attentions,
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
+
|
| 1086 |
+
@auto_docstring
|
| 1087 |
+
class AlbertForTokenClassification(AlbertPreTrainedModel):
|
| 1088 |
+
def __init__(self, config: AlbertConfig):
|
| 1089 |
+
super().__init__(config)
|
| 1090 |
+
self.num_labels = config.num_labels
|
| 1091 |
+
|
| 1092 |
+
self.albert = AlbertModel(config, add_pooling_layer=False)
|
| 1093 |
+
classifier_dropout_prob = (
|
| 1094 |
+
config.classifier_dropout_prob
|
| 1095 |
+
if config.classifier_dropout_prob is not None
|
| 1096 |
+
else config.hidden_dropout_prob
|
| 1097 |
+
)
|
| 1098 |
+
self.dropout = nn.Dropout(classifier_dropout_prob)
|
| 1099 |
+
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
| 1100 |
+
|
| 1101 |
+
# Initialize weights and apply final processing
|
| 1102 |
+
self.post_init()
|
| 1103 |
+
|
| 1104 |
+
@auto_docstring
|
| 1105 |
+
def forward(
|
| 1106 |
+
self,
|
| 1107 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1108 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1109 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1110 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1111 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1112 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1113 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1114 |
+
output_attentions: Optional[bool] = None,
|
| 1115 |
+
output_hidden_states: Optional[bool] = None,
|
| 1116 |
+
return_dict: Optional[bool] = None,
|
| 1117 |
+
) -> Union[TokenClassifierOutput, tuple]:
|
| 1118 |
+
r"""
|
| 1119 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1120 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1121 |
+
"""
|
| 1122 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1123 |
+
|
| 1124 |
+
outputs = self.albert(
|
| 1125 |
+
input_ids,
|
| 1126 |
+
attention_mask=attention_mask,
|
| 1127 |
+
token_type_ids=token_type_ids,
|
| 1128 |
+
position_ids=position_ids,
|
| 1129 |
+
head_mask=head_mask,
|
| 1130 |
+
inputs_embeds=inputs_embeds,
|
| 1131 |
+
output_attentions=output_attentions,
|
| 1132 |
+
output_hidden_states=output_hidden_states,
|
| 1133 |
+
return_dict=return_dict,
|
| 1134 |
+
)
|
| 1135 |
+
|
| 1136 |
+
sequence_output = outputs[0]
|
| 1137 |
+
|
| 1138 |
+
sequence_output = self.dropout(sequence_output)
|
| 1139 |
+
logits = self.classifier(sequence_output)
|
| 1140 |
+
|
| 1141 |
+
loss = None
|
| 1142 |
+
if labels is not None:
|
| 1143 |
+
loss_fct = CrossEntropyLoss()
|
| 1144 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1145 |
+
|
| 1146 |
+
if not return_dict:
|
| 1147 |
+
output = (logits,) + outputs[2:]
|
| 1148 |
+
return ((loss,) + output) if loss is not None else output
|
| 1149 |
+
|
| 1150 |
+
return TokenClassifierOutput(
|
| 1151 |
+
loss=loss,
|
| 1152 |
+
logits=logits,
|
| 1153 |
+
hidden_states=outputs.hidden_states,
|
| 1154 |
+
attentions=outputs.attentions,
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
|
| 1158 |
+
@auto_docstring
|
| 1159 |
+
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
| 1160 |
+
def __init__(self, config: AlbertConfig):
|
| 1161 |
+
super().__init__(config)
|
| 1162 |
+
self.num_labels = config.num_labels
|
| 1163 |
+
|
| 1164 |
+
self.albert = AlbertModel(config, add_pooling_layer=False)
|
| 1165 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 1166 |
+
|
| 1167 |
+
# Initialize weights and apply final processing
|
| 1168 |
+
self.post_init()
|
| 1169 |
+
|
| 1170 |
+
@auto_docstring
|
| 1171 |
+
def forward(
|
| 1172 |
+
self,
|
| 1173 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1174 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1175 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1176 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1177 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1178 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1179 |
+
start_positions: Optional[torch.LongTensor] = None,
|
| 1180 |
+
end_positions: Optional[torch.LongTensor] = None,
|
| 1181 |
+
output_attentions: Optional[bool] = None,
|
| 1182 |
+
output_hidden_states: Optional[bool] = None,
|
| 1183 |
+
return_dict: Optional[bool] = None,
|
| 1184 |
+
) -> Union[AlbertForPreTrainingOutput, tuple]:
|
| 1185 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1186 |
+
|
| 1187 |
+
outputs = self.albert(
|
| 1188 |
+
input_ids=input_ids,
|
| 1189 |
+
attention_mask=attention_mask,
|
| 1190 |
+
token_type_ids=token_type_ids,
|
| 1191 |
+
position_ids=position_ids,
|
| 1192 |
+
head_mask=head_mask,
|
| 1193 |
+
inputs_embeds=inputs_embeds,
|
| 1194 |
+
output_attentions=output_attentions,
|
| 1195 |
+
output_hidden_states=output_hidden_states,
|
| 1196 |
+
return_dict=return_dict,
|
| 1197 |
+
)
|
| 1198 |
+
|
| 1199 |
+
sequence_output = outputs[0]
|
| 1200 |
+
|
| 1201 |
+
logits: torch.Tensor = self.qa_outputs(sequence_output)
|
| 1202 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1203 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1204 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1205 |
+
|
| 1206 |
+
total_loss = None
|
| 1207 |
+
if start_positions is not None and end_positions is not None:
|
| 1208 |
+
# If we are on multi-GPU, split add a dimension
|
| 1209 |
+
if len(start_positions.size()) > 1:
|
| 1210 |
+
start_positions = start_positions.squeeze(-1)
|
| 1211 |
+
if len(end_positions.size()) > 1:
|
| 1212 |
+
end_positions = end_positions.squeeze(-1)
|
| 1213 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1214 |
+
ignored_index = start_logits.size(1)
|
| 1215 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1216 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1217 |
+
|
| 1218 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1219 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1220 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1221 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1222 |
+
|
| 1223 |
+
if not return_dict:
|
| 1224 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1225 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1226 |
+
|
| 1227 |
+
return QuestionAnsweringModelOutput(
|
| 1228 |
+
loss=total_loss,
|
| 1229 |
+
start_logits=start_logits,
|
| 1230 |
+
end_logits=end_logits,
|
| 1231 |
+
hidden_states=outputs.hidden_states,
|
| 1232 |
+
attentions=outputs.attentions,
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
@auto_docstring
|
| 1237 |
+
class AlbertForMultipleChoice(AlbertPreTrainedModel):
|
| 1238 |
+
def __init__(self, config: AlbertConfig):
|
| 1239 |
+
super().__init__(config)
|
| 1240 |
+
|
| 1241 |
+
self.albert = AlbertModel(config)
|
| 1242 |
+
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
| 1243 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
| 1244 |
+
|
| 1245 |
+
# Initialize weights and apply final processing
|
| 1246 |
+
self.post_init()
|
| 1247 |
+
|
| 1248 |
+
@auto_docstring
|
| 1249 |
+
def forward(
|
| 1250 |
+
self,
|
| 1251 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1252 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1253 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1254 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1255 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1256 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1257 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1258 |
+
output_attentions: Optional[bool] = None,
|
| 1259 |
+
output_hidden_states: Optional[bool] = None,
|
| 1260 |
+
return_dict: Optional[bool] = None,
|
| 1261 |
+
) -> Union[AlbertForPreTrainingOutput, tuple]:
|
| 1262 |
+
r"""
|
| 1263 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
|
| 1264 |
+
Indices of input sequence tokens in the vocabulary.
|
| 1265 |
+
|
| 1266 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
|
| 1267 |
+
[`PreTrainedTokenizer.encode`] for details.
|
| 1268 |
+
|
| 1269 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1270 |
+
token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
|
| 1271 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 1272 |
+
1]`:
|
| 1273 |
+
|
| 1274 |
+
- 0 corresponds to a *sentence A* token,
|
| 1275 |
+
- 1 corresponds to a *sentence B* token.
|
| 1276 |
+
|
| 1277 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 1278 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
|
| 1279 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 1280 |
+
config.max_position_embeddings - 1]`.
|
| 1281 |
+
|
| 1282 |
+
[What are position IDs?](../glossary#position-ids)
|
| 1283 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
|
| 1284 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 1285 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 1286 |
+
model's internal embedding lookup matrix.
|
| 1287 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1288 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
| 1289 |
+
num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
|
| 1290 |
+
*input_ids* above)
|
| 1291 |
+
"""
|
| 1292 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1293 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
| 1294 |
+
|
| 1295 |
+
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
| 1296 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 1297 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
| 1298 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
| 1299 |
+
inputs_embeds = (
|
| 1300 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
| 1301 |
+
if inputs_embeds is not None
|
| 1302 |
+
else None
|
| 1303 |
+
)
|
| 1304 |
+
outputs = self.albert(
|
| 1305 |
+
input_ids,
|
| 1306 |
+
attention_mask=attention_mask,
|
| 1307 |
+
token_type_ids=token_type_ids,
|
| 1308 |
+
position_ids=position_ids,
|
| 1309 |
+
head_mask=head_mask,
|
| 1310 |
+
inputs_embeds=inputs_embeds,
|
| 1311 |
+
output_attentions=output_attentions,
|
| 1312 |
+
output_hidden_states=output_hidden_states,
|
| 1313 |
+
return_dict=return_dict,
|
| 1314 |
+
)
|
| 1315 |
+
|
| 1316 |
+
pooled_output = outputs[1]
|
| 1317 |
+
|
| 1318 |
+
pooled_output = self.dropout(pooled_output)
|
| 1319 |
+
logits: torch.Tensor = self.classifier(pooled_output)
|
| 1320 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 1321 |
+
|
| 1322 |
+
loss = None
|
| 1323 |
+
if labels is not None:
|
| 1324 |
+
loss_fct = CrossEntropyLoss()
|
| 1325 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 1326 |
+
|
| 1327 |
+
if not return_dict:
|
| 1328 |
+
output = (reshaped_logits,) + outputs[2:]
|
| 1329 |
+
return ((loss,) + output) if loss is not None else output
|
| 1330 |
+
|
| 1331 |
+
return MultipleChoiceModelOutput(
|
| 1332 |
+
loss=loss,
|
| 1333 |
+
logits=reshaped_logits,
|
| 1334 |
+
hidden_states=outputs.hidden_states,
|
| 1335 |
+
attentions=outputs.attentions,
|
| 1336 |
+
)
|
| 1337 |
+
|
| 1338 |
+
|
| 1339 |
+
__all__ = [
|
| 1340 |
+
"load_tf_weights_in_albert",
|
| 1341 |
+
"AlbertPreTrainedModel",
|
| 1342 |
+
"AlbertModel",
|
| 1343 |
+
"AlbertForPreTraining",
|
| 1344 |
+
"AlbertForMaskedLM",
|
| 1345 |
+
"AlbertForSequenceClassification",
|
| 1346 |
+
"AlbertForTokenClassification",
|
| 1347 |
+
"AlbertForQuestionAnswering",
|
| 1348 |
+
"AlbertForMultipleChoice",
|
| 1349 |
+
]
|
venv/lib/python3.13/site-packages/transformers/models/albert/modeling_flax_albert.py
ADDED
|
@@ -0,0 +1,1132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 Google AI, Google Brain and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Callable, Optional
|
| 17 |
+
|
| 18 |
+
import flax
|
| 19 |
+
import flax.linen as nn
|
| 20 |
+
import jax
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
import numpy as np
|
| 23 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
| 24 |
+
from flax.linen.attention import dot_product_attention_weights
|
| 25 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
| 26 |
+
from jax import lax
|
| 27 |
+
|
| 28 |
+
from ...modeling_flax_outputs import (
|
| 29 |
+
FlaxBaseModelOutput,
|
| 30 |
+
FlaxBaseModelOutputWithPooling,
|
| 31 |
+
FlaxMaskedLMOutput,
|
| 32 |
+
FlaxMultipleChoiceModelOutput,
|
| 33 |
+
FlaxQuestionAnsweringModelOutput,
|
| 34 |
+
FlaxSequenceClassifierOutput,
|
| 35 |
+
FlaxTokenClassifierOutput,
|
| 36 |
+
)
|
| 37 |
+
from ...modeling_flax_utils import (
|
| 38 |
+
ACT2FN,
|
| 39 |
+
FlaxPreTrainedModel,
|
| 40 |
+
append_call_sample_docstring,
|
| 41 |
+
append_replace_return_docstrings,
|
| 42 |
+
overwrite_call_docstring,
|
| 43 |
+
)
|
| 44 |
+
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
| 45 |
+
from .configuration_albert import AlbertConfig
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__)
|
| 49 |
+
|
| 50 |
+
_CHECKPOINT_FOR_DOC = "albert/albert-base-v2"
|
| 51 |
+
_CONFIG_FOR_DOC = "AlbertConfig"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@flax.struct.dataclass
|
| 55 |
+
class FlaxAlbertForPreTrainingOutput(ModelOutput):
|
| 56 |
+
"""
|
| 57 |
+
Output type of [`FlaxAlbertForPreTraining`].
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 61 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 62 |
+
sop_logits (`jnp.ndarray` of shape `(batch_size, 2)`):
|
| 63 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
| 64 |
+
before SoftMax).
|
| 65 |
+
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 66 |
+
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 67 |
+
`(batch_size, sequence_length, hidden_size)`.
|
| 68 |
+
|
| 69 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 70 |
+
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 71 |
+
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 72 |
+
sequence_length)`.
|
| 73 |
+
|
| 74 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 75 |
+
heads.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
prediction_logits: jnp.ndarray = None
|
| 79 |
+
sop_logits: jnp.ndarray = None
|
| 80 |
+
hidden_states: Optional[tuple[jnp.ndarray]] = None
|
| 81 |
+
attentions: Optional[tuple[jnp.ndarray]] = None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
ALBERT_START_DOCSTRING = r"""
|
| 85 |
+
|
| 86 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 87 |
+
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
|
| 88 |
+
|
| 89 |
+
This model is also a
|
| 90 |
+
[flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
|
| 91 |
+
a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
|
| 92 |
+
behavior.
|
| 93 |
+
|
| 94 |
+
Finally, this model supports inherent JAX features such as:
|
| 95 |
+
|
| 96 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
| 97 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
| 98 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
| 99 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
| 100 |
+
|
| 101 |
+
Parameters:
|
| 102 |
+
config ([`AlbertConfig`]): Model configuration class with all the parameters of the model.
|
| 103 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 104 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 105 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
| 106 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
| 107 |
+
`jax.numpy.bfloat16` (on TPUs).
|
| 108 |
+
|
| 109 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
| 110 |
+
specified all the computation will be performed with the given `dtype`.
|
| 111 |
+
|
| 112 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
| 113 |
+
parameters.**
|
| 114 |
+
|
| 115 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
| 116 |
+
[`~FlaxPreTrainedModel.to_bf16`].
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
ALBERT_INPUTS_DOCSTRING = r"""
|
| 120 |
+
Args:
|
| 121 |
+
input_ids (`numpy.ndarray` of shape `({0})`):
|
| 122 |
+
Indices of input sequence tokens in the vocabulary.
|
| 123 |
+
|
| 124 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 125 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 126 |
+
|
| 127 |
+
[What are input IDs?](../glossary#input-ids)
|
| 128 |
+
attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
|
| 129 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 130 |
+
|
| 131 |
+
- 1 for tokens that are **not masked**,
|
| 132 |
+
- 0 for tokens that are **masked**.
|
| 133 |
+
|
| 134 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 135 |
+
token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
|
| 136 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 137 |
+
1]`:
|
| 138 |
+
|
| 139 |
+
- 0 corresponds to a *sentence A* token,
|
| 140 |
+
- 1 corresponds to a *sentence B* token.
|
| 141 |
+
|
| 142 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 143 |
+
position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
|
| 144 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 145 |
+
config.max_position_embeddings - 1]`.
|
| 146 |
+
return_dict (`bool`, *optional*):
|
| 147 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 148 |
+
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class FlaxAlbertEmbeddings(nn.Module):
|
| 153 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 154 |
+
|
| 155 |
+
config: AlbertConfig
|
| 156 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 157 |
+
|
| 158 |
+
def setup(self):
|
| 159 |
+
self.word_embeddings = nn.Embed(
|
| 160 |
+
self.config.vocab_size,
|
| 161 |
+
self.config.embedding_size,
|
| 162 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
| 163 |
+
)
|
| 164 |
+
self.position_embeddings = nn.Embed(
|
| 165 |
+
self.config.max_position_embeddings,
|
| 166 |
+
self.config.embedding_size,
|
| 167 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
| 168 |
+
)
|
| 169 |
+
self.token_type_embeddings = nn.Embed(
|
| 170 |
+
self.config.type_vocab_size,
|
| 171 |
+
self.config.embedding_size,
|
| 172 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
| 173 |
+
)
|
| 174 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 175 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 176 |
+
|
| 177 |
+
def __call__(self, input_ids, token_type_ids, position_ids, deterministic: bool = True):
|
| 178 |
+
# Embed
|
| 179 |
+
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
|
| 180 |
+
position_embeds = self.position_embeddings(position_ids.astype("i4"))
|
| 181 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
|
| 182 |
+
|
| 183 |
+
# Sum all embeddings
|
| 184 |
+
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
|
| 185 |
+
|
| 186 |
+
# Layer Norm
|
| 187 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 188 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 189 |
+
return hidden_states
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class FlaxAlbertSelfAttention(nn.Module):
|
| 193 |
+
config: AlbertConfig
|
| 194 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 195 |
+
|
| 196 |
+
def setup(self):
|
| 197 |
+
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
| 198 |
+
raise ValueError(
|
| 199 |
+
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
|
| 200 |
+
" : {self.config.num_attention_heads}"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
self.query = nn.Dense(
|
| 204 |
+
self.config.hidden_size,
|
| 205 |
+
dtype=self.dtype,
|
| 206 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 207 |
+
)
|
| 208 |
+
self.key = nn.Dense(
|
| 209 |
+
self.config.hidden_size,
|
| 210 |
+
dtype=self.dtype,
|
| 211 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 212 |
+
)
|
| 213 |
+
self.value = nn.Dense(
|
| 214 |
+
self.config.hidden_size,
|
| 215 |
+
dtype=self.dtype,
|
| 216 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 217 |
+
)
|
| 218 |
+
self.dense = nn.Dense(
|
| 219 |
+
self.config.hidden_size,
|
| 220 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 221 |
+
dtype=self.dtype,
|
| 222 |
+
)
|
| 223 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 224 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 225 |
+
|
| 226 |
+
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
| 227 |
+
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
| 228 |
+
|
| 229 |
+
query_states = self.query(hidden_states).reshape(
|
| 230 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 231 |
+
)
|
| 232 |
+
value_states = self.value(hidden_states).reshape(
|
| 233 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 234 |
+
)
|
| 235 |
+
key_states = self.key(hidden_states).reshape(
|
| 236 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Convert the boolean attention mask to an attention bias.
|
| 240 |
+
if attention_mask is not None:
|
| 241 |
+
# attention mask in the form of attention bias
|
| 242 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
| 243 |
+
attention_bias = lax.select(
|
| 244 |
+
attention_mask > 0,
|
| 245 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
| 246 |
+
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
| 247 |
+
)
|
| 248 |
+
else:
|
| 249 |
+
attention_bias = None
|
| 250 |
+
|
| 251 |
+
dropout_rng = None
|
| 252 |
+
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
|
| 253 |
+
dropout_rng = self.make_rng("dropout")
|
| 254 |
+
|
| 255 |
+
attn_weights = dot_product_attention_weights(
|
| 256 |
+
query_states,
|
| 257 |
+
key_states,
|
| 258 |
+
bias=attention_bias,
|
| 259 |
+
dropout_rng=dropout_rng,
|
| 260 |
+
dropout_rate=self.config.attention_probs_dropout_prob,
|
| 261 |
+
broadcast_dropout=True,
|
| 262 |
+
deterministic=deterministic,
|
| 263 |
+
dtype=self.dtype,
|
| 264 |
+
precision=None,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
| 268 |
+
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
| 269 |
+
|
| 270 |
+
projected_attn_output = self.dense(attn_output)
|
| 271 |
+
projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic)
|
| 272 |
+
layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states)
|
| 273 |
+
outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,)
|
| 274 |
+
return outputs
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class FlaxAlbertLayer(nn.Module):
|
| 278 |
+
config: AlbertConfig
|
| 279 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 280 |
+
|
| 281 |
+
def setup(self):
|
| 282 |
+
self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype)
|
| 283 |
+
self.ffn = nn.Dense(
|
| 284 |
+
self.config.intermediate_size,
|
| 285 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 286 |
+
dtype=self.dtype,
|
| 287 |
+
)
|
| 288 |
+
self.activation = ACT2FN[self.config.hidden_act]
|
| 289 |
+
self.ffn_output = nn.Dense(
|
| 290 |
+
self.config.hidden_size,
|
| 291 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 292 |
+
dtype=self.dtype,
|
| 293 |
+
)
|
| 294 |
+
self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 295 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 296 |
+
|
| 297 |
+
def __call__(
|
| 298 |
+
self,
|
| 299 |
+
hidden_states,
|
| 300 |
+
attention_mask,
|
| 301 |
+
deterministic: bool = True,
|
| 302 |
+
output_attentions: bool = False,
|
| 303 |
+
):
|
| 304 |
+
attention_outputs = self.attention(
|
| 305 |
+
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
| 306 |
+
)
|
| 307 |
+
attention_output = attention_outputs[0]
|
| 308 |
+
ffn_output = self.ffn(attention_output)
|
| 309 |
+
ffn_output = self.activation(ffn_output)
|
| 310 |
+
ffn_output = self.ffn_output(ffn_output)
|
| 311 |
+
ffn_output = self.dropout(ffn_output, deterministic=deterministic)
|
| 312 |
+
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
|
| 313 |
+
|
| 314 |
+
outputs = (hidden_states,)
|
| 315 |
+
|
| 316 |
+
if output_attentions:
|
| 317 |
+
outputs += (attention_outputs[1],)
|
| 318 |
+
return outputs
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class FlaxAlbertLayerCollection(nn.Module):
|
| 322 |
+
config: AlbertConfig
|
| 323 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 324 |
+
|
| 325 |
+
def setup(self):
|
| 326 |
+
self.layers = [
|
| 327 |
+
FlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num)
|
| 328 |
+
]
|
| 329 |
+
|
| 330 |
+
def __call__(
|
| 331 |
+
self,
|
| 332 |
+
hidden_states,
|
| 333 |
+
attention_mask,
|
| 334 |
+
deterministic: bool = True,
|
| 335 |
+
output_attentions: bool = False,
|
| 336 |
+
output_hidden_states: bool = False,
|
| 337 |
+
):
|
| 338 |
+
layer_hidden_states = ()
|
| 339 |
+
layer_attentions = ()
|
| 340 |
+
|
| 341 |
+
for layer_index, albert_layer in enumerate(self.layers):
|
| 342 |
+
layer_output = albert_layer(
|
| 343 |
+
hidden_states,
|
| 344 |
+
attention_mask,
|
| 345 |
+
deterministic=deterministic,
|
| 346 |
+
output_attentions=output_attentions,
|
| 347 |
+
)
|
| 348 |
+
hidden_states = layer_output[0]
|
| 349 |
+
|
| 350 |
+
if output_attentions:
|
| 351 |
+
layer_attentions = layer_attentions + (layer_output[1],)
|
| 352 |
+
|
| 353 |
+
if output_hidden_states:
|
| 354 |
+
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
| 355 |
+
|
| 356 |
+
outputs = (hidden_states,)
|
| 357 |
+
if output_hidden_states:
|
| 358 |
+
outputs = outputs + (layer_hidden_states,)
|
| 359 |
+
if output_attentions:
|
| 360 |
+
outputs = outputs + (layer_attentions,)
|
| 361 |
+
return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class FlaxAlbertLayerCollections(nn.Module):
|
| 365 |
+
config: AlbertConfig
|
| 366 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 367 |
+
layer_index: Optional[str] = None
|
| 368 |
+
|
| 369 |
+
def setup(self):
|
| 370 |
+
self.albert_layers = FlaxAlbertLayerCollection(self.config, dtype=self.dtype)
|
| 371 |
+
|
| 372 |
+
def __call__(
|
| 373 |
+
self,
|
| 374 |
+
hidden_states,
|
| 375 |
+
attention_mask,
|
| 376 |
+
deterministic: bool = True,
|
| 377 |
+
output_attentions: bool = False,
|
| 378 |
+
output_hidden_states: bool = False,
|
| 379 |
+
):
|
| 380 |
+
outputs = self.albert_layers(
|
| 381 |
+
hidden_states,
|
| 382 |
+
attention_mask,
|
| 383 |
+
deterministic=deterministic,
|
| 384 |
+
output_attentions=output_attentions,
|
| 385 |
+
output_hidden_states=output_hidden_states,
|
| 386 |
+
)
|
| 387 |
+
return outputs
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class FlaxAlbertLayerGroups(nn.Module):
|
| 391 |
+
config: AlbertConfig
|
| 392 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 393 |
+
|
| 394 |
+
def setup(self):
|
| 395 |
+
self.layers = [
|
| 396 |
+
FlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype)
|
| 397 |
+
for i in range(self.config.num_hidden_groups)
|
| 398 |
+
]
|
| 399 |
+
|
| 400 |
+
def __call__(
|
| 401 |
+
self,
|
| 402 |
+
hidden_states,
|
| 403 |
+
attention_mask,
|
| 404 |
+
deterministic: bool = True,
|
| 405 |
+
output_attentions: bool = False,
|
| 406 |
+
output_hidden_states: bool = False,
|
| 407 |
+
return_dict: bool = True,
|
| 408 |
+
):
|
| 409 |
+
all_attentions = () if output_attentions else None
|
| 410 |
+
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
| 411 |
+
|
| 412 |
+
for i in range(self.config.num_hidden_layers):
|
| 413 |
+
# Index of the hidden group
|
| 414 |
+
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
| 415 |
+
layer_group_output = self.layers[group_idx](
|
| 416 |
+
hidden_states,
|
| 417 |
+
attention_mask,
|
| 418 |
+
deterministic=deterministic,
|
| 419 |
+
output_attentions=output_attentions,
|
| 420 |
+
output_hidden_states=output_hidden_states,
|
| 421 |
+
)
|
| 422 |
+
hidden_states = layer_group_output[0]
|
| 423 |
+
|
| 424 |
+
if output_attentions:
|
| 425 |
+
all_attentions = all_attentions + layer_group_output[-1]
|
| 426 |
+
|
| 427 |
+
if output_hidden_states:
|
| 428 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 429 |
+
|
| 430 |
+
if not return_dict:
|
| 431 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
| 432 |
+
return FlaxBaseModelOutput(
|
| 433 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class FlaxAlbertEncoder(nn.Module):
|
| 438 |
+
config: AlbertConfig
|
| 439 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 440 |
+
|
| 441 |
+
def setup(self):
|
| 442 |
+
self.embedding_hidden_mapping_in = nn.Dense(
|
| 443 |
+
self.config.hidden_size,
|
| 444 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 445 |
+
dtype=self.dtype,
|
| 446 |
+
)
|
| 447 |
+
self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype)
|
| 448 |
+
|
| 449 |
+
def __call__(
|
| 450 |
+
self,
|
| 451 |
+
hidden_states,
|
| 452 |
+
attention_mask,
|
| 453 |
+
deterministic: bool = True,
|
| 454 |
+
output_attentions: bool = False,
|
| 455 |
+
output_hidden_states: bool = False,
|
| 456 |
+
return_dict: bool = True,
|
| 457 |
+
):
|
| 458 |
+
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
| 459 |
+
return self.albert_layer_groups(
|
| 460 |
+
hidden_states,
|
| 461 |
+
attention_mask,
|
| 462 |
+
deterministic=deterministic,
|
| 463 |
+
output_attentions=output_attentions,
|
| 464 |
+
output_hidden_states=output_hidden_states,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class FlaxAlbertOnlyMLMHead(nn.Module):
|
| 469 |
+
config: AlbertConfig
|
| 470 |
+
dtype: jnp.dtype = jnp.float32
|
| 471 |
+
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
| 472 |
+
|
| 473 |
+
def setup(self):
|
| 474 |
+
self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype)
|
| 475 |
+
self.activation = ACT2FN[self.config.hidden_act]
|
| 476 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 477 |
+
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
|
| 478 |
+
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
|
| 479 |
+
|
| 480 |
+
def __call__(self, hidden_states, shared_embedding=None):
|
| 481 |
+
hidden_states = self.dense(hidden_states)
|
| 482 |
+
hidden_states = self.activation(hidden_states)
|
| 483 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 484 |
+
|
| 485 |
+
if shared_embedding is not None:
|
| 486 |
+
hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
| 487 |
+
else:
|
| 488 |
+
hidden_states = self.decoder(hidden_states)
|
| 489 |
+
|
| 490 |
+
hidden_states += self.bias
|
| 491 |
+
return hidden_states
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
class FlaxAlbertSOPHead(nn.Module):
|
| 495 |
+
config: AlbertConfig
|
| 496 |
+
dtype: jnp.dtype = jnp.float32
|
| 497 |
+
|
| 498 |
+
def setup(self):
|
| 499 |
+
self.dropout = nn.Dropout(self.config.classifier_dropout_prob)
|
| 500 |
+
self.classifier = nn.Dense(2, dtype=self.dtype)
|
| 501 |
+
|
| 502 |
+
def __call__(self, pooled_output, deterministic=True):
|
| 503 |
+
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
| 504 |
+
logits = self.classifier(pooled_output)
|
| 505 |
+
return logits
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel):
|
| 509 |
+
"""
|
| 510 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 511 |
+
models.
|
| 512 |
+
"""
|
| 513 |
+
|
| 514 |
+
config_class = AlbertConfig
|
| 515 |
+
base_model_prefix = "albert"
|
| 516 |
+
module_class: nn.Module = None
|
| 517 |
+
|
| 518 |
+
def __init__(
|
| 519 |
+
self,
|
| 520 |
+
config: AlbertConfig,
|
| 521 |
+
input_shape: tuple = (1, 1),
|
| 522 |
+
seed: int = 0,
|
| 523 |
+
dtype: jnp.dtype = jnp.float32,
|
| 524 |
+
_do_init: bool = True,
|
| 525 |
+
**kwargs,
|
| 526 |
+
):
|
| 527 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
| 528 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
| 529 |
+
|
| 530 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
|
| 531 |
+
# init input tensors
|
| 532 |
+
input_ids = jnp.zeros(input_shape, dtype="i4")
|
| 533 |
+
token_type_ids = jnp.zeros_like(input_ids)
|
| 534 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
| 535 |
+
attention_mask = jnp.ones_like(input_ids)
|
| 536 |
+
|
| 537 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
| 538 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
| 539 |
+
|
| 540 |
+
random_params = self.module.init(
|
| 541 |
+
rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False
|
| 542 |
+
)["params"]
|
| 543 |
+
|
| 544 |
+
if params is not None:
|
| 545 |
+
random_params = flatten_dict(unfreeze(random_params))
|
| 546 |
+
params = flatten_dict(unfreeze(params))
|
| 547 |
+
for missing_key in self._missing_keys:
|
| 548 |
+
params[missing_key] = random_params[missing_key]
|
| 549 |
+
self._missing_keys = set()
|
| 550 |
+
return freeze(unflatten_dict(params))
|
| 551 |
+
else:
|
| 552 |
+
return random_params
|
| 553 |
+
|
| 554 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 555 |
+
def __call__(
|
| 556 |
+
self,
|
| 557 |
+
input_ids,
|
| 558 |
+
attention_mask=None,
|
| 559 |
+
token_type_ids=None,
|
| 560 |
+
position_ids=None,
|
| 561 |
+
params: Optional[dict] = None,
|
| 562 |
+
dropout_rng: jax.random.PRNGKey = None,
|
| 563 |
+
train: bool = False,
|
| 564 |
+
output_attentions: Optional[bool] = None,
|
| 565 |
+
output_hidden_states: Optional[bool] = None,
|
| 566 |
+
return_dict: Optional[bool] = None,
|
| 567 |
+
):
|
| 568 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 569 |
+
output_hidden_states = (
|
| 570 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 571 |
+
)
|
| 572 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 573 |
+
|
| 574 |
+
# init input tensors if not passed
|
| 575 |
+
if token_type_ids is None:
|
| 576 |
+
token_type_ids = jnp.zeros_like(input_ids)
|
| 577 |
+
|
| 578 |
+
if position_ids is None:
|
| 579 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
| 580 |
+
|
| 581 |
+
if attention_mask is None:
|
| 582 |
+
attention_mask = jnp.ones_like(input_ids)
|
| 583 |
+
|
| 584 |
+
# Handle any PRNG if needed
|
| 585 |
+
rngs = {}
|
| 586 |
+
if dropout_rng is not None:
|
| 587 |
+
rngs["dropout"] = dropout_rng
|
| 588 |
+
|
| 589 |
+
return self.module.apply(
|
| 590 |
+
{"params": params or self.params},
|
| 591 |
+
jnp.array(input_ids, dtype="i4"),
|
| 592 |
+
jnp.array(attention_mask, dtype="i4"),
|
| 593 |
+
jnp.array(token_type_ids, dtype="i4"),
|
| 594 |
+
jnp.array(position_ids, dtype="i4"),
|
| 595 |
+
not train,
|
| 596 |
+
output_attentions,
|
| 597 |
+
output_hidden_states,
|
| 598 |
+
return_dict,
|
| 599 |
+
rngs=rngs,
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class FlaxAlbertModule(nn.Module):
|
| 604 |
+
config: AlbertConfig
|
| 605 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 606 |
+
add_pooling_layer: bool = True
|
| 607 |
+
|
| 608 |
+
def setup(self):
|
| 609 |
+
self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype)
|
| 610 |
+
self.encoder = FlaxAlbertEncoder(self.config, dtype=self.dtype)
|
| 611 |
+
if self.add_pooling_layer:
|
| 612 |
+
self.pooler = nn.Dense(
|
| 613 |
+
self.config.hidden_size,
|
| 614 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 615 |
+
dtype=self.dtype,
|
| 616 |
+
name="pooler",
|
| 617 |
+
)
|
| 618 |
+
self.pooler_activation = nn.tanh
|
| 619 |
+
else:
|
| 620 |
+
self.pooler = None
|
| 621 |
+
self.pooler_activation = None
|
| 622 |
+
|
| 623 |
+
def __call__(
|
| 624 |
+
self,
|
| 625 |
+
input_ids,
|
| 626 |
+
attention_mask,
|
| 627 |
+
token_type_ids: Optional[np.ndarray] = None,
|
| 628 |
+
position_ids: Optional[np.ndarray] = None,
|
| 629 |
+
deterministic: bool = True,
|
| 630 |
+
output_attentions: bool = False,
|
| 631 |
+
output_hidden_states: bool = False,
|
| 632 |
+
return_dict: bool = True,
|
| 633 |
+
):
|
| 634 |
+
# make sure `token_type_ids` is correctly initialized when not passed
|
| 635 |
+
if token_type_ids is None:
|
| 636 |
+
token_type_ids = jnp.zeros_like(input_ids)
|
| 637 |
+
|
| 638 |
+
# make sure `position_ids` is correctly initialized when not passed
|
| 639 |
+
if position_ids is None:
|
| 640 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
| 641 |
+
|
| 642 |
+
hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic)
|
| 643 |
+
|
| 644 |
+
outputs = self.encoder(
|
| 645 |
+
hidden_states,
|
| 646 |
+
attention_mask,
|
| 647 |
+
deterministic=deterministic,
|
| 648 |
+
output_attentions=output_attentions,
|
| 649 |
+
output_hidden_states=output_hidden_states,
|
| 650 |
+
return_dict=return_dict,
|
| 651 |
+
)
|
| 652 |
+
hidden_states = outputs[0]
|
| 653 |
+
if self.add_pooling_layer:
|
| 654 |
+
pooled = self.pooler(hidden_states[:, 0])
|
| 655 |
+
pooled = self.pooler_activation(pooled)
|
| 656 |
+
else:
|
| 657 |
+
pooled = None
|
| 658 |
+
|
| 659 |
+
if not return_dict:
|
| 660 |
+
# if pooled is None, don't return it
|
| 661 |
+
if pooled is None:
|
| 662 |
+
return (hidden_states,) + outputs[1:]
|
| 663 |
+
return (hidden_states, pooled) + outputs[1:]
|
| 664 |
+
|
| 665 |
+
return FlaxBaseModelOutputWithPooling(
|
| 666 |
+
last_hidden_state=hidden_states,
|
| 667 |
+
pooler_output=pooled,
|
| 668 |
+
hidden_states=outputs.hidden_states,
|
| 669 |
+
attentions=outputs.attentions,
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
@add_start_docstrings(
|
| 674 |
+
"The bare Albert Model transformer outputting raw hidden-states without any specific head on top.",
|
| 675 |
+
ALBERT_START_DOCSTRING,
|
| 676 |
+
)
|
| 677 |
+
class FlaxAlbertModel(FlaxAlbertPreTrainedModel):
|
| 678 |
+
module_class = FlaxAlbertModule
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
append_call_sample_docstring(FlaxAlbertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
class FlaxAlbertForPreTrainingModule(nn.Module):
|
| 685 |
+
config: AlbertConfig
|
| 686 |
+
dtype: jnp.dtype = jnp.float32
|
| 687 |
+
|
| 688 |
+
def setup(self):
|
| 689 |
+
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
|
| 690 |
+
self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
|
| 691 |
+
self.sop_classifier = FlaxAlbertSOPHead(config=self.config, dtype=self.dtype)
|
| 692 |
+
|
| 693 |
+
def __call__(
|
| 694 |
+
self,
|
| 695 |
+
input_ids,
|
| 696 |
+
attention_mask,
|
| 697 |
+
token_type_ids,
|
| 698 |
+
position_ids,
|
| 699 |
+
deterministic: bool = True,
|
| 700 |
+
output_attentions: bool = False,
|
| 701 |
+
output_hidden_states: bool = False,
|
| 702 |
+
return_dict: bool = True,
|
| 703 |
+
):
|
| 704 |
+
# Model
|
| 705 |
+
outputs = self.albert(
|
| 706 |
+
input_ids,
|
| 707 |
+
attention_mask,
|
| 708 |
+
token_type_ids,
|
| 709 |
+
position_ids,
|
| 710 |
+
deterministic=deterministic,
|
| 711 |
+
output_attentions=output_attentions,
|
| 712 |
+
output_hidden_states=output_hidden_states,
|
| 713 |
+
return_dict=return_dict,
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
if self.config.tie_word_embeddings:
|
| 717 |
+
shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
| 718 |
+
else:
|
| 719 |
+
shared_embedding = None
|
| 720 |
+
|
| 721 |
+
hidden_states = outputs[0]
|
| 722 |
+
pooled_output = outputs[1]
|
| 723 |
+
|
| 724 |
+
prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
|
| 725 |
+
sop_scores = self.sop_classifier(pooled_output, deterministic=deterministic)
|
| 726 |
+
|
| 727 |
+
if not return_dict:
|
| 728 |
+
return (prediction_scores, sop_scores) + outputs[2:]
|
| 729 |
+
|
| 730 |
+
return FlaxAlbertForPreTrainingOutput(
|
| 731 |
+
prediction_logits=prediction_scores,
|
| 732 |
+
sop_logits=sop_scores,
|
| 733 |
+
hidden_states=outputs.hidden_states,
|
| 734 |
+
attentions=outputs.attentions,
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
@add_start_docstrings(
|
| 739 |
+
"""
|
| 740 |
+
Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
|
| 741 |
+
`sentence order prediction (classification)` head.
|
| 742 |
+
""",
|
| 743 |
+
ALBERT_START_DOCSTRING,
|
| 744 |
+
)
|
| 745 |
+
class FlaxAlbertForPreTraining(FlaxAlbertPreTrainedModel):
|
| 746 |
+
module_class = FlaxAlbertForPreTrainingModule
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING = """
|
| 750 |
+
Returns:
|
| 751 |
+
|
| 752 |
+
Example:
|
| 753 |
+
|
| 754 |
+
```python
|
| 755 |
+
>>> from transformers import AutoTokenizer, FlaxAlbertForPreTraining
|
| 756 |
+
|
| 757 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
|
| 758 |
+
>>> model = FlaxAlbertForPreTraining.from_pretrained("albert/albert-base-v2")
|
| 759 |
+
|
| 760 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
|
| 761 |
+
>>> outputs = model(**inputs)
|
| 762 |
+
|
| 763 |
+
>>> prediction_logits = outputs.prediction_logits
|
| 764 |
+
>>> seq_relationship_logits = outputs.sop_logits
|
| 765 |
+
```
|
| 766 |
+
"""
|
| 767 |
+
|
| 768 |
+
overwrite_call_docstring(
|
| 769 |
+
FlaxAlbertForPreTraining,
|
| 770 |
+
ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING,
|
| 771 |
+
)
|
| 772 |
+
append_replace_return_docstrings(
|
| 773 |
+
FlaxAlbertForPreTraining, output_type=FlaxAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
class FlaxAlbertForMaskedLMModule(nn.Module):
|
| 778 |
+
config: AlbertConfig
|
| 779 |
+
dtype: jnp.dtype = jnp.float32
|
| 780 |
+
|
| 781 |
+
def setup(self):
|
| 782 |
+
self.albert = FlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
| 783 |
+
self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
|
| 784 |
+
|
| 785 |
+
def __call__(
|
| 786 |
+
self,
|
| 787 |
+
input_ids,
|
| 788 |
+
attention_mask,
|
| 789 |
+
token_type_ids,
|
| 790 |
+
position_ids,
|
| 791 |
+
deterministic: bool = True,
|
| 792 |
+
output_attentions: bool = False,
|
| 793 |
+
output_hidden_states: bool = False,
|
| 794 |
+
return_dict: bool = True,
|
| 795 |
+
):
|
| 796 |
+
# Model
|
| 797 |
+
outputs = self.albert(
|
| 798 |
+
input_ids,
|
| 799 |
+
attention_mask,
|
| 800 |
+
token_type_ids,
|
| 801 |
+
position_ids,
|
| 802 |
+
deterministic=deterministic,
|
| 803 |
+
output_attentions=output_attentions,
|
| 804 |
+
output_hidden_states=output_hidden_states,
|
| 805 |
+
return_dict=return_dict,
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
hidden_states = outputs[0]
|
| 809 |
+
if self.config.tie_word_embeddings:
|
| 810 |
+
shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
| 811 |
+
else:
|
| 812 |
+
shared_embedding = None
|
| 813 |
+
|
| 814 |
+
# Compute the prediction scores
|
| 815 |
+
logits = self.predictions(hidden_states, shared_embedding=shared_embedding)
|
| 816 |
+
|
| 817 |
+
if not return_dict:
|
| 818 |
+
return (logits,) + outputs[1:]
|
| 819 |
+
|
| 820 |
+
return FlaxMaskedLMOutput(
|
| 821 |
+
logits=logits,
|
| 822 |
+
hidden_states=outputs.hidden_states,
|
| 823 |
+
attentions=outputs.attentions,
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
@add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING)
|
| 828 |
+
class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel):
|
| 829 |
+
module_class = FlaxAlbertForMaskedLMModule
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
append_call_sample_docstring(
|
| 833 |
+
FlaxAlbertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC, revision="refs/pr/11"
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
class FlaxAlbertForSequenceClassificationModule(nn.Module):
|
| 838 |
+
config: AlbertConfig
|
| 839 |
+
dtype: jnp.dtype = jnp.float32
|
| 840 |
+
|
| 841 |
+
def setup(self):
|
| 842 |
+
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
|
| 843 |
+
classifier_dropout = (
|
| 844 |
+
self.config.classifier_dropout_prob
|
| 845 |
+
if self.config.classifier_dropout_prob is not None
|
| 846 |
+
else self.config.hidden_dropout_prob
|
| 847 |
+
)
|
| 848 |
+
self.dropout = nn.Dropout(rate=classifier_dropout)
|
| 849 |
+
self.classifier = nn.Dense(
|
| 850 |
+
self.config.num_labels,
|
| 851 |
+
dtype=self.dtype,
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
def __call__(
|
| 855 |
+
self,
|
| 856 |
+
input_ids,
|
| 857 |
+
attention_mask,
|
| 858 |
+
token_type_ids,
|
| 859 |
+
position_ids,
|
| 860 |
+
deterministic: bool = True,
|
| 861 |
+
output_attentions: bool = False,
|
| 862 |
+
output_hidden_states: bool = False,
|
| 863 |
+
return_dict: bool = True,
|
| 864 |
+
):
|
| 865 |
+
# Model
|
| 866 |
+
outputs = self.albert(
|
| 867 |
+
input_ids,
|
| 868 |
+
attention_mask,
|
| 869 |
+
token_type_ids,
|
| 870 |
+
position_ids,
|
| 871 |
+
deterministic=deterministic,
|
| 872 |
+
output_attentions=output_attentions,
|
| 873 |
+
output_hidden_states=output_hidden_states,
|
| 874 |
+
return_dict=return_dict,
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
pooled_output = outputs[1]
|
| 878 |
+
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
| 879 |
+
logits = self.classifier(pooled_output)
|
| 880 |
+
|
| 881 |
+
if not return_dict:
|
| 882 |
+
return (logits,) + outputs[2:]
|
| 883 |
+
|
| 884 |
+
return FlaxSequenceClassifierOutput(
|
| 885 |
+
logits=logits,
|
| 886 |
+
hidden_states=outputs.hidden_states,
|
| 887 |
+
attentions=outputs.attentions,
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
@add_start_docstrings(
|
| 892 |
+
"""
|
| 893 |
+
Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
| 894 |
+
output) e.g. for GLUE tasks.
|
| 895 |
+
""",
|
| 896 |
+
ALBERT_START_DOCSTRING,
|
| 897 |
+
)
|
| 898 |
+
class FlaxAlbertForSequenceClassification(FlaxAlbertPreTrainedModel):
|
| 899 |
+
module_class = FlaxAlbertForSequenceClassificationModule
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
append_call_sample_docstring(
|
| 903 |
+
FlaxAlbertForSequenceClassification,
|
| 904 |
+
_CHECKPOINT_FOR_DOC,
|
| 905 |
+
FlaxSequenceClassifierOutput,
|
| 906 |
+
_CONFIG_FOR_DOC,
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
class FlaxAlbertForMultipleChoiceModule(nn.Module):
|
| 911 |
+
config: AlbertConfig
|
| 912 |
+
dtype: jnp.dtype = jnp.float32
|
| 913 |
+
|
| 914 |
+
def setup(self):
|
| 915 |
+
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
|
| 916 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 917 |
+
self.classifier = nn.Dense(1, dtype=self.dtype)
|
| 918 |
+
|
| 919 |
+
def __call__(
|
| 920 |
+
self,
|
| 921 |
+
input_ids,
|
| 922 |
+
attention_mask,
|
| 923 |
+
token_type_ids,
|
| 924 |
+
position_ids,
|
| 925 |
+
deterministic: bool = True,
|
| 926 |
+
output_attentions: bool = False,
|
| 927 |
+
output_hidden_states: bool = False,
|
| 928 |
+
return_dict: bool = True,
|
| 929 |
+
):
|
| 930 |
+
num_choices = input_ids.shape[1]
|
| 931 |
+
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
|
| 932 |
+
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
|
| 933 |
+
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
|
| 934 |
+
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
|
| 935 |
+
|
| 936 |
+
# Model
|
| 937 |
+
outputs = self.albert(
|
| 938 |
+
input_ids,
|
| 939 |
+
attention_mask,
|
| 940 |
+
token_type_ids,
|
| 941 |
+
position_ids,
|
| 942 |
+
deterministic=deterministic,
|
| 943 |
+
output_attentions=output_attentions,
|
| 944 |
+
output_hidden_states=output_hidden_states,
|
| 945 |
+
return_dict=return_dict,
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
pooled_output = outputs[1]
|
| 949 |
+
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
| 950 |
+
logits = self.classifier(pooled_output)
|
| 951 |
+
|
| 952 |
+
reshaped_logits = logits.reshape(-1, num_choices)
|
| 953 |
+
|
| 954 |
+
if not return_dict:
|
| 955 |
+
return (reshaped_logits,) + outputs[2:]
|
| 956 |
+
|
| 957 |
+
return FlaxMultipleChoiceModelOutput(
|
| 958 |
+
logits=reshaped_logits,
|
| 959 |
+
hidden_states=outputs.hidden_states,
|
| 960 |
+
attentions=outputs.attentions,
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
@add_start_docstrings(
|
| 965 |
+
"""
|
| 966 |
+
Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
| 967 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
| 968 |
+
""",
|
| 969 |
+
ALBERT_START_DOCSTRING,
|
| 970 |
+
)
|
| 971 |
+
class FlaxAlbertForMultipleChoice(FlaxAlbertPreTrainedModel):
|
| 972 |
+
module_class = FlaxAlbertForMultipleChoiceModule
|
| 973 |
+
|
| 974 |
+
|
| 975 |
+
overwrite_call_docstring(
|
| 976 |
+
FlaxAlbertForMultipleChoice, ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
| 977 |
+
)
|
| 978 |
+
append_call_sample_docstring(
|
| 979 |
+
FlaxAlbertForMultipleChoice,
|
| 980 |
+
_CHECKPOINT_FOR_DOC,
|
| 981 |
+
FlaxMultipleChoiceModelOutput,
|
| 982 |
+
_CONFIG_FOR_DOC,
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
|
| 986 |
+
class FlaxAlbertForTokenClassificationModule(nn.Module):
|
| 987 |
+
config: AlbertConfig
|
| 988 |
+
dtype: jnp.dtype = jnp.float32
|
| 989 |
+
|
| 990 |
+
def setup(self):
|
| 991 |
+
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
| 992 |
+
classifier_dropout = (
|
| 993 |
+
self.config.classifier_dropout_prob
|
| 994 |
+
if self.config.classifier_dropout_prob is not None
|
| 995 |
+
else self.config.hidden_dropout_prob
|
| 996 |
+
)
|
| 997 |
+
self.dropout = nn.Dropout(rate=classifier_dropout)
|
| 998 |
+
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
| 999 |
+
|
| 1000 |
+
def __call__(
|
| 1001 |
+
self,
|
| 1002 |
+
input_ids,
|
| 1003 |
+
attention_mask,
|
| 1004 |
+
token_type_ids,
|
| 1005 |
+
position_ids,
|
| 1006 |
+
deterministic: bool = True,
|
| 1007 |
+
output_attentions: bool = False,
|
| 1008 |
+
output_hidden_states: bool = False,
|
| 1009 |
+
return_dict: bool = True,
|
| 1010 |
+
):
|
| 1011 |
+
# Model
|
| 1012 |
+
outputs = self.albert(
|
| 1013 |
+
input_ids,
|
| 1014 |
+
attention_mask,
|
| 1015 |
+
token_type_ids,
|
| 1016 |
+
position_ids,
|
| 1017 |
+
deterministic=deterministic,
|
| 1018 |
+
output_attentions=output_attentions,
|
| 1019 |
+
output_hidden_states=output_hidden_states,
|
| 1020 |
+
return_dict=return_dict,
|
| 1021 |
+
)
|
| 1022 |
+
|
| 1023 |
+
hidden_states = outputs[0]
|
| 1024 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 1025 |
+
logits = self.classifier(hidden_states)
|
| 1026 |
+
|
| 1027 |
+
if not return_dict:
|
| 1028 |
+
return (logits,) + outputs[1:]
|
| 1029 |
+
|
| 1030 |
+
return FlaxTokenClassifierOutput(
|
| 1031 |
+
logits=logits,
|
| 1032 |
+
hidden_states=outputs.hidden_states,
|
| 1033 |
+
attentions=outputs.attentions,
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
+
@add_start_docstrings(
|
| 1038 |
+
"""
|
| 1039 |
+
Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1040 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1041 |
+
""",
|
| 1042 |
+
ALBERT_START_DOCSTRING,
|
| 1043 |
+
)
|
| 1044 |
+
class FlaxAlbertForTokenClassification(FlaxAlbertPreTrainedModel):
|
| 1045 |
+
module_class = FlaxAlbertForTokenClassificationModule
|
| 1046 |
+
|
| 1047 |
+
|
| 1048 |
+
append_call_sample_docstring(
|
| 1049 |
+
FlaxAlbertForTokenClassification,
|
| 1050 |
+
_CHECKPOINT_FOR_DOC,
|
| 1051 |
+
FlaxTokenClassifierOutput,
|
| 1052 |
+
_CONFIG_FOR_DOC,
|
| 1053 |
+
)
|
| 1054 |
+
|
| 1055 |
+
|
| 1056 |
+
class FlaxAlbertForQuestionAnsweringModule(nn.Module):
|
| 1057 |
+
config: AlbertConfig
|
| 1058 |
+
dtype: jnp.dtype = jnp.float32
|
| 1059 |
+
|
| 1060 |
+
def setup(self):
|
| 1061 |
+
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
| 1062 |
+
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
| 1063 |
+
|
| 1064 |
+
def __call__(
|
| 1065 |
+
self,
|
| 1066 |
+
input_ids,
|
| 1067 |
+
attention_mask,
|
| 1068 |
+
token_type_ids,
|
| 1069 |
+
position_ids,
|
| 1070 |
+
deterministic: bool = True,
|
| 1071 |
+
output_attentions: bool = False,
|
| 1072 |
+
output_hidden_states: bool = False,
|
| 1073 |
+
return_dict: bool = True,
|
| 1074 |
+
):
|
| 1075 |
+
# Model
|
| 1076 |
+
outputs = self.albert(
|
| 1077 |
+
input_ids,
|
| 1078 |
+
attention_mask,
|
| 1079 |
+
token_type_ids,
|
| 1080 |
+
position_ids,
|
| 1081 |
+
deterministic=deterministic,
|
| 1082 |
+
output_attentions=output_attentions,
|
| 1083 |
+
output_hidden_states=output_hidden_states,
|
| 1084 |
+
return_dict=return_dict,
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
hidden_states = outputs[0]
|
| 1088 |
+
|
| 1089 |
+
logits = self.qa_outputs(hidden_states)
|
| 1090 |
+
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
|
| 1091 |
+
start_logits = start_logits.squeeze(-1)
|
| 1092 |
+
end_logits = end_logits.squeeze(-1)
|
| 1093 |
+
|
| 1094 |
+
if not return_dict:
|
| 1095 |
+
return (start_logits, end_logits) + outputs[1:]
|
| 1096 |
+
|
| 1097 |
+
return FlaxQuestionAnsweringModelOutput(
|
| 1098 |
+
start_logits=start_logits,
|
| 1099 |
+
end_logits=end_logits,
|
| 1100 |
+
hidden_states=outputs.hidden_states,
|
| 1101 |
+
attentions=outputs.attentions,
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
|
| 1105 |
+
@add_start_docstrings(
|
| 1106 |
+
"""
|
| 1107 |
+
Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1108 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1109 |
+
""",
|
| 1110 |
+
ALBERT_START_DOCSTRING,
|
| 1111 |
+
)
|
| 1112 |
+
class FlaxAlbertForQuestionAnswering(FlaxAlbertPreTrainedModel):
|
| 1113 |
+
module_class = FlaxAlbertForQuestionAnsweringModule
|
| 1114 |
+
|
| 1115 |
+
|
| 1116 |
+
append_call_sample_docstring(
|
| 1117 |
+
FlaxAlbertForQuestionAnswering,
|
| 1118 |
+
_CHECKPOINT_FOR_DOC,
|
| 1119 |
+
FlaxQuestionAnsweringModelOutput,
|
| 1120 |
+
_CONFIG_FOR_DOC,
|
| 1121 |
+
)
|
| 1122 |
+
|
| 1123 |
+
__all__ = [
|
| 1124 |
+
"FlaxAlbertPreTrainedModel",
|
| 1125 |
+
"FlaxAlbertModel",
|
| 1126 |
+
"FlaxAlbertForPreTraining",
|
| 1127 |
+
"FlaxAlbertForMaskedLM",
|
| 1128 |
+
"FlaxAlbertForSequenceClassification",
|
| 1129 |
+
"FlaxAlbertForMultipleChoice",
|
| 1130 |
+
"FlaxAlbertForTokenClassification",
|
| 1131 |
+
"FlaxAlbertForQuestionAnswering",
|
| 1132 |
+
]
|
venv/lib/python3.13/site-packages/transformers/models/albert/modeling_tf_albert.py
ADDED
|
@@ -0,0 +1,1572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""TF 2.0 ALBERT model."""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import math
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import tensorflow as tf
|
| 25 |
+
|
| 26 |
+
from ...activations_tf import get_tf_activation
|
| 27 |
+
from ...modeling_tf_outputs import (
|
| 28 |
+
TFBaseModelOutput,
|
| 29 |
+
TFBaseModelOutputWithPooling,
|
| 30 |
+
TFMaskedLMOutput,
|
| 31 |
+
TFMultipleChoiceModelOutput,
|
| 32 |
+
TFQuestionAnsweringModelOutput,
|
| 33 |
+
TFSequenceClassifierOutput,
|
| 34 |
+
TFTokenClassifierOutput,
|
| 35 |
+
)
|
| 36 |
+
from ...modeling_tf_utils import (
|
| 37 |
+
TFMaskedLanguageModelingLoss,
|
| 38 |
+
TFModelInputType,
|
| 39 |
+
TFMultipleChoiceLoss,
|
| 40 |
+
TFPreTrainedModel,
|
| 41 |
+
TFQuestionAnsweringLoss,
|
| 42 |
+
TFSequenceClassificationLoss,
|
| 43 |
+
TFTokenClassificationLoss,
|
| 44 |
+
get_initializer,
|
| 45 |
+
keras,
|
| 46 |
+
keras_serializable,
|
| 47 |
+
unpack_inputs,
|
| 48 |
+
)
|
| 49 |
+
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
|
| 50 |
+
from ...utils import (
|
| 51 |
+
ModelOutput,
|
| 52 |
+
add_code_sample_docstrings,
|
| 53 |
+
add_start_docstrings,
|
| 54 |
+
add_start_docstrings_to_model_forward,
|
| 55 |
+
logging,
|
| 56 |
+
replace_return_docstrings,
|
| 57 |
+
)
|
| 58 |
+
from .configuration_albert import AlbertConfig
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
logger = logging.get_logger(__name__)
|
| 62 |
+
|
| 63 |
+
_CHECKPOINT_FOR_DOC = "albert/albert-base-v2"
|
| 64 |
+
_CONFIG_FOR_DOC = "AlbertConfig"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TFAlbertPreTrainingLoss:
|
| 68 |
+
"""
|
| 69 |
+
Loss function suitable for ALBERT pretraining, that is, the task of pretraining a language model by combining SOP +
|
| 70 |
+
MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
|
| 74 |
+
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
|
| 75 |
+
if self.config.tf_legacy_loss:
|
| 76 |
+
# make sure only labels that are not equal to -100
|
| 77 |
+
# are taken into account as loss
|
| 78 |
+
masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
|
| 79 |
+
masked_lm_reduced_logits = tf.boolean_mask(
|
| 80 |
+
tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
|
| 81 |
+
mask=masked_lm_active_loss,
|
| 82 |
+
)
|
| 83 |
+
masked_lm_labels = tf.boolean_mask(
|
| 84 |
+
tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
|
| 85 |
+
)
|
| 86 |
+
sentence_order_active_loss = tf.not_equal(
|
| 87 |
+
tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100
|
| 88 |
+
)
|
| 89 |
+
sentence_order_reduced_logits = tf.boolean_mask(
|
| 90 |
+
tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss
|
| 91 |
+
)
|
| 92 |
+
sentence_order_label = tf.boolean_mask(
|
| 93 |
+
tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss
|
| 94 |
+
)
|
| 95 |
+
masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
|
| 96 |
+
sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits)
|
| 97 |
+
masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0]))
|
| 98 |
+
masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)
|
| 99 |
+
|
| 100 |
+
return masked_lm_loss + sentence_order_loss
|
| 101 |
+
|
| 102 |
+
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
|
| 103 |
+
unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
|
| 104 |
+
# make sure only labels that are not equal to -100
|
| 105 |
+
# are taken into account for the loss computation
|
| 106 |
+
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
|
| 107 |
+
masked_lm_losses = unmasked_lm_losses * lm_loss_mask
|
| 108 |
+
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask)
|
| 109 |
+
|
| 110 |
+
sop_logits = tf.reshape(logits[1], (-1, 2))
|
| 111 |
+
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
|
| 112 |
+
unmasked_sop_loss = loss_fn(y_true=tf.nn.relu(labels["sentence_order_label"]), y_pred=sop_logits)
|
| 113 |
+
sop_loss_mask = tf.cast(labels["sentence_order_label"] != -100, dtype=unmasked_sop_loss.dtype)
|
| 114 |
+
|
| 115 |
+
masked_sop_loss = unmasked_sop_loss * sop_loss_mask
|
| 116 |
+
reduced_masked_sop_loss = tf.reduce_sum(masked_sop_loss) / tf.reduce_sum(sop_loss_mask)
|
| 117 |
+
|
| 118 |
+
return tf.reshape(reduced_masked_lm_loss + reduced_masked_sop_loss, (1,))
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class TFAlbertEmbeddings(keras.layers.Layer):
|
| 122 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 123 |
+
|
| 124 |
+
def __init__(self, config: AlbertConfig, **kwargs):
|
| 125 |
+
super().__init__(**kwargs)
|
| 126 |
+
|
| 127 |
+
self.config = config
|
| 128 |
+
self.embedding_size = config.embedding_size
|
| 129 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 130 |
+
self.initializer_range = config.initializer_range
|
| 131 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 132 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 133 |
+
|
| 134 |
+
def build(self, input_shape=None):
|
| 135 |
+
with tf.name_scope("word_embeddings"):
|
| 136 |
+
self.weight = self.add_weight(
|
| 137 |
+
name="weight",
|
| 138 |
+
shape=[self.config.vocab_size, self.embedding_size],
|
| 139 |
+
initializer=get_initializer(self.initializer_range),
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
with tf.name_scope("token_type_embeddings"):
|
| 143 |
+
self.token_type_embeddings = self.add_weight(
|
| 144 |
+
name="embeddings",
|
| 145 |
+
shape=[self.config.type_vocab_size, self.embedding_size],
|
| 146 |
+
initializer=get_initializer(self.initializer_range),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
with tf.name_scope("position_embeddings"):
|
| 150 |
+
self.position_embeddings = self.add_weight(
|
| 151 |
+
name="embeddings",
|
| 152 |
+
shape=[self.max_position_embeddings, self.embedding_size],
|
| 153 |
+
initializer=get_initializer(self.initializer_range),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if self.built:
|
| 157 |
+
return
|
| 158 |
+
self.built = True
|
| 159 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 160 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 161 |
+
self.LayerNorm.build([None, None, self.config.embedding_size])
|
| 162 |
+
|
| 163 |
+
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
|
| 164 |
+
def call(
|
| 165 |
+
self,
|
| 166 |
+
input_ids: tf.Tensor | None = None,
|
| 167 |
+
position_ids: tf.Tensor | None = None,
|
| 168 |
+
token_type_ids: tf.Tensor | None = None,
|
| 169 |
+
inputs_embeds: tf.Tensor | None = None,
|
| 170 |
+
past_key_values_length=0,
|
| 171 |
+
training: bool = False,
|
| 172 |
+
) -> tf.Tensor:
|
| 173 |
+
"""
|
| 174 |
+
Applies embedding based on inputs tensor.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
final_embeddings (`tf.Tensor`): output embedding tensor.
|
| 178 |
+
"""
|
| 179 |
+
if input_ids is None and inputs_embeds is None:
|
| 180 |
+
raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
|
| 181 |
+
|
| 182 |
+
if input_ids is not None:
|
| 183 |
+
check_embeddings_within_bounds(input_ids, self.config.vocab_size)
|
| 184 |
+
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
|
| 185 |
+
|
| 186 |
+
input_shape = shape_list(inputs_embeds)[:-1]
|
| 187 |
+
|
| 188 |
+
if token_type_ids is None:
|
| 189 |
+
token_type_ids = tf.fill(dims=input_shape, value=0)
|
| 190 |
+
|
| 191 |
+
if position_ids is None:
|
| 192 |
+
position_ids = tf.expand_dims(
|
| 193 |
+
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
| 197 |
+
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
|
| 198 |
+
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
|
| 199 |
+
final_embeddings = self.LayerNorm(inputs=final_embeddings)
|
| 200 |
+
final_embeddings = self.dropout(inputs=final_embeddings, training=training)
|
| 201 |
+
|
| 202 |
+
return final_embeddings
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class TFAlbertAttention(keras.layers.Layer):
|
| 206 |
+
"""Contains the complete attention sublayer, including both dropouts and layer norm."""
|
| 207 |
+
|
| 208 |
+
def __init__(self, config: AlbertConfig, **kwargs):
|
| 209 |
+
super().__init__(**kwargs)
|
| 210 |
+
|
| 211 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 212 |
+
raise ValueError(
|
| 213 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
|
| 214 |
+
f"of attention heads ({config.num_attention_heads})"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
self.num_attention_heads = config.num_attention_heads
|
| 218 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 219 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 220 |
+
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
|
| 221 |
+
self.output_attentions = config.output_attentions
|
| 222 |
+
|
| 223 |
+
self.query = keras.layers.Dense(
|
| 224 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
| 225 |
+
)
|
| 226 |
+
self.key = keras.layers.Dense(
|
| 227 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
|
| 228 |
+
)
|
| 229 |
+
self.value = keras.layers.Dense(
|
| 230 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
| 231 |
+
)
|
| 232 |
+
self.dense = keras.layers.Dense(
|
| 233 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 234 |
+
)
|
| 235 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 236 |
+
# Two different dropout probabilities; see https://github.com/google-research/albert/blob/master/modeling.py#L971-L993
|
| 237 |
+
self.attention_dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
| 238 |
+
self.output_dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 239 |
+
self.config = config
|
| 240 |
+
|
| 241 |
+
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
| 242 |
+
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
| 243 |
+
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
| 244 |
+
|
| 245 |
+
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
|
| 246 |
+
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
| 247 |
+
|
| 248 |
+
def call(
|
| 249 |
+
self,
|
| 250 |
+
input_tensor: tf.Tensor,
|
| 251 |
+
attention_mask: tf.Tensor,
|
| 252 |
+
head_mask: tf.Tensor,
|
| 253 |
+
output_attentions: bool,
|
| 254 |
+
training: bool = False,
|
| 255 |
+
) -> tuple[tf.Tensor]:
|
| 256 |
+
batch_size = shape_list(input_tensor)[0]
|
| 257 |
+
mixed_query_layer = self.query(inputs=input_tensor)
|
| 258 |
+
mixed_key_layer = self.key(inputs=input_tensor)
|
| 259 |
+
mixed_value_layer = self.value(inputs=input_tensor)
|
| 260 |
+
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
| 261 |
+
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
| 262 |
+
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
| 263 |
+
|
| 264 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 265 |
+
# (batch size, num_heads, seq_len_q, seq_len_k)
|
| 266 |
+
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
| 267 |
+
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
|
| 268 |
+
attention_scores = tf.divide(attention_scores, dk)
|
| 269 |
+
|
| 270 |
+
if attention_mask is not None:
|
| 271 |
+
# Apply the attention mask is (precomputed for all layers in TFAlbertModel call() function)
|
| 272 |
+
attention_scores = tf.add(attention_scores, attention_mask)
|
| 273 |
+
|
| 274 |
+
# Normalize the attention scores to probabilities.
|
| 275 |
+
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
|
| 276 |
+
|
| 277 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 278 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 279 |
+
attention_probs = self.attention_dropout(inputs=attention_probs, training=training)
|
| 280 |
+
|
| 281 |
+
# Mask heads if we want to
|
| 282 |
+
if head_mask is not None:
|
| 283 |
+
attention_probs = tf.multiply(attention_probs, head_mask)
|
| 284 |
+
|
| 285 |
+
context_layer = tf.matmul(attention_probs, value_layer)
|
| 286 |
+
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
|
| 287 |
+
|
| 288 |
+
# (batch_size, seq_len_q, all_head_size)
|
| 289 |
+
context_layer = tf.reshape(tensor=context_layer, shape=(batch_size, -1, self.all_head_size))
|
| 290 |
+
self_outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 291 |
+
hidden_states = self_outputs[0]
|
| 292 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 293 |
+
hidden_states = self.output_dropout(inputs=hidden_states, training=training)
|
| 294 |
+
attention_output = self.LayerNorm(inputs=hidden_states + input_tensor)
|
| 295 |
+
|
| 296 |
+
# add attentions if we output them
|
| 297 |
+
outputs = (attention_output,) + self_outputs[1:]
|
| 298 |
+
|
| 299 |
+
return outputs
|
| 300 |
+
|
| 301 |
+
def build(self, input_shape=None):
|
| 302 |
+
if self.built:
|
| 303 |
+
return
|
| 304 |
+
self.built = True
|
| 305 |
+
if getattr(self, "query", None) is not None:
|
| 306 |
+
with tf.name_scope(self.query.name):
|
| 307 |
+
self.query.build([None, None, self.config.hidden_size])
|
| 308 |
+
if getattr(self, "key", None) is not None:
|
| 309 |
+
with tf.name_scope(self.key.name):
|
| 310 |
+
self.key.build([None, None, self.config.hidden_size])
|
| 311 |
+
if getattr(self, "value", None) is not None:
|
| 312 |
+
with tf.name_scope(self.value.name):
|
| 313 |
+
self.value.build([None, None, self.config.hidden_size])
|
| 314 |
+
if getattr(self, "dense", None) is not None:
|
| 315 |
+
with tf.name_scope(self.dense.name):
|
| 316 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 317 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 318 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 319 |
+
self.LayerNorm.build([None, None, self.config.hidden_size])
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class TFAlbertLayer(keras.layers.Layer):
|
| 323 |
+
def __init__(self, config: AlbertConfig, **kwargs):
|
| 324 |
+
super().__init__(**kwargs)
|
| 325 |
+
|
| 326 |
+
self.attention = TFAlbertAttention(config, name="attention")
|
| 327 |
+
self.ffn = keras.layers.Dense(
|
| 328 |
+
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn"
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if isinstance(config.hidden_act, str):
|
| 332 |
+
self.activation = get_tf_activation(config.hidden_act)
|
| 333 |
+
else:
|
| 334 |
+
self.activation = config.hidden_act
|
| 335 |
+
|
| 336 |
+
self.ffn_output = keras.layers.Dense(
|
| 337 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn_output"
|
| 338 |
+
)
|
| 339 |
+
self.full_layer_layer_norm = keras.layers.LayerNormalization(
|
| 340 |
+
epsilon=config.layer_norm_eps, name="full_layer_layer_norm"
|
| 341 |
+
)
|
| 342 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 343 |
+
self.config = config
|
| 344 |
+
|
| 345 |
+
def call(
|
| 346 |
+
self,
|
| 347 |
+
hidden_states: tf.Tensor,
|
| 348 |
+
attention_mask: tf.Tensor,
|
| 349 |
+
head_mask: tf.Tensor,
|
| 350 |
+
output_attentions: bool,
|
| 351 |
+
training: bool = False,
|
| 352 |
+
) -> tuple[tf.Tensor]:
|
| 353 |
+
attention_outputs = self.attention(
|
| 354 |
+
input_tensor=hidden_states,
|
| 355 |
+
attention_mask=attention_mask,
|
| 356 |
+
head_mask=head_mask,
|
| 357 |
+
output_attentions=output_attentions,
|
| 358 |
+
training=training,
|
| 359 |
+
)
|
| 360 |
+
ffn_output = self.ffn(inputs=attention_outputs[0])
|
| 361 |
+
ffn_output = self.activation(ffn_output)
|
| 362 |
+
ffn_output = self.ffn_output(inputs=ffn_output)
|
| 363 |
+
ffn_output = self.dropout(inputs=ffn_output, training=training)
|
| 364 |
+
hidden_states = self.full_layer_layer_norm(inputs=ffn_output + attention_outputs[0])
|
| 365 |
+
|
| 366 |
+
# add attentions if we output them
|
| 367 |
+
outputs = (hidden_states,) + attention_outputs[1:]
|
| 368 |
+
|
| 369 |
+
return outputs
|
| 370 |
+
|
| 371 |
+
def build(self, input_shape=None):
|
| 372 |
+
if self.built:
|
| 373 |
+
return
|
| 374 |
+
self.built = True
|
| 375 |
+
if getattr(self, "attention", None) is not None:
|
| 376 |
+
with tf.name_scope(self.attention.name):
|
| 377 |
+
self.attention.build(None)
|
| 378 |
+
if getattr(self, "ffn", None) is not None:
|
| 379 |
+
with tf.name_scope(self.ffn.name):
|
| 380 |
+
self.ffn.build([None, None, self.config.hidden_size])
|
| 381 |
+
if getattr(self, "ffn_output", None) is not None:
|
| 382 |
+
with tf.name_scope(self.ffn_output.name):
|
| 383 |
+
self.ffn_output.build([None, None, self.config.intermediate_size])
|
| 384 |
+
if getattr(self, "full_layer_layer_norm", None) is not None:
|
| 385 |
+
with tf.name_scope(self.full_layer_layer_norm.name):
|
| 386 |
+
self.full_layer_layer_norm.build([None, None, self.config.hidden_size])
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class TFAlbertLayerGroup(keras.layers.Layer):
|
| 390 |
+
def __init__(self, config: AlbertConfig, **kwargs):
|
| 391 |
+
super().__init__(**kwargs)
|
| 392 |
+
|
| 393 |
+
self.albert_layers = [
|
| 394 |
+
TFAlbertLayer(config, name=f"albert_layers_._{i}") for i in range(config.inner_group_num)
|
| 395 |
+
]
|
| 396 |
+
|
| 397 |
+
def call(
|
| 398 |
+
self,
|
| 399 |
+
hidden_states: tf.Tensor,
|
| 400 |
+
attention_mask: tf.Tensor,
|
| 401 |
+
head_mask: tf.Tensor,
|
| 402 |
+
output_attentions: bool,
|
| 403 |
+
output_hidden_states: bool,
|
| 404 |
+
training: bool = False,
|
| 405 |
+
) -> TFBaseModelOutput | tuple[tf.Tensor]:
|
| 406 |
+
layer_hidden_states = () if output_hidden_states else None
|
| 407 |
+
layer_attentions = () if output_attentions else None
|
| 408 |
+
|
| 409 |
+
for layer_index, albert_layer in enumerate(self.albert_layers):
|
| 410 |
+
if output_hidden_states:
|
| 411 |
+
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
| 412 |
+
|
| 413 |
+
layer_output = albert_layer(
|
| 414 |
+
hidden_states=hidden_states,
|
| 415 |
+
attention_mask=attention_mask,
|
| 416 |
+
head_mask=head_mask[layer_index],
|
| 417 |
+
output_attentions=output_attentions,
|
| 418 |
+
training=training,
|
| 419 |
+
)
|
| 420 |
+
hidden_states = layer_output[0]
|
| 421 |
+
|
| 422 |
+
if output_attentions:
|
| 423 |
+
layer_attentions = layer_attentions + (layer_output[1],)
|
| 424 |
+
|
| 425 |
+
# Add last layer
|
| 426 |
+
if output_hidden_states:
|
| 427 |
+
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
| 428 |
+
|
| 429 |
+
return tuple(v for v in [hidden_states, layer_hidden_states, layer_attentions] if v is not None)
|
| 430 |
+
|
| 431 |
+
def build(self, input_shape=None):
|
| 432 |
+
if self.built:
|
| 433 |
+
return
|
| 434 |
+
self.built = True
|
| 435 |
+
if getattr(self, "albert_layers", None) is not None:
|
| 436 |
+
for layer in self.albert_layers:
|
| 437 |
+
with tf.name_scope(layer.name):
|
| 438 |
+
layer.build(None)
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
class TFAlbertTransformer(keras.layers.Layer):
|
| 442 |
+
def __init__(self, config: AlbertConfig, **kwargs):
|
| 443 |
+
super().__init__(**kwargs)
|
| 444 |
+
|
| 445 |
+
self.num_hidden_layers = config.num_hidden_layers
|
| 446 |
+
self.num_hidden_groups = config.num_hidden_groups
|
| 447 |
+
# Number of layers in a hidden group
|
| 448 |
+
self.layers_per_group = int(config.num_hidden_layers / config.num_hidden_groups)
|
| 449 |
+
self.embedding_hidden_mapping_in = keras.layers.Dense(
|
| 450 |
+
units=config.hidden_size,
|
| 451 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 452 |
+
name="embedding_hidden_mapping_in",
|
| 453 |
+
)
|
| 454 |
+
self.albert_layer_groups = [
|
| 455 |
+
TFAlbertLayerGroup(config, name=f"albert_layer_groups_._{i}") for i in range(config.num_hidden_groups)
|
| 456 |
+
]
|
| 457 |
+
self.config = config
|
| 458 |
+
|
| 459 |
+
def call(
|
| 460 |
+
self,
|
| 461 |
+
hidden_states: tf.Tensor,
|
| 462 |
+
attention_mask: tf.Tensor,
|
| 463 |
+
head_mask: tf.Tensor,
|
| 464 |
+
output_attentions: bool,
|
| 465 |
+
output_hidden_states: bool,
|
| 466 |
+
return_dict: bool,
|
| 467 |
+
training: bool = False,
|
| 468 |
+
) -> TFBaseModelOutput | tuple[tf.Tensor]:
|
| 469 |
+
hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states)
|
| 470 |
+
all_attentions = () if output_attentions else None
|
| 471 |
+
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
| 472 |
+
|
| 473 |
+
for i in range(self.num_hidden_layers):
|
| 474 |
+
# Index of the hidden group
|
| 475 |
+
group_idx = int(i / (self.num_hidden_layers / self.num_hidden_groups))
|
| 476 |
+
layer_group_output = self.albert_layer_groups[group_idx](
|
| 477 |
+
hidden_states=hidden_states,
|
| 478 |
+
attention_mask=attention_mask,
|
| 479 |
+
head_mask=head_mask[group_idx * self.layers_per_group : (group_idx + 1) * self.layers_per_group],
|
| 480 |
+
output_attentions=output_attentions,
|
| 481 |
+
output_hidden_states=output_hidden_states,
|
| 482 |
+
training=training,
|
| 483 |
+
)
|
| 484 |
+
hidden_states = layer_group_output[0]
|
| 485 |
+
|
| 486 |
+
if output_attentions:
|
| 487 |
+
all_attentions = all_attentions + layer_group_output[-1]
|
| 488 |
+
|
| 489 |
+
if output_hidden_states:
|
| 490 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 491 |
+
|
| 492 |
+
if not return_dict:
|
| 493 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
| 494 |
+
|
| 495 |
+
return TFBaseModelOutput(
|
| 496 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
def build(self, input_shape=None):
|
| 500 |
+
if self.built:
|
| 501 |
+
return
|
| 502 |
+
self.built = True
|
| 503 |
+
if getattr(self, "embedding_hidden_mapping_in", None) is not None:
|
| 504 |
+
with tf.name_scope(self.embedding_hidden_mapping_in.name):
|
| 505 |
+
self.embedding_hidden_mapping_in.build([None, None, self.config.embedding_size])
|
| 506 |
+
if getattr(self, "albert_layer_groups", None) is not None:
|
| 507 |
+
for layer in self.albert_layer_groups:
|
| 508 |
+
with tf.name_scope(layer.name):
|
| 509 |
+
layer.build(None)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class TFAlbertPreTrainedModel(TFPreTrainedModel):
|
| 513 |
+
"""
|
| 514 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 515 |
+
models.
|
| 516 |
+
"""
|
| 517 |
+
|
| 518 |
+
config_class = AlbertConfig
|
| 519 |
+
base_model_prefix = "albert"
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
class TFAlbertMLMHead(keras.layers.Layer):
|
| 523 |
+
def __init__(self, config: AlbertConfig, input_embeddings: keras.layers.Layer, **kwargs):
|
| 524 |
+
super().__init__(**kwargs)
|
| 525 |
+
|
| 526 |
+
self.config = config
|
| 527 |
+
self.embedding_size = config.embedding_size
|
| 528 |
+
self.dense = keras.layers.Dense(
|
| 529 |
+
config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 530 |
+
)
|
| 531 |
+
if isinstance(config.hidden_act, str):
|
| 532 |
+
self.activation = get_tf_activation(config.hidden_act)
|
| 533 |
+
else:
|
| 534 |
+
self.activation = config.hidden_act
|
| 535 |
+
|
| 536 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 537 |
+
|
| 538 |
+
# The output weights are the same as the input embeddings, but there is
|
| 539 |
+
# an output-only bias for each token.
|
| 540 |
+
self.decoder = input_embeddings
|
| 541 |
+
|
| 542 |
+
def build(self, input_shape=None):
|
| 543 |
+
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
| 544 |
+
self.decoder_bias = self.add_weight(
|
| 545 |
+
shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias"
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
if self.built:
|
| 549 |
+
return
|
| 550 |
+
self.built = True
|
| 551 |
+
if getattr(self, "dense", None) is not None:
|
| 552 |
+
with tf.name_scope(self.dense.name):
|
| 553 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 554 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 555 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 556 |
+
self.LayerNorm.build([None, None, self.config.embedding_size])
|
| 557 |
+
|
| 558 |
+
def get_output_embeddings(self) -> keras.layers.Layer:
|
| 559 |
+
return self.decoder
|
| 560 |
+
|
| 561 |
+
def set_output_embeddings(self, value: tf.Variable):
|
| 562 |
+
self.decoder.weight = value
|
| 563 |
+
self.decoder.vocab_size = shape_list(value)[0]
|
| 564 |
+
|
| 565 |
+
def get_bias(self) -> dict[str, tf.Variable]:
|
| 566 |
+
return {"bias": self.bias, "decoder_bias": self.decoder_bias}
|
| 567 |
+
|
| 568 |
+
def set_bias(self, value: tf.Variable):
|
| 569 |
+
self.bias = value["bias"]
|
| 570 |
+
self.decoder_bias = value["decoder_bias"]
|
| 571 |
+
self.config.vocab_size = shape_list(value["bias"])[0]
|
| 572 |
+
|
| 573 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 574 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 575 |
+
hidden_states = self.activation(hidden_states)
|
| 576 |
+
hidden_states = self.LayerNorm(inputs=hidden_states)
|
| 577 |
+
seq_length = shape_list(tensor=hidden_states)[1]
|
| 578 |
+
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
|
| 579 |
+
hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True)
|
| 580 |
+
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
|
| 581 |
+
hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.decoder_bias)
|
| 582 |
+
|
| 583 |
+
return hidden_states
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
@keras_serializable
|
| 587 |
+
class TFAlbertMainLayer(keras.layers.Layer):
|
| 588 |
+
config_class = AlbertConfig
|
| 589 |
+
|
| 590 |
+
def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True, **kwargs):
|
| 591 |
+
super().__init__(**kwargs)
|
| 592 |
+
|
| 593 |
+
self.config = config
|
| 594 |
+
|
| 595 |
+
self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
|
| 596 |
+
self.encoder = TFAlbertTransformer(config, name="encoder")
|
| 597 |
+
self.pooler = (
|
| 598 |
+
keras.layers.Dense(
|
| 599 |
+
units=config.hidden_size,
|
| 600 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 601 |
+
activation="tanh",
|
| 602 |
+
name="pooler",
|
| 603 |
+
)
|
| 604 |
+
if add_pooling_layer
|
| 605 |
+
else None
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
def get_input_embeddings(self) -> keras.layers.Layer:
|
| 609 |
+
return self.embeddings
|
| 610 |
+
|
| 611 |
+
def set_input_embeddings(self, value: tf.Variable):
|
| 612 |
+
self.embeddings.weight = value
|
| 613 |
+
self.embeddings.vocab_size = shape_list(value)[0]
|
| 614 |
+
|
| 615 |
+
def _prune_heads(self, heads_to_prune):
|
| 616 |
+
"""
|
| 617 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 618 |
+
class PreTrainedModel
|
| 619 |
+
"""
|
| 620 |
+
raise NotImplementedError
|
| 621 |
+
|
| 622 |
+
@unpack_inputs
|
| 623 |
+
def call(
|
| 624 |
+
self,
|
| 625 |
+
input_ids: TFModelInputType | None = None,
|
| 626 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 627 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 628 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 629 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 630 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 631 |
+
output_attentions: bool | None = None,
|
| 632 |
+
output_hidden_states: bool | None = None,
|
| 633 |
+
return_dict: bool | None = None,
|
| 634 |
+
training: bool = False,
|
| 635 |
+
) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
|
| 636 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 637 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 638 |
+
elif input_ids is not None:
|
| 639 |
+
input_shape = shape_list(input_ids)
|
| 640 |
+
elif inputs_embeds is not None:
|
| 641 |
+
input_shape = shape_list(inputs_embeds)[:-1]
|
| 642 |
+
else:
|
| 643 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 644 |
+
|
| 645 |
+
if attention_mask is None:
|
| 646 |
+
attention_mask = tf.fill(dims=input_shape, value=1)
|
| 647 |
+
|
| 648 |
+
if token_type_ids is None:
|
| 649 |
+
token_type_ids = tf.fill(dims=input_shape, value=0)
|
| 650 |
+
|
| 651 |
+
embedding_output = self.embeddings(
|
| 652 |
+
input_ids=input_ids,
|
| 653 |
+
position_ids=position_ids,
|
| 654 |
+
token_type_ids=token_type_ids,
|
| 655 |
+
inputs_embeds=inputs_embeds,
|
| 656 |
+
training=training,
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 660 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 661 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 662 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 663 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 664 |
+
extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
|
| 665 |
+
|
| 666 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 667 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 668 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 669 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 670 |
+
# effectively the same as removing these entirely.
|
| 671 |
+
extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
|
| 672 |
+
one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
|
| 673 |
+
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
| 674 |
+
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
| 675 |
+
|
| 676 |
+
# Prepare head mask if needed
|
| 677 |
+
# 1.0 in head_mask indicate we keep the head
|
| 678 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 679 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 680 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 681 |
+
if head_mask is not None:
|
| 682 |
+
raise NotImplementedError
|
| 683 |
+
else:
|
| 684 |
+
head_mask = [None] * self.config.num_hidden_layers
|
| 685 |
+
|
| 686 |
+
encoder_outputs = self.encoder(
|
| 687 |
+
hidden_states=embedding_output,
|
| 688 |
+
attention_mask=extended_attention_mask,
|
| 689 |
+
head_mask=head_mask,
|
| 690 |
+
output_attentions=output_attentions,
|
| 691 |
+
output_hidden_states=output_hidden_states,
|
| 692 |
+
return_dict=return_dict,
|
| 693 |
+
training=training,
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
sequence_output = encoder_outputs[0]
|
| 697 |
+
pooled_output = self.pooler(inputs=sequence_output[:, 0]) if self.pooler is not None else None
|
| 698 |
+
|
| 699 |
+
if not return_dict:
|
| 700 |
+
return (
|
| 701 |
+
sequence_output,
|
| 702 |
+
pooled_output,
|
| 703 |
+
) + encoder_outputs[1:]
|
| 704 |
+
|
| 705 |
+
return TFBaseModelOutputWithPooling(
|
| 706 |
+
last_hidden_state=sequence_output,
|
| 707 |
+
pooler_output=pooled_output,
|
| 708 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 709 |
+
attentions=encoder_outputs.attentions,
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
def build(self, input_shape=None):
|
| 713 |
+
if self.built:
|
| 714 |
+
return
|
| 715 |
+
self.built = True
|
| 716 |
+
if getattr(self, "embeddings", None) is not None:
|
| 717 |
+
with tf.name_scope(self.embeddings.name):
|
| 718 |
+
self.embeddings.build(None)
|
| 719 |
+
if getattr(self, "encoder", None) is not None:
|
| 720 |
+
with tf.name_scope(self.encoder.name):
|
| 721 |
+
self.encoder.build(None)
|
| 722 |
+
if getattr(self, "pooler", None) is not None:
|
| 723 |
+
with tf.name_scope(self.pooler.name):
|
| 724 |
+
self.pooler.build([None, None, self.config.hidden_size])
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
@dataclass
|
| 728 |
+
class TFAlbertForPreTrainingOutput(ModelOutput):
|
| 729 |
+
"""
|
| 730 |
+
Output type of [`TFAlbertForPreTraining`].
|
| 731 |
+
|
| 732 |
+
Args:
|
| 733 |
+
prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 734 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 735 |
+
sop_logits (`tf.Tensor` of shape `(batch_size, 2)`):
|
| 736 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
| 737 |
+
before SoftMax).
|
| 738 |
+
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 739 |
+
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 740 |
+
`(batch_size, sequence_length, hidden_size)`.
|
| 741 |
+
|
| 742 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 743 |
+
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 744 |
+
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 745 |
+
sequence_length)`.
|
| 746 |
+
|
| 747 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 748 |
+
heads.
|
| 749 |
+
"""
|
| 750 |
+
|
| 751 |
+
loss: tf.Tensor | None = None
|
| 752 |
+
prediction_logits: tf.Tensor | None = None
|
| 753 |
+
sop_logits: tf.Tensor | None = None
|
| 754 |
+
hidden_states: tuple[tf.Tensor] | None = None
|
| 755 |
+
attentions: tuple[tf.Tensor] | None = None
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
ALBERT_START_DOCSTRING = r"""
|
| 759 |
+
|
| 760 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 761 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 762 |
+
etc.)
|
| 763 |
+
|
| 764 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 765 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 766 |
+
behavior.
|
| 767 |
+
|
| 768 |
+
<Tip>
|
| 769 |
+
|
| 770 |
+
TensorFlow models and layers in `transformers` accept two formats as input:
|
| 771 |
+
|
| 772 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
| 773 |
+
- having all inputs as a list, tuple or dict in the first positional argument.
|
| 774 |
+
|
| 775 |
+
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
|
| 776 |
+
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
|
| 777 |
+
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
|
| 778 |
+
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
|
| 779 |
+
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
|
| 780 |
+
positional argument:
|
| 781 |
+
|
| 782 |
+
- a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
|
| 783 |
+
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
| 784 |
+
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
| 785 |
+
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
| 786 |
+
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
|
| 787 |
+
|
| 788 |
+
Note that when creating models and layers with
|
| 789 |
+
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
|
| 790 |
+
about any of this, as you can just pass inputs like you would to any other Python function!
|
| 791 |
+
|
| 792 |
+
</Tip>
|
| 793 |
+
|
| 794 |
+
Args:
|
| 795 |
+
config ([`AlbertConfig`]): Model configuration class with all the parameters of the model.
|
| 796 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 797 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 798 |
+
"""
|
| 799 |
+
|
| 800 |
+
ALBERT_INPUTS_DOCSTRING = r"""
|
| 801 |
+
Args:
|
| 802 |
+
input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
|
| 803 |
+
Indices of input sequence tokens in the vocabulary.
|
| 804 |
+
|
| 805 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
|
| 806 |
+
[`PreTrainedTokenizer.encode`] for details.
|
| 807 |
+
|
| 808 |
+
[What are input IDs?](../glossary#input-ids)
|
| 809 |
+
attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 810 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 811 |
+
|
| 812 |
+
- 1 for tokens that are **not masked**,
|
| 813 |
+
- 0 for tokens that are **masked**.
|
| 814 |
+
|
| 815 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 816 |
+
token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 817 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 818 |
+
1]`:
|
| 819 |
+
|
| 820 |
+
- 0 corresponds to a *sentence A* token,
|
| 821 |
+
- 1 corresponds to a *sentence B* token.
|
| 822 |
+
|
| 823 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 824 |
+
position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 825 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 826 |
+
config.max_position_embeddings - 1]`.
|
| 827 |
+
|
| 828 |
+
[What are position IDs?](../glossary#position-ids)
|
| 829 |
+
head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 830 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 831 |
+
|
| 832 |
+
- 1 indicates the head is **not masked**,
|
| 833 |
+
- 0 indicates the head is **masked**.
|
| 834 |
+
|
| 835 |
+
inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
|
| 836 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 837 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 838 |
+
model's internal embedding lookup matrix.
|
| 839 |
+
output_attentions (`bool`, *optional*):
|
| 840 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 841 |
+
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
|
| 842 |
+
config will be used instead.
|
| 843 |
+
output_hidden_states (`bool`, *optional*):
|
| 844 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 845 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
| 846 |
+
used instead.
|
| 847 |
+
return_dict (`bool`, *optional*):
|
| 848 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
| 849 |
+
eager mode, in graph mode the value will always be set to True.
|
| 850 |
+
training (`bool`, *optional*, defaults to `False`):
|
| 851 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
| 852 |
+
behaviors between training and evaluation).
|
| 853 |
+
"""
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
@add_start_docstrings(
|
| 857 |
+
"The bare Albert Model transformer outputting raw hidden-states without any specific head on top.",
|
| 858 |
+
ALBERT_START_DOCSTRING,
|
| 859 |
+
)
|
| 860 |
+
class TFAlbertModel(TFAlbertPreTrainedModel):
|
| 861 |
+
def __init__(self, config: AlbertConfig, *inputs, **kwargs):
|
| 862 |
+
super().__init__(config, *inputs, **kwargs)
|
| 863 |
+
|
| 864 |
+
self.albert = TFAlbertMainLayer(config, name="albert")
|
| 865 |
+
|
| 866 |
+
@unpack_inputs
|
| 867 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 868 |
+
@add_code_sample_docstrings(
|
| 869 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 870 |
+
output_type=TFBaseModelOutputWithPooling,
|
| 871 |
+
config_class=_CONFIG_FOR_DOC,
|
| 872 |
+
)
|
| 873 |
+
def call(
|
| 874 |
+
self,
|
| 875 |
+
input_ids: TFModelInputType | None = None,
|
| 876 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 877 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 878 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 879 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 880 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 881 |
+
output_attentions: bool | None = None,
|
| 882 |
+
output_hidden_states: bool | None = None,
|
| 883 |
+
return_dict: bool | None = None,
|
| 884 |
+
training: bool | None = False,
|
| 885 |
+
) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
|
| 886 |
+
outputs = self.albert(
|
| 887 |
+
input_ids=input_ids,
|
| 888 |
+
attention_mask=attention_mask,
|
| 889 |
+
token_type_ids=token_type_ids,
|
| 890 |
+
position_ids=position_ids,
|
| 891 |
+
head_mask=head_mask,
|
| 892 |
+
inputs_embeds=inputs_embeds,
|
| 893 |
+
output_attentions=output_attentions,
|
| 894 |
+
output_hidden_states=output_hidden_states,
|
| 895 |
+
return_dict=return_dict,
|
| 896 |
+
training=training,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
return outputs
|
| 900 |
+
|
| 901 |
+
def build(self, input_shape=None):
|
| 902 |
+
if self.built:
|
| 903 |
+
return
|
| 904 |
+
self.built = True
|
| 905 |
+
if getattr(self, "albert", None) is not None:
|
| 906 |
+
with tf.name_scope(self.albert.name):
|
| 907 |
+
self.albert.build(None)
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
@add_start_docstrings(
|
| 911 |
+
"""
|
| 912 |
+
Albert Model with two heads on top for pretraining: a `masked language modeling` head and a `sentence order
|
| 913 |
+
prediction` (classification) head.
|
| 914 |
+
""",
|
| 915 |
+
ALBERT_START_DOCSTRING,
|
| 916 |
+
)
|
| 917 |
+
class TFAlbertForPreTraining(TFAlbertPreTrainedModel, TFAlbertPreTrainingLoss):
|
| 918 |
+
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
| 919 |
+
_keys_to_ignore_on_load_unexpected = [r"predictions.decoder.weight"]
|
| 920 |
+
|
| 921 |
+
def __init__(self, config: AlbertConfig, *inputs, **kwargs):
|
| 922 |
+
super().__init__(config, *inputs, **kwargs)
|
| 923 |
+
|
| 924 |
+
self.num_labels = config.num_labels
|
| 925 |
+
|
| 926 |
+
self.albert = TFAlbertMainLayer(config, name="albert")
|
| 927 |
+
self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions")
|
| 928 |
+
self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier")
|
| 929 |
+
|
| 930 |
+
def get_lm_head(self) -> keras.layers.Layer:
|
| 931 |
+
return self.predictions
|
| 932 |
+
|
| 933 |
+
@unpack_inputs
|
| 934 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 935 |
+
@replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
| 936 |
+
def call(
|
| 937 |
+
self,
|
| 938 |
+
input_ids: TFModelInputType | None = None,
|
| 939 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 940 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 941 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 942 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 943 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 944 |
+
output_attentions: bool | None = None,
|
| 945 |
+
output_hidden_states: bool | None = None,
|
| 946 |
+
return_dict: bool | None = None,
|
| 947 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 948 |
+
sentence_order_label: np.ndarray | tf.Tensor | None = None,
|
| 949 |
+
training: bool | None = False,
|
| 950 |
+
) -> TFAlbertForPreTrainingOutput | tuple[tf.Tensor]:
|
| 951 |
+
r"""
|
| 952 |
+
Return:
|
| 953 |
+
|
| 954 |
+
Example:
|
| 955 |
+
|
| 956 |
+
```python
|
| 957 |
+
>>> import tensorflow as tf
|
| 958 |
+
>>> from transformers import AutoTokenizer, TFAlbertForPreTraining
|
| 959 |
+
|
| 960 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
|
| 961 |
+
>>> model = TFAlbertForPreTraining.from_pretrained("albert/albert-base-v2")
|
| 962 |
+
|
| 963 |
+
>>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]
|
| 964 |
+
>>> # Batch size 1
|
| 965 |
+
>>> outputs = model(input_ids)
|
| 966 |
+
|
| 967 |
+
>>> prediction_logits = outputs.prediction_logits
|
| 968 |
+
>>> sop_logits = outputs.sop_logits
|
| 969 |
+
```"""
|
| 970 |
+
|
| 971 |
+
outputs = self.albert(
|
| 972 |
+
input_ids=input_ids,
|
| 973 |
+
attention_mask=attention_mask,
|
| 974 |
+
token_type_ids=token_type_ids,
|
| 975 |
+
position_ids=position_ids,
|
| 976 |
+
head_mask=head_mask,
|
| 977 |
+
inputs_embeds=inputs_embeds,
|
| 978 |
+
output_attentions=output_attentions,
|
| 979 |
+
output_hidden_states=output_hidden_states,
|
| 980 |
+
return_dict=return_dict,
|
| 981 |
+
training=training,
|
| 982 |
+
)
|
| 983 |
+
sequence_output, pooled_output = outputs[:2]
|
| 984 |
+
prediction_scores = self.predictions(hidden_states=sequence_output)
|
| 985 |
+
sop_scores = self.sop_classifier(pooled_output=pooled_output, training=training)
|
| 986 |
+
total_loss = None
|
| 987 |
+
|
| 988 |
+
if labels is not None and sentence_order_label is not None:
|
| 989 |
+
d_labels = {"labels": labels}
|
| 990 |
+
d_labels["sentence_order_label"] = sentence_order_label
|
| 991 |
+
total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, sop_scores))
|
| 992 |
+
|
| 993 |
+
if not return_dict:
|
| 994 |
+
output = (prediction_scores, sop_scores) + outputs[2:]
|
| 995 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 996 |
+
|
| 997 |
+
return TFAlbertForPreTrainingOutput(
|
| 998 |
+
loss=total_loss,
|
| 999 |
+
prediction_logits=prediction_scores,
|
| 1000 |
+
sop_logits=sop_scores,
|
| 1001 |
+
hidden_states=outputs.hidden_states,
|
| 1002 |
+
attentions=outputs.attentions,
|
| 1003 |
+
)
|
| 1004 |
+
|
| 1005 |
+
def build(self, input_shape=None):
|
| 1006 |
+
if self.built:
|
| 1007 |
+
return
|
| 1008 |
+
self.built = True
|
| 1009 |
+
if getattr(self, "albert", None) is not None:
|
| 1010 |
+
with tf.name_scope(self.albert.name):
|
| 1011 |
+
self.albert.build(None)
|
| 1012 |
+
if getattr(self, "predictions", None) is not None:
|
| 1013 |
+
with tf.name_scope(self.predictions.name):
|
| 1014 |
+
self.predictions.build(None)
|
| 1015 |
+
if getattr(self, "sop_classifier", None) is not None:
|
| 1016 |
+
with tf.name_scope(self.sop_classifier.name):
|
| 1017 |
+
self.sop_classifier.build(None)
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
class TFAlbertSOPHead(keras.layers.Layer):
|
| 1021 |
+
def __init__(self, config: AlbertConfig, **kwargs):
|
| 1022 |
+
super().__init__(**kwargs)
|
| 1023 |
+
|
| 1024 |
+
self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob)
|
| 1025 |
+
self.classifier = keras.layers.Dense(
|
| 1026 |
+
units=config.num_labels,
|
| 1027 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 1028 |
+
name="classifier",
|
| 1029 |
+
)
|
| 1030 |
+
self.config = config
|
| 1031 |
+
|
| 1032 |
+
def call(self, pooled_output: tf.Tensor, training: bool) -> tf.Tensor:
|
| 1033 |
+
dropout_pooled_output = self.dropout(inputs=pooled_output, training=training)
|
| 1034 |
+
logits = self.classifier(inputs=dropout_pooled_output)
|
| 1035 |
+
|
| 1036 |
+
return logits
|
| 1037 |
+
|
| 1038 |
+
def build(self, input_shape=None):
|
| 1039 |
+
if self.built:
|
| 1040 |
+
return
|
| 1041 |
+
self.built = True
|
| 1042 |
+
if getattr(self, "classifier", None) is not None:
|
| 1043 |
+
with tf.name_scope(self.classifier.name):
|
| 1044 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
@add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING)
|
| 1048 |
+
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
| 1049 |
+
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
| 1050 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions.decoder.weight"]
|
| 1051 |
+
|
| 1052 |
+
def __init__(self, config: AlbertConfig, *inputs, **kwargs):
|
| 1053 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1054 |
+
|
| 1055 |
+
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
|
| 1056 |
+
self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions")
|
| 1057 |
+
|
| 1058 |
+
def get_lm_head(self) -> keras.layers.Layer:
|
| 1059 |
+
return self.predictions
|
| 1060 |
+
|
| 1061 |
+
@unpack_inputs
|
| 1062 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1063 |
+
@replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
| 1064 |
+
def call(
|
| 1065 |
+
self,
|
| 1066 |
+
input_ids: TFModelInputType | None = None,
|
| 1067 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1068 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1069 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1070 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1071 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1072 |
+
output_attentions: bool | None = None,
|
| 1073 |
+
output_hidden_states: bool | None = None,
|
| 1074 |
+
return_dict: bool | None = None,
|
| 1075 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1076 |
+
training: bool | None = False,
|
| 1077 |
+
) -> TFMaskedLMOutput | tuple[tf.Tensor]:
|
| 1078 |
+
r"""
|
| 1079 |
+
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1080 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 1081 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 1082 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 1083 |
+
|
| 1084 |
+
Returns:
|
| 1085 |
+
|
| 1086 |
+
Example:
|
| 1087 |
+
|
| 1088 |
+
```python
|
| 1089 |
+
>>> import tensorflow as tf
|
| 1090 |
+
>>> from transformers import AutoTokenizer, TFAlbertForMaskedLM
|
| 1091 |
+
|
| 1092 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
|
| 1093 |
+
>>> model = TFAlbertForMaskedLM.from_pretrained("albert/albert-base-v2")
|
| 1094 |
+
|
| 1095 |
+
>>> # add mask_token
|
| 1096 |
+
>>> inputs = tokenizer(f"The capital of [MASK] is Paris.", return_tensors="tf")
|
| 1097 |
+
>>> logits = model(**inputs).logits
|
| 1098 |
+
|
| 1099 |
+
>>> # retrieve index of [MASK]
|
| 1100 |
+
>>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1]
|
| 1101 |
+
>>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1)
|
| 1102 |
+
>>> tokenizer.decode(predicted_token_id)
|
| 1103 |
+
'france'
|
| 1104 |
+
```
|
| 1105 |
+
|
| 1106 |
+
```python
|
| 1107 |
+
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
|
| 1108 |
+
>>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
|
| 1109 |
+
>>> outputs = model(**inputs, labels=labels)
|
| 1110 |
+
>>> round(float(outputs.loss), 2)
|
| 1111 |
+
0.81
|
| 1112 |
+
```
|
| 1113 |
+
"""
|
| 1114 |
+
outputs = self.albert(
|
| 1115 |
+
input_ids=input_ids,
|
| 1116 |
+
attention_mask=attention_mask,
|
| 1117 |
+
token_type_ids=token_type_ids,
|
| 1118 |
+
position_ids=position_ids,
|
| 1119 |
+
head_mask=head_mask,
|
| 1120 |
+
inputs_embeds=inputs_embeds,
|
| 1121 |
+
output_attentions=output_attentions,
|
| 1122 |
+
output_hidden_states=output_hidden_states,
|
| 1123 |
+
return_dict=return_dict,
|
| 1124 |
+
training=training,
|
| 1125 |
+
)
|
| 1126 |
+
sequence_output = outputs[0]
|
| 1127 |
+
prediction_scores = self.predictions(hidden_states=sequence_output, training=training)
|
| 1128 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
|
| 1129 |
+
|
| 1130 |
+
if not return_dict:
|
| 1131 |
+
output = (prediction_scores,) + outputs[2:]
|
| 1132 |
+
|
| 1133 |
+
return ((loss,) + output) if loss is not None else output
|
| 1134 |
+
|
| 1135 |
+
return TFMaskedLMOutput(
|
| 1136 |
+
loss=loss,
|
| 1137 |
+
logits=prediction_scores,
|
| 1138 |
+
hidden_states=outputs.hidden_states,
|
| 1139 |
+
attentions=outputs.attentions,
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
def build(self, input_shape=None):
|
| 1143 |
+
if self.built:
|
| 1144 |
+
return
|
| 1145 |
+
self.built = True
|
| 1146 |
+
if getattr(self, "albert", None) is not None:
|
| 1147 |
+
with tf.name_scope(self.albert.name):
|
| 1148 |
+
self.albert.build(None)
|
| 1149 |
+
if getattr(self, "predictions", None) is not None:
|
| 1150 |
+
with tf.name_scope(self.predictions.name):
|
| 1151 |
+
self.predictions.build(None)
|
| 1152 |
+
|
| 1153 |
+
|
| 1154 |
+
@add_start_docstrings(
|
| 1155 |
+
"""
|
| 1156 |
+
Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
| 1157 |
+
output) e.g. for GLUE tasks.
|
| 1158 |
+
""",
|
| 1159 |
+
ALBERT_START_DOCSTRING,
|
| 1160 |
+
)
|
| 1161 |
+
class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClassificationLoss):
|
| 1162 |
+
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
| 1163 |
+
_keys_to_ignore_on_load_unexpected = [r"predictions"]
|
| 1164 |
+
_keys_to_ignore_on_load_missing = [r"dropout"]
|
| 1165 |
+
|
| 1166 |
+
def __init__(self, config: AlbertConfig, *inputs, **kwargs):
|
| 1167 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1168 |
+
|
| 1169 |
+
self.num_labels = config.num_labels
|
| 1170 |
+
|
| 1171 |
+
self.albert = TFAlbertMainLayer(config, name="albert")
|
| 1172 |
+
self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob)
|
| 1173 |
+
self.classifier = keras.layers.Dense(
|
| 1174 |
+
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
| 1175 |
+
)
|
| 1176 |
+
self.config = config
|
| 1177 |
+
|
| 1178 |
+
@unpack_inputs
|
| 1179 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1180 |
+
@add_code_sample_docstrings(
|
| 1181 |
+
checkpoint="vumichien/albert-base-v2-imdb",
|
| 1182 |
+
output_type=TFSequenceClassifierOutput,
|
| 1183 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1184 |
+
expected_output="'LABEL_1'",
|
| 1185 |
+
expected_loss=0.12,
|
| 1186 |
+
)
|
| 1187 |
+
def call(
|
| 1188 |
+
self,
|
| 1189 |
+
input_ids: TFModelInputType | None = None,
|
| 1190 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1191 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1192 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1193 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1194 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1195 |
+
output_attentions: bool | None = None,
|
| 1196 |
+
output_hidden_states: bool | None = None,
|
| 1197 |
+
return_dict: bool | None = None,
|
| 1198 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1199 |
+
training: bool | None = False,
|
| 1200 |
+
) -> TFSequenceClassifierOutput | tuple[tf.Tensor]:
|
| 1201 |
+
r"""
|
| 1202 |
+
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
| 1203 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1204 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1205 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1206 |
+
"""
|
| 1207 |
+
outputs = self.albert(
|
| 1208 |
+
input_ids=input_ids,
|
| 1209 |
+
attention_mask=attention_mask,
|
| 1210 |
+
token_type_ids=token_type_ids,
|
| 1211 |
+
position_ids=position_ids,
|
| 1212 |
+
head_mask=head_mask,
|
| 1213 |
+
inputs_embeds=inputs_embeds,
|
| 1214 |
+
output_attentions=output_attentions,
|
| 1215 |
+
output_hidden_states=output_hidden_states,
|
| 1216 |
+
return_dict=return_dict,
|
| 1217 |
+
training=training,
|
| 1218 |
+
)
|
| 1219 |
+
pooled_output = outputs[1]
|
| 1220 |
+
pooled_output = self.dropout(inputs=pooled_output, training=training)
|
| 1221 |
+
logits = self.classifier(inputs=pooled_output)
|
| 1222 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 1223 |
+
|
| 1224 |
+
if not return_dict:
|
| 1225 |
+
output = (logits,) + outputs[2:]
|
| 1226 |
+
|
| 1227 |
+
return ((loss,) + output) if loss is not None else output
|
| 1228 |
+
|
| 1229 |
+
return TFSequenceClassifierOutput(
|
| 1230 |
+
loss=loss,
|
| 1231 |
+
logits=logits,
|
| 1232 |
+
hidden_states=outputs.hidden_states,
|
| 1233 |
+
attentions=outputs.attentions,
|
| 1234 |
+
)
|
| 1235 |
+
|
| 1236 |
+
def build(self, input_shape=None):
|
| 1237 |
+
if self.built:
|
| 1238 |
+
return
|
| 1239 |
+
self.built = True
|
| 1240 |
+
if getattr(self, "albert", None) is not None:
|
| 1241 |
+
with tf.name_scope(self.albert.name):
|
| 1242 |
+
self.albert.build(None)
|
| 1243 |
+
if getattr(self, "classifier", None) is not None:
|
| 1244 |
+
with tf.name_scope(self.classifier.name):
|
| 1245 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 1246 |
+
|
| 1247 |
+
|
| 1248 |
+
@add_start_docstrings(
|
| 1249 |
+
"""
|
| 1250 |
+
Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1251 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1252 |
+
""",
|
| 1253 |
+
ALBERT_START_DOCSTRING,
|
| 1254 |
+
)
|
| 1255 |
+
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
|
| 1256 |
+
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
| 1257 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
|
| 1258 |
+
_keys_to_ignore_on_load_missing = [r"dropout"]
|
| 1259 |
+
|
| 1260 |
+
def __init__(self, config: AlbertConfig, *inputs, **kwargs):
|
| 1261 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1262 |
+
|
| 1263 |
+
self.num_labels = config.num_labels
|
| 1264 |
+
|
| 1265 |
+
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
|
| 1266 |
+
classifier_dropout_prob = (
|
| 1267 |
+
config.classifier_dropout_prob
|
| 1268 |
+
if config.classifier_dropout_prob is not None
|
| 1269 |
+
else config.hidden_dropout_prob
|
| 1270 |
+
)
|
| 1271 |
+
self.dropout = keras.layers.Dropout(rate=classifier_dropout_prob)
|
| 1272 |
+
self.classifier = keras.layers.Dense(
|
| 1273 |
+
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
| 1274 |
+
)
|
| 1275 |
+
self.config = config
|
| 1276 |
+
|
| 1277 |
+
@unpack_inputs
|
| 1278 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1279 |
+
@add_code_sample_docstrings(
|
| 1280 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1281 |
+
output_type=TFTokenClassifierOutput,
|
| 1282 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1283 |
+
)
|
| 1284 |
+
def call(
|
| 1285 |
+
self,
|
| 1286 |
+
input_ids: TFModelInputType | None = None,
|
| 1287 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1288 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1289 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1290 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1291 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1292 |
+
output_attentions: bool | None = None,
|
| 1293 |
+
output_hidden_states: bool | None = None,
|
| 1294 |
+
return_dict: bool | None = None,
|
| 1295 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1296 |
+
training: bool | None = False,
|
| 1297 |
+
) -> TFTokenClassifierOutput | tuple[tf.Tensor]:
|
| 1298 |
+
r"""
|
| 1299 |
+
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1300 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1301 |
+
"""
|
| 1302 |
+
outputs = self.albert(
|
| 1303 |
+
input_ids=input_ids,
|
| 1304 |
+
attention_mask=attention_mask,
|
| 1305 |
+
token_type_ids=token_type_ids,
|
| 1306 |
+
position_ids=position_ids,
|
| 1307 |
+
head_mask=head_mask,
|
| 1308 |
+
inputs_embeds=inputs_embeds,
|
| 1309 |
+
output_attentions=output_attentions,
|
| 1310 |
+
output_hidden_states=output_hidden_states,
|
| 1311 |
+
return_dict=return_dict,
|
| 1312 |
+
training=training,
|
| 1313 |
+
)
|
| 1314 |
+
sequence_output = outputs[0]
|
| 1315 |
+
sequence_output = self.dropout(inputs=sequence_output, training=training)
|
| 1316 |
+
logits = self.classifier(inputs=sequence_output)
|
| 1317 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 1318 |
+
|
| 1319 |
+
if not return_dict:
|
| 1320 |
+
output = (logits,) + outputs[2:]
|
| 1321 |
+
|
| 1322 |
+
return ((loss,) + output) if loss is not None else output
|
| 1323 |
+
|
| 1324 |
+
return TFTokenClassifierOutput(
|
| 1325 |
+
loss=loss,
|
| 1326 |
+
logits=logits,
|
| 1327 |
+
hidden_states=outputs.hidden_states,
|
| 1328 |
+
attentions=outputs.attentions,
|
| 1329 |
+
)
|
| 1330 |
+
|
| 1331 |
+
def build(self, input_shape=None):
|
| 1332 |
+
if self.built:
|
| 1333 |
+
return
|
| 1334 |
+
self.built = True
|
| 1335 |
+
if getattr(self, "albert", None) is not None:
|
| 1336 |
+
with tf.name_scope(self.albert.name):
|
| 1337 |
+
self.albert.build(None)
|
| 1338 |
+
if getattr(self, "classifier", None) is not None:
|
| 1339 |
+
with tf.name_scope(self.classifier.name):
|
| 1340 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 1341 |
+
|
| 1342 |
+
|
| 1343 |
+
@add_start_docstrings(
|
| 1344 |
+
"""
|
| 1345 |
+
Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1346 |
+
layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1347 |
+
""",
|
| 1348 |
+
ALBERT_START_DOCSTRING,
|
| 1349 |
+
)
|
| 1350 |
+
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
|
| 1351 |
+
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
| 1352 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
|
| 1353 |
+
|
| 1354 |
+
def __init__(self, config: AlbertConfig, *inputs, **kwargs):
|
| 1355 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1356 |
+
|
| 1357 |
+
self.num_labels = config.num_labels
|
| 1358 |
+
|
| 1359 |
+
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
|
| 1360 |
+
self.qa_outputs = keras.layers.Dense(
|
| 1361 |
+
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
| 1362 |
+
)
|
| 1363 |
+
self.config = config
|
| 1364 |
+
|
| 1365 |
+
@unpack_inputs
|
| 1366 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1367 |
+
@add_code_sample_docstrings(
|
| 1368 |
+
checkpoint="vumichien/albert-base-v2-squad2",
|
| 1369 |
+
output_type=TFQuestionAnsweringModelOutput,
|
| 1370 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1371 |
+
qa_target_start_index=12,
|
| 1372 |
+
qa_target_end_index=13,
|
| 1373 |
+
expected_output="'a nice puppet'",
|
| 1374 |
+
expected_loss=7.36,
|
| 1375 |
+
)
|
| 1376 |
+
def call(
|
| 1377 |
+
self,
|
| 1378 |
+
input_ids: TFModelInputType | None = None,
|
| 1379 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1380 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1381 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1382 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1383 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1384 |
+
output_attentions: bool | None = None,
|
| 1385 |
+
output_hidden_states: bool | None = None,
|
| 1386 |
+
return_dict: bool | None = None,
|
| 1387 |
+
start_positions: np.ndarray | tf.Tensor | None = None,
|
| 1388 |
+
end_positions: np.ndarray | tf.Tensor | None = None,
|
| 1389 |
+
training: bool | None = False,
|
| 1390 |
+
) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]:
|
| 1391 |
+
r"""
|
| 1392 |
+
start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
| 1393 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1394 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1395 |
+
are not taken into account for computing the loss.
|
| 1396 |
+
end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
| 1397 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1398 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1399 |
+
are not taken into account for computing the loss.
|
| 1400 |
+
"""
|
| 1401 |
+
outputs = self.albert(
|
| 1402 |
+
input_ids=input_ids,
|
| 1403 |
+
attention_mask=attention_mask,
|
| 1404 |
+
token_type_ids=token_type_ids,
|
| 1405 |
+
position_ids=position_ids,
|
| 1406 |
+
head_mask=head_mask,
|
| 1407 |
+
inputs_embeds=inputs_embeds,
|
| 1408 |
+
output_attentions=output_attentions,
|
| 1409 |
+
output_hidden_states=output_hidden_states,
|
| 1410 |
+
return_dict=return_dict,
|
| 1411 |
+
training=training,
|
| 1412 |
+
)
|
| 1413 |
+
sequence_output = outputs[0]
|
| 1414 |
+
logits = self.qa_outputs(inputs=sequence_output)
|
| 1415 |
+
start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
|
| 1416 |
+
start_logits = tf.squeeze(input=start_logits, axis=-1)
|
| 1417 |
+
end_logits = tf.squeeze(input=end_logits, axis=-1)
|
| 1418 |
+
loss = None
|
| 1419 |
+
|
| 1420 |
+
if start_positions is not None and end_positions is not None:
|
| 1421 |
+
labels = {"start_position": start_positions}
|
| 1422 |
+
labels["end_position"] = end_positions
|
| 1423 |
+
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
|
| 1424 |
+
|
| 1425 |
+
if not return_dict:
|
| 1426 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1427 |
+
|
| 1428 |
+
return ((loss,) + output) if loss is not None else output
|
| 1429 |
+
|
| 1430 |
+
return TFQuestionAnsweringModelOutput(
|
| 1431 |
+
loss=loss,
|
| 1432 |
+
start_logits=start_logits,
|
| 1433 |
+
end_logits=end_logits,
|
| 1434 |
+
hidden_states=outputs.hidden_states,
|
| 1435 |
+
attentions=outputs.attentions,
|
| 1436 |
+
)
|
| 1437 |
+
|
| 1438 |
+
def build(self, input_shape=None):
|
| 1439 |
+
if self.built:
|
| 1440 |
+
return
|
| 1441 |
+
self.built = True
|
| 1442 |
+
if getattr(self, "albert", None) is not None:
|
| 1443 |
+
with tf.name_scope(self.albert.name):
|
| 1444 |
+
self.albert.build(None)
|
| 1445 |
+
if getattr(self, "qa_outputs", None) is not None:
|
| 1446 |
+
with tf.name_scope(self.qa_outputs.name):
|
| 1447 |
+
self.qa_outputs.build([None, None, self.config.hidden_size])
|
| 1448 |
+
|
| 1449 |
+
|
| 1450 |
+
@add_start_docstrings(
|
| 1451 |
+
"""
|
| 1452 |
+
Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
| 1453 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
| 1454 |
+
""",
|
| 1455 |
+
ALBERT_START_DOCSTRING,
|
| 1456 |
+
)
|
| 1457 |
+
class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
| 1458 |
+
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
| 1459 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
|
| 1460 |
+
_keys_to_ignore_on_load_missing = [r"dropout"]
|
| 1461 |
+
|
| 1462 |
+
def __init__(self, config: AlbertConfig, *inputs, **kwargs):
|
| 1463 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1464 |
+
|
| 1465 |
+
self.albert = TFAlbertMainLayer(config, name="albert")
|
| 1466 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 1467 |
+
self.classifier = keras.layers.Dense(
|
| 1468 |
+
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
| 1469 |
+
)
|
| 1470 |
+
self.config = config
|
| 1471 |
+
|
| 1472 |
+
@unpack_inputs
|
| 1473 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
| 1474 |
+
@add_code_sample_docstrings(
|
| 1475 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1476 |
+
output_type=TFMultipleChoiceModelOutput,
|
| 1477 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1478 |
+
)
|
| 1479 |
+
def call(
|
| 1480 |
+
self,
|
| 1481 |
+
input_ids: TFModelInputType | None = None,
|
| 1482 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1483 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1484 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1485 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1486 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1487 |
+
output_attentions: bool | None = None,
|
| 1488 |
+
output_hidden_states: bool | None = None,
|
| 1489 |
+
return_dict: bool | None = None,
|
| 1490 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1491 |
+
training: bool | None = False,
|
| 1492 |
+
) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]:
|
| 1493 |
+
r"""
|
| 1494 |
+
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
| 1495 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
|
| 1496 |
+
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
|
| 1497 |
+
"""
|
| 1498 |
+
|
| 1499 |
+
if input_ids is not None:
|
| 1500 |
+
num_choices = shape_list(input_ids)[1]
|
| 1501 |
+
seq_length = shape_list(input_ids)[2]
|
| 1502 |
+
else:
|
| 1503 |
+
num_choices = shape_list(inputs_embeds)[1]
|
| 1504 |
+
seq_length = shape_list(inputs_embeds)[2]
|
| 1505 |
+
|
| 1506 |
+
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
| 1507 |
+
flat_attention_mask = (
|
| 1508 |
+
tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
|
| 1509 |
+
)
|
| 1510 |
+
flat_token_type_ids = (
|
| 1511 |
+
tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
|
| 1512 |
+
)
|
| 1513 |
+
flat_position_ids = (
|
| 1514 |
+
tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None
|
| 1515 |
+
)
|
| 1516 |
+
flat_inputs_embeds = (
|
| 1517 |
+
tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]))
|
| 1518 |
+
if inputs_embeds is not None
|
| 1519 |
+
else None
|
| 1520 |
+
)
|
| 1521 |
+
outputs = self.albert(
|
| 1522 |
+
input_ids=flat_input_ids,
|
| 1523 |
+
attention_mask=flat_attention_mask,
|
| 1524 |
+
token_type_ids=flat_token_type_ids,
|
| 1525 |
+
position_ids=flat_position_ids,
|
| 1526 |
+
head_mask=head_mask,
|
| 1527 |
+
inputs_embeds=flat_inputs_embeds,
|
| 1528 |
+
output_attentions=output_attentions,
|
| 1529 |
+
output_hidden_states=output_hidden_states,
|
| 1530 |
+
return_dict=return_dict,
|
| 1531 |
+
training=training,
|
| 1532 |
+
)
|
| 1533 |
+
pooled_output = outputs[1]
|
| 1534 |
+
pooled_output = self.dropout(inputs=pooled_output, training=training)
|
| 1535 |
+
logits = self.classifier(inputs=pooled_output)
|
| 1536 |
+
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
|
| 1537 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
|
| 1538 |
+
|
| 1539 |
+
if not return_dict:
|
| 1540 |
+
output = (reshaped_logits,) + outputs[2:]
|
| 1541 |
+
return ((loss,) + output) if loss is not None else output
|
| 1542 |
+
|
| 1543 |
+
return TFMultipleChoiceModelOutput(
|
| 1544 |
+
loss=loss,
|
| 1545 |
+
logits=reshaped_logits,
|
| 1546 |
+
hidden_states=outputs.hidden_states,
|
| 1547 |
+
attentions=outputs.attentions,
|
| 1548 |
+
)
|
| 1549 |
+
|
| 1550 |
+
def build(self, input_shape=None):
|
| 1551 |
+
if self.built:
|
| 1552 |
+
return
|
| 1553 |
+
self.built = True
|
| 1554 |
+
if getattr(self, "albert", None) is not None:
|
| 1555 |
+
with tf.name_scope(self.albert.name):
|
| 1556 |
+
self.albert.build(None)
|
| 1557 |
+
if getattr(self, "classifier", None) is not None:
|
| 1558 |
+
with tf.name_scope(self.classifier.name):
|
| 1559 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 1560 |
+
|
| 1561 |
+
|
| 1562 |
+
__all__ = [
|
| 1563 |
+
"TFAlbertPreTrainedModel",
|
| 1564 |
+
"TFAlbertModel",
|
| 1565 |
+
"TFAlbertForPreTraining",
|
| 1566 |
+
"TFAlbertForMaskedLM",
|
| 1567 |
+
"TFAlbertForSequenceClassification",
|
| 1568 |
+
"TFAlbertForTokenClassification",
|
| 1569 |
+
"TFAlbertForQuestionAnswering",
|
| 1570 |
+
"TFAlbertForMultipleChoice",
|
| 1571 |
+
"TFAlbertMainLayer",
|
| 1572 |
+
]
|
venv/lib/python3.13/site-packages/transformers/models/albert/tokenization_albert.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for ALBERT model."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import unicodedata
|
| 19 |
+
from shutil import copyfile
|
| 20 |
+
from typing import Any, Optional
|
| 21 |
+
|
| 22 |
+
import sentencepiece as spm
|
| 23 |
+
|
| 24 |
+
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 25 |
+
from ...utils import logging
|
| 26 |
+
from ...utils.import_utils import requires
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__)
|
| 30 |
+
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
SPIECE_UNDERLINE = "▁"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@requires(backends=("sentencepiece",))
|
| 37 |
+
class AlbertTokenizer(PreTrainedTokenizer):
|
| 38 |
+
"""
|
| 39 |
+
Construct an ALBERT tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
|
| 40 |
+
|
| 41 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 42 |
+
this superclass for more information regarding those methods.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
vocab_file (`str`):
|
| 46 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
| 47 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 48 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
| 49 |
+
Whether or not to lowercase the input when tokenizing.
|
| 50 |
+
remove_space (`bool`, *optional*, defaults to `True`):
|
| 51 |
+
Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
|
| 52 |
+
keep_accents (`bool`, *optional*, defaults to `False`):
|
| 53 |
+
Whether or not to keep accents when tokenizing.
|
| 54 |
+
bos_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 55 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 56 |
+
|
| 57 |
+
<Tip>
|
| 58 |
+
|
| 59 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 60 |
+
sequence. The token used is the `cls_token`.
|
| 61 |
+
|
| 62 |
+
</Tip>
|
| 63 |
+
|
| 64 |
+
eos_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 65 |
+
The end of sequence token.
|
| 66 |
+
|
| 67 |
+
<Tip>
|
| 68 |
+
|
| 69 |
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
| 70 |
+
The token used is the `sep_token`.
|
| 71 |
+
|
| 72 |
+
</Tip>
|
| 73 |
+
|
| 74 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 75 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 76 |
+
token instead.
|
| 77 |
+
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 78 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 79 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 80 |
+
token of a sequence built with special tokens.
|
| 81 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 82 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 83 |
+
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 84 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 85 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 86 |
+
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
|
| 87 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 88 |
+
modeling. This is the token which the model will try to predict.
|
| 89 |
+
sp_model_kwargs (`dict`, *optional*):
|
| 90 |
+
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
| 91 |
+
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
| 92 |
+
to set:
|
| 93 |
+
|
| 94 |
+
- `enable_sampling`: Enable subword regularization.
|
| 95 |
+
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
| 96 |
+
|
| 97 |
+
- `nbest_size = {0,1}`: No sampling is performed.
|
| 98 |
+
- `nbest_size > 1`: samples from the nbest_size results.
|
| 99 |
+
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
| 100 |
+
using forward-filtering-and-backward-sampling algorithm.
|
| 101 |
+
|
| 102 |
+
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
| 103 |
+
BPE-dropout.
|
| 104 |
+
|
| 105 |
+
Attributes:
|
| 106 |
+
sp_model (`SentencePieceProcessor`):
|
| 107 |
+
The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
vocab_file,
|
| 115 |
+
do_lower_case=True,
|
| 116 |
+
remove_space=True,
|
| 117 |
+
keep_accents=False,
|
| 118 |
+
bos_token="[CLS]",
|
| 119 |
+
eos_token="[SEP]",
|
| 120 |
+
unk_token="<unk>",
|
| 121 |
+
sep_token="[SEP]",
|
| 122 |
+
pad_token="<pad>",
|
| 123 |
+
cls_token="[CLS]",
|
| 124 |
+
mask_token="[MASK]",
|
| 125 |
+
sp_model_kwargs: Optional[dict[str, Any]] = None,
|
| 126 |
+
**kwargs,
|
| 127 |
+
) -> None:
|
| 128 |
+
# Mask token behave like a normal word, i.e. include the space before it and
|
| 129 |
+
# is included in the raw text, there should be a match in a non-normalized sentence.
|
| 130 |
+
mask_token = (
|
| 131 |
+
AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
|
| 132 |
+
if isinstance(mask_token, str)
|
| 133 |
+
else mask_token
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 137 |
+
|
| 138 |
+
self.do_lower_case = do_lower_case
|
| 139 |
+
self.remove_space = remove_space
|
| 140 |
+
self.keep_accents = keep_accents
|
| 141 |
+
self.vocab_file = vocab_file
|
| 142 |
+
|
| 143 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 144 |
+
self.sp_model.Load(vocab_file)
|
| 145 |
+
|
| 146 |
+
super().__init__(
|
| 147 |
+
do_lower_case=do_lower_case,
|
| 148 |
+
remove_space=remove_space,
|
| 149 |
+
keep_accents=keep_accents,
|
| 150 |
+
bos_token=bos_token,
|
| 151 |
+
eos_token=eos_token,
|
| 152 |
+
unk_token=unk_token,
|
| 153 |
+
sep_token=sep_token,
|
| 154 |
+
pad_token=pad_token,
|
| 155 |
+
cls_token=cls_token,
|
| 156 |
+
mask_token=mask_token,
|
| 157 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
| 158 |
+
**kwargs,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
@property
|
| 162 |
+
def vocab_size(self) -> int:
|
| 163 |
+
return len(self.sp_model)
|
| 164 |
+
|
| 165 |
+
def get_vocab(self) -> dict[str, int]:
|
| 166 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 167 |
+
vocab.update(self.added_tokens_encoder)
|
| 168 |
+
return vocab
|
| 169 |
+
|
| 170 |
+
def __getstate__(self):
|
| 171 |
+
state = self.__dict__.copy()
|
| 172 |
+
state["sp_model"] = None
|
| 173 |
+
return state
|
| 174 |
+
|
| 175 |
+
def __setstate__(self, d):
|
| 176 |
+
self.__dict__ = d
|
| 177 |
+
|
| 178 |
+
# for backward compatibility
|
| 179 |
+
if not hasattr(self, "sp_model_kwargs"):
|
| 180 |
+
self.sp_model_kwargs = {}
|
| 181 |
+
|
| 182 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 183 |
+
self.sp_model.Load(self.vocab_file)
|
| 184 |
+
|
| 185 |
+
def preprocess_text(self, inputs):
|
| 186 |
+
if self.remove_space:
|
| 187 |
+
outputs = " ".join(inputs.strip().split())
|
| 188 |
+
else:
|
| 189 |
+
outputs = inputs
|
| 190 |
+
outputs = outputs.replace("``", '"').replace("''", '"')
|
| 191 |
+
|
| 192 |
+
if not self.keep_accents:
|
| 193 |
+
outputs = unicodedata.normalize("NFKD", outputs)
|
| 194 |
+
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
|
| 195 |
+
if self.do_lower_case:
|
| 196 |
+
outputs = outputs.lower()
|
| 197 |
+
|
| 198 |
+
return outputs
|
| 199 |
+
|
| 200 |
+
def _tokenize(self, text: str) -> list[str]:
|
| 201 |
+
"""Tokenize a string."""
|
| 202 |
+
text = self.preprocess_text(text)
|
| 203 |
+
pieces = self.sp_model.encode(text, out_type=str)
|
| 204 |
+
new_pieces = []
|
| 205 |
+
for piece in pieces:
|
| 206 |
+
if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
|
| 207 |
+
# Logic to handle special cases see https://github.com/google-research/bert/blob/master/README.md#tokenization
|
| 208 |
+
# `9,9` -> ['▁9', ',', '9'] instead of [`_9,`, '9']
|
| 209 |
+
cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
|
| 210 |
+
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
|
| 211 |
+
if len(cur_pieces[0]) == 1:
|
| 212 |
+
cur_pieces = cur_pieces[1:]
|
| 213 |
+
else:
|
| 214 |
+
cur_pieces[0] = cur_pieces[0][1:]
|
| 215 |
+
cur_pieces.append(piece[-1])
|
| 216 |
+
new_pieces.extend(cur_pieces)
|
| 217 |
+
else:
|
| 218 |
+
new_pieces.append(piece)
|
| 219 |
+
|
| 220 |
+
return new_pieces
|
| 221 |
+
|
| 222 |
+
def _convert_token_to_id(self, token):
|
| 223 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 224 |
+
return self.sp_model.PieceToId(token)
|
| 225 |
+
|
| 226 |
+
def _convert_id_to_token(self, index):
|
| 227 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 228 |
+
return self.sp_model.IdToPiece(index)
|
| 229 |
+
|
| 230 |
+
def convert_tokens_to_string(self, tokens):
|
| 231 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 232 |
+
current_sub_tokens = []
|
| 233 |
+
out_string = ""
|
| 234 |
+
prev_is_special = False
|
| 235 |
+
for token in tokens:
|
| 236 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
| 237 |
+
if token in self.all_special_tokens:
|
| 238 |
+
if not prev_is_special:
|
| 239 |
+
out_string += " "
|
| 240 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
| 241 |
+
prev_is_special = True
|
| 242 |
+
current_sub_tokens = []
|
| 243 |
+
else:
|
| 244 |
+
current_sub_tokens.append(token)
|
| 245 |
+
prev_is_special = False
|
| 246 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 247 |
+
return out_string.strip()
|
| 248 |
+
|
| 249 |
+
def build_inputs_with_special_tokens(
|
| 250 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 251 |
+
) -> list[int]:
|
| 252 |
+
"""
|
| 253 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 254 |
+
adding special tokens. An ALBERT sequence has the following format:
|
| 255 |
+
|
| 256 |
+
- single sequence: `[CLS] X [SEP]`
|
| 257 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
token_ids_0 (`List[int]`):
|
| 261 |
+
List of IDs to which the special tokens will be added.
|
| 262 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 263 |
+
Optional second list of IDs for sequence pairs.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 267 |
+
"""
|
| 268 |
+
sep = [self.sep_token_id]
|
| 269 |
+
cls = [self.cls_token_id]
|
| 270 |
+
if token_ids_1 is None:
|
| 271 |
+
return cls + token_ids_0 + sep
|
| 272 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 273 |
+
|
| 274 |
+
def get_special_tokens_mask(
|
| 275 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
|
| 276 |
+
) -> list[int]:
|
| 277 |
+
"""
|
| 278 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 279 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
token_ids_0 (`List[int]`):
|
| 283 |
+
List of IDs.
|
| 284 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 285 |
+
Optional second list of IDs for sequence pairs.
|
| 286 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 287 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
if already_has_special_tokens:
|
| 294 |
+
return super().get_special_tokens_mask(
|
| 295 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if token_ids_1 is not None:
|
| 299 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 300 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 301 |
+
|
| 302 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
|
| 303 |
+
if not os.path.isdir(save_directory):
|
| 304 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 305 |
+
return
|
| 306 |
+
out_vocab_file = os.path.join(
|
| 307 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 311 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 312 |
+
elif not os.path.isfile(self.vocab_file):
|
| 313 |
+
with open(out_vocab_file, "wb") as fi:
|
| 314 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 315 |
+
fi.write(content_spiece_model)
|
| 316 |
+
|
| 317 |
+
return (out_vocab_file,)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
__all__ = ["AlbertTokenizer"]
|
venv/lib/python3.13/site-packages/transformers/models/albert/tokenization_albert_fast.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for ALBERT model."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from shutil import copyfile
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
from ...tokenization_utils import AddedToken
|
| 22 |
+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
| 23 |
+
from ...utils import is_sentencepiece_available, logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_sentencepiece_available():
|
| 27 |
+
from .tokenization_albert import AlbertTokenizer
|
| 28 |
+
else:
|
| 29 |
+
AlbertTokenizer = None
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
SPIECE_UNDERLINE = "▁"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AlbertTokenizerFast(PreTrainedTokenizerFast):
|
| 39 |
+
"""
|
| 40 |
+
Construct a "fast" ALBERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on
|
| 41 |
+
[Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). This
|
| 42 |
+
tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should refer to
|
| 43 |
+
this superclass for more information regarding those methods
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
vocab_file (`str`):
|
| 47 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
| 48 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 49 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
| 50 |
+
Whether or not to lowercase the input when tokenizing.
|
| 51 |
+
remove_space (`bool`, *optional*, defaults to `True`):
|
| 52 |
+
Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
|
| 53 |
+
keep_accents (`bool`, *optional*, defaults to `False`):
|
| 54 |
+
Whether or not to keep accents when tokenizing.
|
| 55 |
+
bos_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 56 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 57 |
+
|
| 58 |
+
<Tip>
|
| 59 |
+
|
| 60 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 61 |
+
sequence. The token used is the `cls_token`.
|
| 62 |
+
|
| 63 |
+
</Tip>
|
| 64 |
+
|
| 65 |
+
eos_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 66 |
+
The end of sequence token. .. note:: When building a sequence using special tokens, this is not the token
|
| 67 |
+
that is used for the end of sequence. The token used is the `sep_token`.
|
| 68 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 69 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 70 |
+
token instead.
|
| 71 |
+
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 72 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 73 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 74 |
+
token of a sequence built with special tokens.
|
| 75 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 76 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 77 |
+
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 78 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 79 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 80 |
+
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
|
| 81 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 82 |
+
modeling. This is the token which the model will try to predict.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 86 |
+
slow_tokenizer_class = AlbertTokenizer
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
vocab_file=None,
|
| 91 |
+
tokenizer_file=None,
|
| 92 |
+
do_lower_case=True,
|
| 93 |
+
remove_space=True,
|
| 94 |
+
keep_accents=False,
|
| 95 |
+
bos_token="[CLS]",
|
| 96 |
+
eos_token="[SEP]",
|
| 97 |
+
unk_token="<unk>",
|
| 98 |
+
sep_token="[SEP]",
|
| 99 |
+
pad_token="<pad>",
|
| 100 |
+
cls_token="[CLS]",
|
| 101 |
+
mask_token="[MASK]",
|
| 102 |
+
**kwargs,
|
| 103 |
+
):
|
| 104 |
+
# Mask token behave like a normal word, i.e. include the space before it and
|
| 105 |
+
# is included in the raw text, there should be a match in a non-normalized sentence.
|
| 106 |
+
mask_token = (
|
| 107 |
+
AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
|
| 108 |
+
if isinstance(mask_token, str)
|
| 109 |
+
else mask_token
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
super().__init__(
|
| 113 |
+
vocab_file,
|
| 114 |
+
tokenizer_file=tokenizer_file,
|
| 115 |
+
do_lower_case=do_lower_case,
|
| 116 |
+
remove_space=remove_space,
|
| 117 |
+
keep_accents=keep_accents,
|
| 118 |
+
bos_token=bos_token,
|
| 119 |
+
eos_token=eos_token,
|
| 120 |
+
unk_token=unk_token,
|
| 121 |
+
sep_token=sep_token,
|
| 122 |
+
pad_token=pad_token,
|
| 123 |
+
cls_token=cls_token,
|
| 124 |
+
mask_token=mask_token,
|
| 125 |
+
**kwargs,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self.do_lower_case = do_lower_case
|
| 129 |
+
self.remove_space = remove_space
|
| 130 |
+
self.keep_accents = keep_accents
|
| 131 |
+
self.vocab_file = vocab_file
|
| 132 |
+
|
| 133 |
+
def build_inputs_with_special_tokens(
|
| 134 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 135 |
+
) -> list[int]:
|
| 136 |
+
"""
|
| 137 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 138 |
+
adding special tokens. An ALBERT sequence has the following format:
|
| 139 |
+
|
| 140 |
+
- single sequence: `[CLS] X [SEP]`
|
| 141 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
token_ids_0 (`List[int]`):
|
| 145 |
+
List of IDs to which the special tokens will be added
|
| 146 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 147 |
+
Optional second list of IDs for sequence pairs.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
`List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 151 |
+
"""
|
| 152 |
+
sep = [self.sep_token_id]
|
| 153 |
+
cls = [self.cls_token_id]
|
| 154 |
+
if token_ids_1 is None:
|
| 155 |
+
return cls + token_ids_0 + sep
|
| 156 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 157 |
+
|
| 158 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
|
| 159 |
+
if not self.can_save_slow_tokenizer:
|
| 160 |
+
raise ValueError(
|
| 161 |
+
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
| 162 |
+
"tokenizer."
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
if not os.path.isdir(save_directory):
|
| 166 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 167 |
+
return
|
| 168 |
+
out_vocab_file = os.path.join(
|
| 169 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
| 173 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 174 |
+
|
| 175 |
+
return (out_vocab_file,)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
__all__ = ["AlbertTokenizerFast"]
|
venv/lib/python3.13/site-packages/transformers/models/apertus/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on HuggingFace's LLaMA implementation in this library.
|
| 5 |
+
# It has been modified from its original forms to accommodate the architectural
|
| 6 |
+
# differences made by the Swiss AI Initiative that trained the model.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
from typing import TYPE_CHECKING
|
| 20 |
+
|
| 21 |
+
from ...utils import _LazyModule
|
| 22 |
+
from ...utils.import_utils import define_import_structure
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if TYPE_CHECKING:
|
| 26 |
+
from .configuration_apertus import *
|
| 27 |
+
from .modeling_apertus import *
|
| 28 |
+
else:
|
| 29 |
+
import sys
|
| 30 |
+
|
| 31 |
+
_file = globals()["__file__"]
|
| 32 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
venv/lib/python3.13/site-packages/transformers/models/apertus/configuration_apertus.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/apertus/modular_apertus.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_apertus.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2025 the HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
#
|
| 11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 12 |
+
# you may not use this file except in compliance with the License.
|
| 13 |
+
# You may obtain a copy of the License at
|
| 14 |
+
#
|
| 15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 16 |
+
#
|
| 17 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 20 |
+
# See the License for the specific language governing permissions and
|
| 21 |
+
# limitations under the License.
|
| 22 |
+
|
| 23 |
+
from ...configuration_utils import PretrainedConfig
|
| 24 |
+
from ...modeling_rope_utils import rope_config_validation
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ApertusConfig(PretrainedConfig):
|
| 28 |
+
r"""
|
| 29 |
+
This is the configuration class to store the configuration of a [`ApertusModel`]. It is used to instantiate a Apertus
|
| 30 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 31 |
+
defaults will yield a similar configuration to that of the Apertus-8B.
|
| 32 |
+
e.g. [swiss-ai/Apertus-8B](https://huggingface.co/swiss-ai/Apertus-8B)
|
| 33 |
+
|
| 34 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 35 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
vocab_size (`int`, *optional*, defaults to 131072):
|
| 40 |
+
Vocabulary size of the Apertus model. Defines the number of different tokens that can be represented by the
|
| 41 |
+
`inputs_ids` passed when calling [`ApertusModel`]
|
| 42 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 43 |
+
Dimension of the hidden representations.
|
| 44 |
+
intermediate_size (`int`, *optional*, defaults to 14336):
|
| 45 |
+
Dimension of the MLP representations.
|
| 46 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 47 |
+
Number of hidden layers in the Transformer decoder.
|
| 48 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 49 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 50 |
+
num_key_value_heads (`int`, *optional*):
|
| 51 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 52 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 53 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 54 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 55 |
+
by meanpooling all the original heads within that group. For more details, check out [this
|
| 56 |
+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
| 57 |
+
`num_attention_heads`.
|
| 58 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"xielu"`):
|
| 59 |
+
The non-linear activation function (function or string) in the decoder.
|
| 60 |
+
max_position_embeddings (`int`, *optional*, defaults to 65536):
|
| 61 |
+
The maximum sequence length that this model might ever be used with. Apertus supports up to 65536 tokens.
|
| 62 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 63 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 64 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 65 |
+
The epsilon used by the rms normalization layers.
|
| 66 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 67 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 68 |
+
relevant if `config.is_decoder=True`.
|
| 69 |
+
pad_token_id (`int`, *optional*, defaults to 3):
|
| 70 |
+
Padding token id.
|
| 71 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 72 |
+
Beginning of stream token id.
|
| 73 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 74 |
+
End of stream token id.
|
| 75 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 76 |
+
Whether to tie weight embeddings
|
| 77 |
+
rope_theta (`float`, *optional*, defaults to 12000000.0):
|
| 78 |
+
The base period of the RoPE embeddings.
|
| 79 |
+
rope_scaling (`Dict`, *optional*):
|
| 80 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 81 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 82 |
+
accordingly.
|
| 83 |
+
Expected contents:
|
| 84 |
+
`rope_type` (`str`):
|
| 85 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 86 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 87 |
+
`factor` (`float`, *optional*):
|
| 88 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 89 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 90 |
+
original maximum pre-trained length.
|
| 91 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 92 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 93 |
+
pretraining.
|
| 94 |
+
`attention_factor` (`float`, *optional*):
|
| 95 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 96 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 97 |
+
`factor` field to infer the suggested value.
|
| 98 |
+
`beta_fast` (`float`, *optional*):
|
| 99 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 100 |
+
ramp function. If unspecified, it defaults to 32.
|
| 101 |
+
`beta_slow` (`float`, *optional*):
|
| 102 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 103 |
+
ramp function. If unspecified, it defaults to 1.
|
| 104 |
+
`short_factor` (`list[float]`, *optional*):
|
| 105 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 106 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 107 |
+
size divided by the number of attention heads divided by 2
|
| 108 |
+
`long_factor` (`list[float]`, *optional*):
|
| 109 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 110 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 111 |
+
size divided by the number of attention heads divided by 2
|
| 112 |
+
`low_freq_factor` (`float`, *optional*):
|
| 113 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 114 |
+
`high_freq_factor` (`float`, *optional*):
|
| 115 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 116 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 117 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 118 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 119 |
+
The dropout ratio for the attention probabilities.
|
| 120 |
+
|
| 121 |
+
```python
|
| 122 |
+
>>> from transformers import ApertusModel, ApertusConfig
|
| 123 |
+
|
| 124 |
+
>>> # Initializing a Apertus-8B style configuration
|
| 125 |
+
>>> configuration = ApertusConfig()
|
| 126 |
+
|
| 127 |
+
>>> # Initializing a model from the Apertus-8B style configuration
|
| 128 |
+
>>> model = ApertusModel(configuration)
|
| 129 |
+
|
| 130 |
+
>>> # Accessing the model configuration
|
| 131 |
+
>>> configuration = model.config
|
| 132 |
+
```"""
|
| 133 |
+
|
| 134 |
+
model_type = "apertus"
|
| 135 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 136 |
+
base_model_tp_plan = {
|
| 137 |
+
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 138 |
+
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 139 |
+
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 140 |
+
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
|
| 141 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 142 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 143 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 144 |
+
}
|
| 145 |
+
base_model_pp_plan = {
|
| 146 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 147 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 148 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
vocab_size=131072,
|
| 154 |
+
hidden_size=4096,
|
| 155 |
+
intermediate_size=14336,
|
| 156 |
+
num_hidden_layers=32,
|
| 157 |
+
num_attention_heads=32,
|
| 158 |
+
num_key_value_heads=None,
|
| 159 |
+
hidden_act="xielu",
|
| 160 |
+
max_position_embeddings=65536,
|
| 161 |
+
initializer_range=0.02,
|
| 162 |
+
rms_norm_eps=1e-5,
|
| 163 |
+
use_cache=True,
|
| 164 |
+
pad_token_id=3,
|
| 165 |
+
bos_token_id=1,
|
| 166 |
+
eos_token_id=2,
|
| 167 |
+
tie_word_embeddings=False,
|
| 168 |
+
rope_theta=12000000.0,
|
| 169 |
+
rope_scaling={
|
| 170 |
+
"rope_type": "llama3",
|
| 171 |
+
"factor": 8.0,
|
| 172 |
+
"original_max_position_embeddings": 8192,
|
| 173 |
+
"low_freq_factor": 1.0,
|
| 174 |
+
"high_freq_factor": 4.0,
|
| 175 |
+
},
|
| 176 |
+
attention_bias=False,
|
| 177 |
+
attention_dropout=0.0,
|
| 178 |
+
**kwargs,
|
| 179 |
+
):
|
| 180 |
+
super().__init__(
|
| 181 |
+
pad_token_id=pad_token_id,
|
| 182 |
+
bos_token_id=bos_token_id,
|
| 183 |
+
eos_token_id=eos_token_id,
|
| 184 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 185 |
+
**kwargs,
|
| 186 |
+
)
|
| 187 |
+
self.vocab_size = vocab_size
|
| 188 |
+
self.max_position_embeddings = max_position_embeddings
|
| 189 |
+
self.hidden_size = hidden_size
|
| 190 |
+
self.intermediate_size = intermediate_size
|
| 191 |
+
self.num_hidden_layers = num_hidden_layers
|
| 192 |
+
self.num_attention_heads = num_attention_heads
|
| 193 |
+
|
| 194 |
+
# for backward compatibility
|
| 195 |
+
if num_key_value_heads is None:
|
| 196 |
+
num_key_value_heads = num_attention_heads
|
| 197 |
+
|
| 198 |
+
self.num_key_value_heads = num_key_value_heads
|
| 199 |
+
self.hidden_act = hidden_act
|
| 200 |
+
self.initializer_range = initializer_range
|
| 201 |
+
self.rms_norm_eps = rms_norm_eps
|
| 202 |
+
self.use_cache = use_cache
|
| 203 |
+
self.rope_theta = rope_theta
|
| 204 |
+
self.rope_scaling = rope_scaling
|
| 205 |
+
self.attention_bias = attention_bias
|
| 206 |
+
self.attention_dropout = attention_dropout
|
| 207 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 208 |
+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
| 209 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 210 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 211 |
+
rope_config_validation(self)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
__all__ = ["ApertusConfig"]
|
venv/lib/python3.13/site-packages/transformers/models/apertus/modeling_apertus.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/apertus/modular_apertus.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_apertus.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2025 the HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
#
|
| 11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 12 |
+
# you may not use this file except in compliance with the License.
|
| 13 |
+
# You may obtain a copy of the License at
|
| 14 |
+
#
|
| 15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 16 |
+
#
|
| 17 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 20 |
+
# See the License for the specific language governing permissions and
|
| 21 |
+
# limitations under the License.
|
| 22 |
+
from typing import Callable, Optional, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
|
| 27 |
+
from ...activations import ACT2FN
|
| 28 |
+
from ...cache_utils import Cache, DynamicCache
|
| 29 |
+
from ...generation import GenerationMixin
|
| 30 |
+
from ...integrations import use_kernel_forward_from_hub
|
| 31 |
+
from ...masking_utils import create_causal_mask
|
| 32 |
+
from ...modeling_layers import GenericForTokenClassification, GradientCheckpointingLayer
|
| 33 |
+
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 34 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 35 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 36 |
+
from ...processing_utils import Unpack
|
| 37 |
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
| 38 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 39 |
+
from ...utils.generic import check_model_inputs
|
| 40 |
+
from .configuration_apertus import ApertusConfig
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ApertusMLP(nn.Module):
|
| 44 |
+
def __init__(self, config):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.config = config
|
| 47 |
+
self.hidden_size = config.hidden_size
|
| 48 |
+
self.intermediate_size = config.intermediate_size
|
| 49 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 50 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 51 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
return self.down_proj(self.act_fn(self.up_proj(x)))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 58 |
+
class ApertusRMSNorm(nn.Module):
|
| 59 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 60 |
+
"""
|
| 61 |
+
ApertusRMSNorm is equivalent to T5LayerNorm
|
| 62 |
+
"""
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 65 |
+
self.variance_epsilon = eps
|
| 66 |
+
|
| 67 |
+
def forward(self, hidden_states):
|
| 68 |
+
input_dtype = hidden_states.dtype
|
| 69 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 70 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 71 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 72 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 73 |
+
|
| 74 |
+
def extra_repr(self):
|
| 75 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ApertusRotaryEmbedding(nn.Module):
|
| 79 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 80 |
+
|
| 81 |
+
def __init__(self, config: ApertusConfig, device=None):
|
| 82 |
+
super().__init__()
|
| 83 |
+
# BC: "rope_type" was originally "type"
|
| 84 |
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 85 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 86 |
+
else:
|
| 87 |
+
self.rope_type = "default"
|
| 88 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 89 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 90 |
+
|
| 91 |
+
self.config = config
|
| 92 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 93 |
+
|
| 94 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 95 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 96 |
+
self.original_inv_freq = self.inv_freq
|
| 97 |
+
|
| 98 |
+
@torch.no_grad()
|
| 99 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 100 |
+
def forward(self, x, position_ids):
|
| 101 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 102 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 103 |
+
|
| 104 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 105 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 106 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 107 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 108 |
+
cos = emb.cos() * self.attention_scaling
|
| 109 |
+
sin = emb.sin() * self.attention_scaling
|
| 110 |
+
|
| 111 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def rotate_half(x):
|
| 115 |
+
"""Rotates half the hidden dims of the input."""
|
| 116 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 117 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 118 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 122 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
q (`torch.Tensor`): The query tensor.
|
| 126 |
+
k (`torch.Tensor`): The key tensor.
|
| 127 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 128 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 129 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 130 |
+
Deprecated and unused.
|
| 131 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 132 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 133 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 134 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 135 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 136 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 137 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 138 |
+
Returns:
|
| 139 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 140 |
+
"""
|
| 141 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 142 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 143 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 144 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 145 |
+
return q_embed, k_embed
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 149 |
+
"""
|
| 150 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 151 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 152 |
+
"""
|
| 153 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 154 |
+
if n_rep == 1:
|
| 155 |
+
return hidden_states
|
| 156 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 157 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def eager_attention_forward(
|
| 161 |
+
module: nn.Module,
|
| 162 |
+
query: torch.Tensor,
|
| 163 |
+
key: torch.Tensor,
|
| 164 |
+
value: torch.Tensor,
|
| 165 |
+
attention_mask: Optional[torch.Tensor],
|
| 166 |
+
scaling: float,
|
| 167 |
+
dropout: float = 0.0,
|
| 168 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 169 |
+
):
|
| 170 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 171 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 172 |
+
|
| 173 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 174 |
+
if attention_mask is not None:
|
| 175 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 176 |
+
attn_weights = attn_weights + causal_mask
|
| 177 |
+
|
| 178 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 179 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 180 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 181 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 182 |
+
|
| 183 |
+
return attn_output, attn_weights
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class ApertusAttention(nn.Module):
|
| 187 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 188 |
+
|
| 189 |
+
def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None):
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.config = config
|
| 192 |
+
self.layer_idx = layer_idx
|
| 193 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 194 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 195 |
+
self.scaling = self.head_dim**-0.5
|
| 196 |
+
self.attention_dropout = config.attention_dropout
|
| 197 |
+
self.is_causal = True
|
| 198 |
+
|
| 199 |
+
self.q_proj = nn.Linear(
|
| 200 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 201 |
+
)
|
| 202 |
+
self.k_proj = nn.Linear(
|
| 203 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 204 |
+
)
|
| 205 |
+
self.v_proj = nn.Linear(
|
| 206 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 207 |
+
)
|
| 208 |
+
self.o_proj = nn.Linear(
|
| 209 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 210 |
+
)
|
| 211 |
+
self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
|
| 212 |
+
self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
|
| 213 |
+
|
| 214 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 215 |
+
def forward(
|
| 216 |
+
self,
|
| 217 |
+
hidden_states: torch.Tensor,
|
| 218 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 219 |
+
attention_mask: Optional[torch.Tensor],
|
| 220 |
+
past_key_values: Optional[Cache] = None,
|
| 221 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 222 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 223 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 224 |
+
input_shape = hidden_states.shape[:-1]
|
| 225 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 226 |
+
|
| 227 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 228 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 229 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 230 |
+
query_states = self.q_norm(query_states)
|
| 231 |
+
key_states = self.k_norm(key_states)
|
| 232 |
+
|
| 233 |
+
cos, sin = position_embeddings
|
| 234 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 235 |
+
|
| 236 |
+
if past_key_values is not None:
|
| 237 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 238 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 239 |
+
|
| 240 |
+
attention_interface: Callable = eager_attention_forward
|
| 241 |
+
if self.config._attn_implementation != "eager":
|
| 242 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 243 |
+
|
| 244 |
+
attn_output, attn_weights = attention_interface(
|
| 245 |
+
self,
|
| 246 |
+
query_states,
|
| 247 |
+
key_states,
|
| 248 |
+
value_states,
|
| 249 |
+
attention_mask,
|
| 250 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 251 |
+
scaling=self.scaling,
|
| 252 |
+
**kwargs,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 256 |
+
attn_output = self.o_proj(attn_output)
|
| 257 |
+
return attn_output, attn_weights
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class ApertusDecoderLayer(GradientCheckpointingLayer):
|
| 261 |
+
def __init__(self, config: ApertusConfig, layer_idx: int):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.hidden_size = config.hidden_size
|
| 264 |
+
|
| 265 |
+
self.self_attn = ApertusAttention(config=config, layer_idx=layer_idx)
|
| 266 |
+
|
| 267 |
+
self.mlp = ApertusMLP(config)
|
| 268 |
+
self.attention_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 269 |
+
self.feedforward_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 270 |
+
|
| 271 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 272 |
+
def forward(
|
| 273 |
+
self,
|
| 274 |
+
hidden_states: torch.Tensor,
|
| 275 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 276 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 277 |
+
past_key_values: Optional[Cache] = None,
|
| 278 |
+
use_cache: Optional[bool] = False,
|
| 279 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 280 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 281 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 282 |
+
) -> tuple[torch.Tensor]:
|
| 283 |
+
residual = hidden_states
|
| 284 |
+
hidden_states = self.attention_layernorm(hidden_states)
|
| 285 |
+
hidden_states, _ = self.self_attn(
|
| 286 |
+
hidden_states=hidden_states,
|
| 287 |
+
attention_mask=attention_mask,
|
| 288 |
+
position_ids=position_ids,
|
| 289 |
+
past_key_values=past_key_values,
|
| 290 |
+
use_cache=use_cache,
|
| 291 |
+
cache_position=cache_position,
|
| 292 |
+
position_embeddings=position_embeddings,
|
| 293 |
+
**kwargs,
|
| 294 |
+
)
|
| 295 |
+
hidden_states = residual + hidden_states
|
| 296 |
+
|
| 297 |
+
# Fully Connected
|
| 298 |
+
residual = hidden_states
|
| 299 |
+
hidden_states = self.feedforward_layernorm(hidden_states)
|
| 300 |
+
hidden_states = self.mlp(hidden_states)
|
| 301 |
+
hidden_states = residual + hidden_states
|
| 302 |
+
return hidden_states
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
@auto_docstring
|
| 306 |
+
class ApertusPreTrainedModel(PreTrainedModel):
|
| 307 |
+
config: ApertusConfig
|
| 308 |
+
base_model_prefix = "model"
|
| 309 |
+
supports_gradient_checkpointing = True
|
| 310 |
+
_no_split_modules = ["ApertusDecoderLayer"]
|
| 311 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 312 |
+
_supports_flash_attn = True
|
| 313 |
+
_supports_sdpa = True
|
| 314 |
+
_supports_flex_attn = True
|
| 315 |
+
|
| 316 |
+
_can_compile_fullgraph = True
|
| 317 |
+
_supports_attention_backend = True
|
| 318 |
+
_can_record_outputs = {
|
| 319 |
+
"hidden_states": ApertusDecoderLayer,
|
| 320 |
+
"attentions": ApertusAttention,
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
@auto_docstring
|
| 325 |
+
class ApertusModel(ApertusPreTrainedModel):
|
| 326 |
+
def __init__(self, config: ApertusConfig):
|
| 327 |
+
super().__init__(config)
|
| 328 |
+
self.padding_idx = config.pad_token_id
|
| 329 |
+
self.vocab_size = config.vocab_size
|
| 330 |
+
|
| 331 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 332 |
+
self.layers = nn.ModuleList(
|
| 333 |
+
[ApertusDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 334 |
+
)
|
| 335 |
+
self.norm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 336 |
+
self.rotary_emb = ApertusRotaryEmbedding(config=config)
|
| 337 |
+
self.gradient_checkpointing = False
|
| 338 |
+
|
| 339 |
+
# Initialize weights and apply final processing
|
| 340 |
+
self.post_init()
|
| 341 |
+
|
| 342 |
+
@check_model_inputs()
|
| 343 |
+
@auto_docstring
|
| 344 |
+
def forward(
|
| 345 |
+
self,
|
| 346 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 347 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 348 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 349 |
+
past_key_values: Optional[Cache] = None,
|
| 350 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 351 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 352 |
+
use_cache: Optional[bool] = None,
|
| 353 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 354 |
+
) -> BaseModelOutputWithPast:
|
| 355 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 356 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 357 |
+
|
| 358 |
+
if inputs_embeds is None:
|
| 359 |
+
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
| 360 |
+
|
| 361 |
+
if use_cache and past_key_values is None:
|
| 362 |
+
past_key_values = DynamicCache(config=self.config)
|
| 363 |
+
|
| 364 |
+
if cache_position is None:
|
| 365 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 366 |
+
cache_position: torch.Tensor = torch.arange(
|
| 367 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
if position_ids is None:
|
| 371 |
+
position_ids = cache_position.unsqueeze(0)
|
| 372 |
+
|
| 373 |
+
causal_mask = create_causal_mask(
|
| 374 |
+
config=self.config,
|
| 375 |
+
input_embeds=inputs_embeds,
|
| 376 |
+
attention_mask=attention_mask,
|
| 377 |
+
cache_position=cache_position,
|
| 378 |
+
past_key_values=past_key_values,
|
| 379 |
+
position_ids=position_ids,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
hidden_states = inputs_embeds
|
| 383 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 384 |
+
|
| 385 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 386 |
+
hidden_states = decoder_layer(
|
| 387 |
+
hidden_states,
|
| 388 |
+
attention_mask=causal_mask,
|
| 389 |
+
position_ids=position_ids,
|
| 390 |
+
past_key_values=past_key_values,
|
| 391 |
+
cache_position=cache_position,
|
| 392 |
+
position_embeddings=position_embeddings,
|
| 393 |
+
**kwargs,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
hidden_states = self.norm(hidden_states)
|
| 397 |
+
return BaseModelOutputWithPast(
|
| 398 |
+
last_hidden_state=hidden_states,
|
| 399 |
+
past_key_values=past_key_values,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
@auto_docstring
|
| 404 |
+
class ApertusForCausalLM(ApertusPreTrainedModel, GenerationMixin):
|
| 405 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 406 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 407 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 408 |
+
|
| 409 |
+
def __init__(self, config):
|
| 410 |
+
super().__init__(config)
|
| 411 |
+
self.model = ApertusModel(config)
|
| 412 |
+
self.vocab_size = config.vocab_size
|
| 413 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 414 |
+
|
| 415 |
+
# Initialize weights and apply final processing
|
| 416 |
+
self.post_init()
|
| 417 |
+
|
| 418 |
+
@can_return_tuple
|
| 419 |
+
@auto_docstring
|
| 420 |
+
def forward(
|
| 421 |
+
self,
|
| 422 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 423 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 424 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 425 |
+
past_key_values: Optional[Cache] = None,
|
| 426 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 427 |
+
labels: Optional[torch.LongTensor] = None,
|
| 428 |
+
use_cache: Optional[bool] = None,
|
| 429 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 430 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 431 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 432 |
+
) -> CausalLMOutputWithPast:
|
| 433 |
+
r"""
|
| 434 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 435 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 436 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 437 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 438 |
+
|
| 439 |
+
Example:
|
| 440 |
+
|
| 441 |
+
```python
|
| 442 |
+
>>> from transformers import AutoTokenizer, ApertusForCausalLM
|
| 443 |
+
|
| 444 |
+
>>> model = ApertusForCausalLM.from_pretrained("swiss-ai/Apertus-8B")
|
| 445 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B")
|
| 446 |
+
|
| 447 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 448 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 449 |
+
|
| 450 |
+
>>> # Generate
|
| 451 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 452 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 453 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 454 |
+
```"""
|
| 455 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 456 |
+
input_ids=input_ids,
|
| 457 |
+
attention_mask=attention_mask,
|
| 458 |
+
position_ids=position_ids,
|
| 459 |
+
past_key_values=past_key_values,
|
| 460 |
+
inputs_embeds=inputs_embeds,
|
| 461 |
+
use_cache=use_cache,
|
| 462 |
+
cache_position=cache_position,
|
| 463 |
+
**kwargs,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
hidden_states = outputs.last_hidden_state
|
| 467 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 468 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 469 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 470 |
+
|
| 471 |
+
loss = None
|
| 472 |
+
if labels is not None:
|
| 473 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 474 |
+
|
| 475 |
+
return CausalLMOutputWithPast(
|
| 476 |
+
loss=loss,
|
| 477 |
+
logits=logits,
|
| 478 |
+
past_key_values=outputs.past_key_values,
|
| 479 |
+
hidden_states=outputs.hidden_states,
|
| 480 |
+
attentions=outputs.attentions,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
class ApertusForTokenClassification(GenericForTokenClassification, ApertusPreTrainedModel):
|
| 485 |
+
pass
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
__all__ = ["ApertusModel", "ApertusForCausalLM", "ApertusForTokenClassification", "ApertusPreTrainedModel"]
|
venv/lib/python3.13/site-packages/transformers/models/apertus/modular_apertus.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 the HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
from typing import Callable, Optional
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn
|
| 20 |
+
|
| 21 |
+
from ...cache_utils import Cache
|
| 22 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
| 23 |
+
from ...processing_utils import Unpack
|
| 24 |
+
from ...utils import TransformersKwargs, logging
|
| 25 |
+
from ..llama.configuration_llama import LlamaConfig
|
| 26 |
+
from ..llama.modeling_llama import (
|
| 27 |
+
LlamaAttention,
|
| 28 |
+
LlamaDecoderLayer,
|
| 29 |
+
LlamaForCausalLM,
|
| 30 |
+
LlamaForTokenClassification,
|
| 31 |
+
LlamaModel,
|
| 32 |
+
LlamaPreTrainedModel,
|
| 33 |
+
LlamaRMSNorm,
|
| 34 |
+
LlamaRotaryEmbedding,
|
| 35 |
+
apply_rotary_pos_emb,
|
| 36 |
+
eager_attention_forward,
|
| 37 |
+
)
|
| 38 |
+
from ..nemotron.modeling_nemotron import NemotronMLP
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ApertusConfig(LlamaConfig):
|
| 45 |
+
r"""
|
| 46 |
+
This is the configuration class to store the configuration of a [`ApertusModel`]. It is used to instantiate a Apertus
|
| 47 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 48 |
+
defaults will yield a similar configuration to that of the Apertus-8B.
|
| 49 |
+
e.g. [swiss-ai/Apertus-8B](https://huggingface.co/swiss-ai/Apertus-8B)
|
| 50 |
+
|
| 51 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 52 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
vocab_size (`int`, *optional*, defaults to 131072):
|
| 57 |
+
Vocabulary size of the Apertus model. Defines the number of different tokens that can be represented by the
|
| 58 |
+
`inputs_ids` passed when calling [`ApertusModel`]
|
| 59 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 60 |
+
Dimension of the hidden representations.
|
| 61 |
+
intermediate_size (`int`, *optional*, defaults to 14336):
|
| 62 |
+
Dimension of the MLP representations.
|
| 63 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 64 |
+
Number of hidden layers in the Transformer decoder.
|
| 65 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 66 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 67 |
+
num_key_value_heads (`int`, *optional*):
|
| 68 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 69 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 70 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 71 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 72 |
+
by meanpooling all the original heads within that group. For more details, check out [this
|
| 73 |
+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
| 74 |
+
`num_attention_heads`.
|
| 75 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"xielu"`):
|
| 76 |
+
The non-linear activation function (function or string) in the decoder.
|
| 77 |
+
max_position_embeddings (`int`, *optional*, defaults to 65536):
|
| 78 |
+
The maximum sequence length that this model might ever be used with. Apertus supports up to 65536 tokens.
|
| 79 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 80 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 81 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 82 |
+
The epsilon used by the rms normalization layers.
|
| 83 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 84 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 85 |
+
relevant if `config.is_decoder=True`.
|
| 86 |
+
pad_token_id (`int`, *optional*, defaults to 3):
|
| 87 |
+
Padding token id.
|
| 88 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 89 |
+
Beginning of stream token id.
|
| 90 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 91 |
+
End of stream token id.
|
| 92 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 93 |
+
Whether to tie weight embeddings
|
| 94 |
+
rope_theta (`float`, *optional*, defaults to 12000000.0):
|
| 95 |
+
The base period of the RoPE embeddings.
|
| 96 |
+
rope_scaling (`Dict`, *optional*):
|
| 97 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 98 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 99 |
+
accordingly.
|
| 100 |
+
Expected contents:
|
| 101 |
+
`rope_type` (`str`):
|
| 102 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 103 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 104 |
+
`factor` (`float`, *optional*):
|
| 105 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 106 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 107 |
+
original maximum pre-trained length.
|
| 108 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 109 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 110 |
+
pretraining.
|
| 111 |
+
`attention_factor` (`float`, *optional*):
|
| 112 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 113 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 114 |
+
`factor` field to infer the suggested value.
|
| 115 |
+
`beta_fast` (`float`, *optional*):
|
| 116 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 117 |
+
ramp function. If unspecified, it defaults to 32.
|
| 118 |
+
`beta_slow` (`float`, *optional*):
|
| 119 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 120 |
+
ramp function. If unspecified, it defaults to 1.
|
| 121 |
+
`short_factor` (`list[float]`, *optional*):
|
| 122 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 123 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 124 |
+
size divided by the number of attention heads divided by 2
|
| 125 |
+
`long_factor` (`list[float]`, *optional*):
|
| 126 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 127 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 128 |
+
size divided by the number of attention heads divided by 2
|
| 129 |
+
`low_freq_factor` (`float`, *optional*):
|
| 130 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 131 |
+
`high_freq_factor` (`float`, *optional*):
|
| 132 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 133 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 134 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 135 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 136 |
+
The dropout ratio for the attention probabilities.
|
| 137 |
+
|
| 138 |
+
```python
|
| 139 |
+
>>> from transformers import ApertusModel, ApertusConfig
|
| 140 |
+
|
| 141 |
+
>>> # Initializing a Apertus-8B style configuration
|
| 142 |
+
>>> configuration = ApertusConfig()
|
| 143 |
+
|
| 144 |
+
>>> # Initializing a model from the Apertus-8B style configuration
|
| 145 |
+
>>> model = ApertusModel(configuration)
|
| 146 |
+
|
| 147 |
+
>>> # Accessing the model configuration
|
| 148 |
+
>>> configuration = model.config
|
| 149 |
+
```"""
|
| 150 |
+
|
| 151 |
+
model_type = "apertus"
|
| 152 |
+
base_model_tp_plan = {
|
| 153 |
+
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 154 |
+
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 155 |
+
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 156 |
+
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
|
| 157 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 158 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 159 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
vocab_size=131072,
|
| 165 |
+
hidden_size=4096,
|
| 166 |
+
intermediate_size=14336,
|
| 167 |
+
num_hidden_layers=32,
|
| 168 |
+
num_attention_heads=32,
|
| 169 |
+
num_key_value_heads=None,
|
| 170 |
+
hidden_act="xielu",
|
| 171 |
+
max_position_embeddings=65536,
|
| 172 |
+
initializer_range=0.02,
|
| 173 |
+
rms_norm_eps=1e-5,
|
| 174 |
+
use_cache=True,
|
| 175 |
+
pad_token_id=3,
|
| 176 |
+
bos_token_id=1,
|
| 177 |
+
eos_token_id=2,
|
| 178 |
+
tie_word_embeddings=False,
|
| 179 |
+
rope_theta=12000000.0,
|
| 180 |
+
rope_scaling={
|
| 181 |
+
"rope_type": "llama3",
|
| 182 |
+
"factor": 8.0,
|
| 183 |
+
"original_max_position_embeddings": 8192,
|
| 184 |
+
"low_freq_factor": 1.0,
|
| 185 |
+
"high_freq_factor": 4.0,
|
| 186 |
+
},
|
| 187 |
+
attention_bias=False,
|
| 188 |
+
attention_dropout=0.0,
|
| 189 |
+
**kwargs,
|
| 190 |
+
):
|
| 191 |
+
super().__init__(
|
| 192 |
+
vocab_size=vocab_size,
|
| 193 |
+
hidden_size=hidden_size,
|
| 194 |
+
intermediate_size=intermediate_size,
|
| 195 |
+
num_hidden_layers=num_hidden_layers,
|
| 196 |
+
num_attention_heads=num_attention_heads,
|
| 197 |
+
num_key_value_heads=num_key_value_heads,
|
| 198 |
+
hidden_act=hidden_act,
|
| 199 |
+
max_position_embeddings=max_position_embeddings,
|
| 200 |
+
initializer_range=initializer_range,
|
| 201 |
+
rms_norm_eps=rms_norm_eps,
|
| 202 |
+
use_cache=use_cache,
|
| 203 |
+
pad_token_id=pad_token_id,
|
| 204 |
+
bos_token_id=bos_token_id,
|
| 205 |
+
eos_token_id=eos_token_id,
|
| 206 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 207 |
+
rope_theta=rope_theta,
|
| 208 |
+
rope_scaling=rope_scaling,
|
| 209 |
+
attention_bias=attention_bias,
|
| 210 |
+
attention_dropout=attention_dropout,
|
| 211 |
+
**kwargs,
|
| 212 |
+
)
|
| 213 |
+
del self.pretraining_tp
|
| 214 |
+
del self.mlp_bias
|
| 215 |
+
del self.head_dim
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class ApertusMLP(NemotronMLP):
|
| 219 |
+
def __init__(self, config):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 222 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class ApertusRMSNorm(LlamaRMSNorm):
|
| 226 |
+
pass
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class ApertusRotaryEmbedding(LlamaRotaryEmbedding):
|
| 230 |
+
pass
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class ApertusAttention(LlamaAttention):
|
| 234 |
+
def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None):
|
| 235 |
+
super().__init__(config, layer_idx)
|
| 236 |
+
self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
|
| 237 |
+
self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
|
| 238 |
+
|
| 239 |
+
def forward(
|
| 240 |
+
self,
|
| 241 |
+
hidden_states: torch.Tensor,
|
| 242 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 243 |
+
attention_mask: Optional[torch.Tensor],
|
| 244 |
+
past_key_values: Optional[Cache] = None,
|
| 245 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 246 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 247 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 248 |
+
input_shape = hidden_states.shape[:-1]
|
| 249 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 250 |
+
|
| 251 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 252 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 253 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 254 |
+
query_states = self.q_norm(query_states)
|
| 255 |
+
key_states = self.k_norm(key_states)
|
| 256 |
+
|
| 257 |
+
cos, sin = position_embeddings
|
| 258 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 259 |
+
|
| 260 |
+
if past_key_values is not None:
|
| 261 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 262 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 263 |
+
|
| 264 |
+
attention_interface: Callable = eager_attention_forward
|
| 265 |
+
if self.config._attn_implementation != "eager":
|
| 266 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 267 |
+
|
| 268 |
+
attn_output, attn_weights = attention_interface(
|
| 269 |
+
self,
|
| 270 |
+
query_states,
|
| 271 |
+
key_states,
|
| 272 |
+
value_states,
|
| 273 |
+
attention_mask,
|
| 274 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 275 |
+
scaling=self.scaling,
|
| 276 |
+
**kwargs,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 280 |
+
attn_output = self.o_proj(attn_output)
|
| 281 |
+
return attn_output, attn_weights
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class ApertusDecoderLayer(LlamaDecoderLayer):
|
| 285 |
+
def __init__(self, config: ApertusConfig, layer_idx: int):
|
| 286 |
+
super().__init__(config, layer_idx)
|
| 287 |
+
self.attention_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 288 |
+
self.feedforward_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 289 |
+
|
| 290 |
+
del self.input_layernorm
|
| 291 |
+
del self.post_attention_layernorm
|
| 292 |
+
|
| 293 |
+
def forward(
|
| 294 |
+
self,
|
| 295 |
+
hidden_states: torch.Tensor,
|
| 296 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 297 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 298 |
+
past_key_values: Optional[Cache] = None,
|
| 299 |
+
use_cache: Optional[bool] = False,
|
| 300 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 301 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 302 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 303 |
+
) -> tuple[torch.Tensor]:
|
| 304 |
+
residual = hidden_states
|
| 305 |
+
hidden_states = self.attention_layernorm(hidden_states)
|
| 306 |
+
hidden_states, _ = self.self_attn(
|
| 307 |
+
hidden_states=hidden_states,
|
| 308 |
+
attention_mask=attention_mask,
|
| 309 |
+
position_ids=position_ids,
|
| 310 |
+
past_key_values=past_key_values,
|
| 311 |
+
use_cache=use_cache,
|
| 312 |
+
cache_position=cache_position,
|
| 313 |
+
position_embeddings=position_embeddings,
|
| 314 |
+
**kwargs,
|
| 315 |
+
)
|
| 316 |
+
hidden_states = residual + hidden_states
|
| 317 |
+
|
| 318 |
+
# Fully Connected
|
| 319 |
+
residual = hidden_states
|
| 320 |
+
hidden_states = self.feedforward_layernorm(hidden_states)
|
| 321 |
+
hidden_states = self.mlp(hidden_states)
|
| 322 |
+
hidden_states = residual + hidden_states
|
| 323 |
+
return hidden_states
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class ApertusPreTrainedModel(LlamaPreTrainedModel):
|
| 327 |
+
pass
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class ApertusModel(LlamaModel):
|
| 331 |
+
pass
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class ApertusForCausalLM(LlamaForCausalLM):
|
| 335 |
+
def forward(self, **super_kwargs):
|
| 336 |
+
r"""
|
| 337 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 338 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 339 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 340 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 341 |
+
|
| 342 |
+
Example:
|
| 343 |
+
|
| 344 |
+
```python
|
| 345 |
+
>>> from transformers import AutoTokenizer, ApertusForCausalLM
|
| 346 |
+
|
| 347 |
+
>>> model = ApertusForCausalLM.from_pretrained("swiss-ai/Apertus-8B")
|
| 348 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B")
|
| 349 |
+
|
| 350 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 351 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 352 |
+
|
| 353 |
+
>>> # Generate
|
| 354 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 355 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 356 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 357 |
+
```"""
|
| 358 |
+
return super().forward(**super_kwargs)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class ApertusForTokenClassification(LlamaForTokenClassification):
|
| 362 |
+
pass
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
__all__ = [
|
| 366 |
+
"ApertusConfig",
|
| 367 |
+
"ApertusModel",
|
| 368 |
+
"ApertusForCausalLM",
|
| 369 |
+
"ApertusForTokenClassification",
|
| 370 |
+
"ApertusPreTrainedModel",
|
| 371 |
+
]
|
venv/lib/python3.13/site-packages/transformers/models/arcee/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_arcee import *
|
| 22 |
+
from .modeling_arcee import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
venv/lib/python3.13/site-packages/transformers/models/arcee/configuration_arcee.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/arcee/modular_arcee.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_arcee.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
from ...configuration_utils import PretrainedConfig
|
| 23 |
+
from ...modeling_rope_utils import rope_config_validation
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ArceeConfig(PretrainedConfig):
|
| 27 |
+
r"""
|
| 28 |
+
This is the configuration class to store the configuration of a [`ArceeModel`]. It is used to instantiate an Arcee
|
| 29 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 30 |
+
defaults will yield a similar configuration to that of the AFM-4.5B-Base.
|
| 31 |
+
|
| 32 |
+
Pre-trained weights are available at
|
| 33 |
+
[arcee-ai/AFM-4.5B](https://huggingface.co/arcee-ai/AFM-4.5B)
|
| 34 |
+
and were used to build the examples below.
|
| 35 |
+
|
| 36 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 37 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 41 |
+
Vocabulary size of the Arcee model. Defines the number of different tokens that can be represented by the
|
| 42 |
+
`inputs_ids` passed when calling [`ArceeModel`]
|
| 43 |
+
hidden_size (`int`, *optional*, defaults to 2560):
|
| 44 |
+
Dimension of the hidden representations.
|
| 45 |
+
intermediate_size (`int`, *optional*, defaults to 18432):
|
| 46 |
+
Dimension of the MLP representations.
|
| 47 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 48 |
+
Number of hidden layers in the Transformer decoder.
|
| 49 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 50 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 51 |
+
num_key_value_heads (`int`, *optional*):
|
| 52 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 53 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 54 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 55 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 56 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 57 |
+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
| 58 |
+
`num_attention_heads`.
|
| 59 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
|
| 60 |
+
The non-linear activation function (function or string) in the decoder.
|
| 61 |
+
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
| 62 |
+
The maximum sequence length that this model might ever be used with. AFM-4.5B-Base supports up to 16384 tokens.
|
| 63 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 64 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 65 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 66 |
+
The epsilon used by the rms normalization layers.
|
| 67 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 68 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 69 |
+
relevant if `config.is_decoder=True`.
|
| 70 |
+
pad_token_id (`int`, *optional*):
|
| 71 |
+
Padding token id.
|
| 72 |
+
bos_token_id (`int`, *optional*, defaults to 128000):
|
| 73 |
+
Beginning of stream token id.
|
| 74 |
+
eos_token_id (`int`, *optional*, defaults to 128001):
|
| 75 |
+
End of stream token id.
|
| 76 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 77 |
+
Whether to tie weight embeddings
|
| 78 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 79 |
+
The base period of the RoPE embeddings.
|
| 80 |
+
rope_scaling (`Dict`, *optional*):
|
| 81 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 82 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 83 |
+
accordingly.
|
| 84 |
+
Expected contents:
|
| 85 |
+
`rope_type` (`str`):
|
| 86 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'yarn'], with 'default' being the original RoPE implementation.
|
| 87 |
+
`factor` (`float`, *optional*):
|
| 88 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 89 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 90 |
+
original maximum pre-trained length.
|
| 91 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 92 |
+
Used with 'yarn'. The original max position embeddings used during pretraining.
|
| 93 |
+
`attention_factor` (`float`, *optional*):
|
| 94 |
+
Used with 'yarn'. The scaling factor to be applied on the attention computation. If unspecified,
|
| 95 |
+
it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value.
|
| 96 |
+
`beta_fast` (`float`, *optional*):
|
| 97 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 98 |
+
ramp function. If unspecified, it defaults to 32.
|
| 99 |
+
`beta_slow` (`float`, *optional*):
|
| 100 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 101 |
+
ramp function. If unspecified, it defaults to 1.
|
| 102 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 103 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 104 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 105 |
+
The dropout ratio for the attention probabilities.
|
| 106 |
+
mlp_bias (`bool`, *optional*, defaults to `False`):
|
| 107 |
+
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
| 108 |
+
head_dim (`int`, *optional*):
|
| 109 |
+
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
|
| 110 |
+
|
| 111 |
+
```python
|
| 112 |
+
>>> from transformers import ArceeModel, ArceeConfig
|
| 113 |
+
|
| 114 |
+
>>> # Initializing an Arcee AFM-4.5B-Base style configuration
|
| 115 |
+
>>> configuration = ArceeConfig()
|
| 116 |
+
|
| 117 |
+
>>> # Initializing a model from the AFM-4.5B-Base style configuration
|
| 118 |
+
>>> model = ArceeModel(configuration)
|
| 119 |
+
|
| 120 |
+
>>> # Accessing the model configuration
|
| 121 |
+
>>> configuration = model.config
|
| 122 |
+
```"""
|
| 123 |
+
|
| 124 |
+
model_type = "arcee"
|
| 125 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 126 |
+
base_model_tp_plan = {
|
| 127 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 128 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 129 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 130 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 131 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 132 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 133 |
+
}
|
| 134 |
+
base_model_pp_plan = {
|
| 135 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 136 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 137 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
vocab_size=32000,
|
| 143 |
+
hidden_size=2560,
|
| 144 |
+
intermediate_size=18432,
|
| 145 |
+
num_hidden_layers=32,
|
| 146 |
+
num_attention_heads=32,
|
| 147 |
+
num_key_value_heads=None,
|
| 148 |
+
hidden_act="relu2",
|
| 149 |
+
max_position_embeddings=4096,
|
| 150 |
+
initializer_range=0.02,
|
| 151 |
+
rms_norm_eps=1e-5,
|
| 152 |
+
use_cache=True,
|
| 153 |
+
pad_token_id=None,
|
| 154 |
+
bos_token_id=128000,
|
| 155 |
+
eos_token_id=128001,
|
| 156 |
+
tie_word_embeddings=False,
|
| 157 |
+
rope_theta=10000.0,
|
| 158 |
+
rope_scaling=None,
|
| 159 |
+
attention_bias=False,
|
| 160 |
+
attention_dropout=0.0,
|
| 161 |
+
mlp_bias=False,
|
| 162 |
+
head_dim=None,
|
| 163 |
+
**kwargs,
|
| 164 |
+
):
|
| 165 |
+
super().__init__(
|
| 166 |
+
pad_token_id=pad_token_id,
|
| 167 |
+
bos_token_id=bos_token_id,
|
| 168 |
+
eos_token_id=eos_token_id,
|
| 169 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 170 |
+
**kwargs,
|
| 171 |
+
)
|
| 172 |
+
self.vocab_size = vocab_size
|
| 173 |
+
self.max_position_embeddings = max_position_embeddings
|
| 174 |
+
self.hidden_size = hidden_size
|
| 175 |
+
self.intermediate_size = intermediate_size
|
| 176 |
+
self.num_hidden_layers = num_hidden_layers
|
| 177 |
+
self.num_attention_heads = num_attention_heads
|
| 178 |
+
|
| 179 |
+
# for backward compatibility
|
| 180 |
+
if num_key_value_heads is None:
|
| 181 |
+
num_key_value_heads = num_attention_heads
|
| 182 |
+
|
| 183 |
+
self.num_key_value_heads = num_key_value_heads
|
| 184 |
+
self.hidden_act = hidden_act
|
| 185 |
+
self.initializer_range = initializer_range
|
| 186 |
+
self.rms_norm_eps = rms_norm_eps
|
| 187 |
+
self.use_cache = use_cache
|
| 188 |
+
self.rope_theta = rope_theta
|
| 189 |
+
self.rope_scaling = rope_scaling
|
| 190 |
+
self.attention_bias = attention_bias
|
| 191 |
+
self.attention_dropout = attention_dropout
|
| 192 |
+
self.mlp_bias = mlp_bias
|
| 193 |
+
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
| 194 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 195 |
+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
| 196 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 197 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 198 |
+
rope_config_validation(self)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
__all__ = ["ArceeConfig"]
|
venv/lib/python3.13/site-packages/transformers/models/arcee/modeling_arcee.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/arcee/modular_arcee.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_arcee.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
from typing import Callable, Optional, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
|
| 27 |
+
from transformers.utils import auto_docstring
|
| 28 |
+
|
| 29 |
+
from ...activations import ACT2FN
|
| 30 |
+
from ...cache_utils import Cache, DynamicCache
|
| 31 |
+
from ...generation import GenerationMixin
|
| 32 |
+
from ...integrations import use_kernel_forward_from_hub
|
| 33 |
+
from ...masking_utils import create_causal_mask
|
| 34 |
+
from ...modeling_layers import (
|
| 35 |
+
GenericForQuestionAnswering,
|
| 36 |
+
GenericForSequenceClassification,
|
| 37 |
+
GenericForTokenClassification,
|
| 38 |
+
GradientCheckpointingLayer,
|
| 39 |
+
)
|
| 40 |
+
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 41 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 42 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 43 |
+
from ...processing_utils import Unpack
|
| 44 |
+
from ...utils import TransformersKwargs, can_return_tuple
|
| 45 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 46 |
+
from ...utils.generic import check_model_inputs
|
| 47 |
+
from .configuration_arcee import ArceeConfig
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ArceeMLP(nn.Module):
|
| 51 |
+
def __init__(self, config):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.config = config
|
| 54 |
+
self.hidden_size = config.hidden_size
|
| 55 |
+
self.intermediate_size = config.intermediate_size
|
| 56 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 57 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
| 58 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
return self.down_proj(self.act_fn(self.up_proj(x)))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 65 |
+
class ArceeRMSNorm(nn.Module):
|
| 66 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 67 |
+
"""
|
| 68 |
+
ArceeRMSNorm is equivalent to T5LayerNorm
|
| 69 |
+
"""
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 72 |
+
self.variance_epsilon = eps
|
| 73 |
+
|
| 74 |
+
def forward(self, hidden_states):
|
| 75 |
+
input_dtype = hidden_states.dtype
|
| 76 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 77 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 78 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 79 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 80 |
+
|
| 81 |
+
def extra_repr(self):
|
| 82 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class ArceeRotaryEmbedding(nn.Module):
|
| 86 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 87 |
+
|
| 88 |
+
def __init__(self, config: ArceeConfig, device=None):
|
| 89 |
+
super().__init__()
|
| 90 |
+
# BC: "rope_type" was originally "type"
|
| 91 |
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 92 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 93 |
+
else:
|
| 94 |
+
self.rope_type = "default"
|
| 95 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 96 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 97 |
+
|
| 98 |
+
self.config = config
|
| 99 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 100 |
+
|
| 101 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 102 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 103 |
+
self.original_inv_freq = self.inv_freq
|
| 104 |
+
|
| 105 |
+
@torch.no_grad()
|
| 106 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 107 |
+
def forward(self, x, position_ids):
|
| 108 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 109 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 110 |
+
|
| 111 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 112 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 113 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 114 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 115 |
+
cos = emb.cos() * self.attention_scaling
|
| 116 |
+
sin = emb.sin() * self.attention_scaling
|
| 117 |
+
|
| 118 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def rotate_half(x):
|
| 122 |
+
"""Rotates half the hidden dims of the input."""
|
| 123 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 124 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 125 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 129 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
q (`torch.Tensor`): The query tensor.
|
| 133 |
+
k (`torch.Tensor`): The key tensor.
|
| 134 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 135 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 136 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 137 |
+
Deprecated and unused.
|
| 138 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 139 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 140 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 141 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 142 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 143 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 144 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 145 |
+
Returns:
|
| 146 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 147 |
+
"""
|
| 148 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 149 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 150 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 151 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 152 |
+
return q_embed, k_embed
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 156 |
+
"""
|
| 157 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 158 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 159 |
+
"""
|
| 160 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 161 |
+
if n_rep == 1:
|
| 162 |
+
return hidden_states
|
| 163 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 164 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def eager_attention_forward(
|
| 168 |
+
module: nn.Module,
|
| 169 |
+
query: torch.Tensor,
|
| 170 |
+
key: torch.Tensor,
|
| 171 |
+
value: torch.Tensor,
|
| 172 |
+
attention_mask: Optional[torch.Tensor],
|
| 173 |
+
scaling: float,
|
| 174 |
+
dropout: float = 0.0,
|
| 175 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 176 |
+
):
|
| 177 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 178 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 179 |
+
|
| 180 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 181 |
+
if attention_mask is not None:
|
| 182 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 183 |
+
attn_weights = attn_weights + causal_mask
|
| 184 |
+
|
| 185 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 186 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 187 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 188 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 189 |
+
|
| 190 |
+
return attn_output, attn_weights
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class ArceeAttention(nn.Module):
|
| 194 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 195 |
+
|
| 196 |
+
def __init__(self, config: ArceeConfig, layer_idx: int):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.config = config
|
| 199 |
+
self.layer_idx = layer_idx
|
| 200 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 201 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 202 |
+
self.scaling = self.head_dim**-0.5
|
| 203 |
+
self.attention_dropout = config.attention_dropout
|
| 204 |
+
self.is_causal = True
|
| 205 |
+
|
| 206 |
+
self.q_proj = nn.Linear(
|
| 207 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 208 |
+
)
|
| 209 |
+
self.k_proj = nn.Linear(
|
| 210 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 211 |
+
)
|
| 212 |
+
self.v_proj = nn.Linear(
|
| 213 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 214 |
+
)
|
| 215 |
+
self.o_proj = nn.Linear(
|
| 216 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 220 |
+
def forward(
|
| 221 |
+
self,
|
| 222 |
+
hidden_states: torch.Tensor,
|
| 223 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 224 |
+
attention_mask: Optional[torch.Tensor],
|
| 225 |
+
past_key_values: Optional[Cache] = None,
|
| 226 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 227 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 228 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 229 |
+
input_shape = hidden_states.shape[:-1]
|
| 230 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 231 |
+
|
| 232 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 233 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 234 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 235 |
+
|
| 236 |
+
cos, sin = position_embeddings
|
| 237 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 238 |
+
|
| 239 |
+
if past_key_values is not None:
|
| 240 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 241 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 242 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 243 |
+
|
| 244 |
+
attention_interface: Callable = eager_attention_forward
|
| 245 |
+
if self.config._attn_implementation != "eager":
|
| 246 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 247 |
+
|
| 248 |
+
attn_output, attn_weights = attention_interface(
|
| 249 |
+
self,
|
| 250 |
+
query_states,
|
| 251 |
+
key_states,
|
| 252 |
+
value_states,
|
| 253 |
+
attention_mask,
|
| 254 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 255 |
+
scaling=self.scaling,
|
| 256 |
+
**kwargs,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 260 |
+
attn_output = self.o_proj(attn_output)
|
| 261 |
+
return attn_output, attn_weights
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class ArceeDecoderLayer(GradientCheckpointingLayer):
|
| 265 |
+
def __init__(self, config: ArceeConfig, layer_idx: int):
|
| 266 |
+
super().__init__()
|
| 267 |
+
self.hidden_size = config.hidden_size
|
| 268 |
+
|
| 269 |
+
self.self_attn = ArceeAttention(config=config, layer_idx=layer_idx)
|
| 270 |
+
|
| 271 |
+
self.mlp = ArceeMLP(config)
|
| 272 |
+
self.input_layernorm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 273 |
+
self.post_attention_layernorm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 274 |
+
|
| 275 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 276 |
+
def forward(
|
| 277 |
+
self,
|
| 278 |
+
hidden_states: torch.Tensor,
|
| 279 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 280 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 281 |
+
past_key_values: Optional[Cache] = None,
|
| 282 |
+
use_cache: Optional[bool] = False,
|
| 283 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 284 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 285 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 286 |
+
) -> torch.Tensor:
|
| 287 |
+
residual = hidden_states
|
| 288 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 289 |
+
# Self Attention
|
| 290 |
+
hidden_states, _ = self.self_attn(
|
| 291 |
+
hidden_states=hidden_states,
|
| 292 |
+
attention_mask=attention_mask,
|
| 293 |
+
position_ids=position_ids,
|
| 294 |
+
past_key_values=past_key_values,
|
| 295 |
+
use_cache=use_cache,
|
| 296 |
+
cache_position=cache_position,
|
| 297 |
+
position_embeddings=position_embeddings,
|
| 298 |
+
**kwargs,
|
| 299 |
+
)
|
| 300 |
+
hidden_states = residual + hidden_states
|
| 301 |
+
|
| 302 |
+
# Fully Connected
|
| 303 |
+
residual = hidden_states
|
| 304 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 305 |
+
hidden_states = self.mlp(hidden_states)
|
| 306 |
+
hidden_states = residual + hidden_states
|
| 307 |
+
return hidden_states
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
@auto_docstring
|
| 311 |
+
class ArceePreTrainedModel(PreTrainedModel):
|
| 312 |
+
config: ArceeConfig
|
| 313 |
+
base_model_prefix = "model"
|
| 314 |
+
supports_gradient_checkpointing = True
|
| 315 |
+
_no_split_modules = ["ArceeDecoderLayer"]
|
| 316 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 317 |
+
_supports_flash_attn = True
|
| 318 |
+
_supports_sdpa = True
|
| 319 |
+
_supports_flex_attn = True
|
| 320 |
+
|
| 321 |
+
_can_compile_fullgraph = True
|
| 322 |
+
_supports_attention_backend = True
|
| 323 |
+
_can_record_outputs = {
|
| 324 |
+
"hidden_states": ArceeDecoderLayer,
|
| 325 |
+
"attentions": ArceeAttention,
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@auto_docstring
|
| 330 |
+
class ArceeModel(ArceePreTrainedModel):
|
| 331 |
+
def __init__(self, config: ArceeConfig):
|
| 332 |
+
super().__init__(config)
|
| 333 |
+
self.padding_idx = config.pad_token_id
|
| 334 |
+
self.vocab_size = config.vocab_size
|
| 335 |
+
|
| 336 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 337 |
+
self.layers = nn.ModuleList(
|
| 338 |
+
[ArceeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 339 |
+
)
|
| 340 |
+
self.norm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 341 |
+
self.rotary_emb = ArceeRotaryEmbedding(config=config)
|
| 342 |
+
self.gradient_checkpointing = False
|
| 343 |
+
|
| 344 |
+
# Initialize weights and apply final processing
|
| 345 |
+
self.post_init()
|
| 346 |
+
|
| 347 |
+
@check_model_inputs()
|
| 348 |
+
@auto_docstring
|
| 349 |
+
def forward(
|
| 350 |
+
self,
|
| 351 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 352 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 353 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 354 |
+
past_key_values: Optional[Cache] = None,
|
| 355 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 356 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 357 |
+
use_cache: Optional[bool] = None,
|
| 358 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 359 |
+
) -> BaseModelOutputWithPast:
|
| 360 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 361 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 362 |
+
|
| 363 |
+
if inputs_embeds is None:
|
| 364 |
+
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
| 365 |
+
|
| 366 |
+
if use_cache and past_key_values is None:
|
| 367 |
+
past_key_values = DynamicCache(config=self.config)
|
| 368 |
+
|
| 369 |
+
if cache_position is None:
|
| 370 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 371 |
+
cache_position: torch.Tensor = torch.arange(
|
| 372 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
if position_ids is None:
|
| 376 |
+
position_ids = cache_position.unsqueeze(0)
|
| 377 |
+
|
| 378 |
+
causal_mask = create_causal_mask(
|
| 379 |
+
config=self.config,
|
| 380 |
+
input_embeds=inputs_embeds,
|
| 381 |
+
attention_mask=attention_mask,
|
| 382 |
+
cache_position=cache_position,
|
| 383 |
+
past_key_values=past_key_values,
|
| 384 |
+
position_ids=position_ids,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
hidden_states = inputs_embeds
|
| 388 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 389 |
+
|
| 390 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 391 |
+
hidden_states = decoder_layer(
|
| 392 |
+
hidden_states,
|
| 393 |
+
attention_mask=causal_mask,
|
| 394 |
+
position_ids=position_ids,
|
| 395 |
+
past_key_values=past_key_values,
|
| 396 |
+
cache_position=cache_position,
|
| 397 |
+
position_embeddings=position_embeddings,
|
| 398 |
+
**kwargs,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
hidden_states = self.norm(hidden_states)
|
| 402 |
+
return BaseModelOutputWithPast(
|
| 403 |
+
last_hidden_state=hidden_states,
|
| 404 |
+
past_key_values=past_key_values,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
| 409 |
+
class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
|
| 410 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 411 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 412 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 413 |
+
|
| 414 |
+
def __init__(self, config):
|
| 415 |
+
super().__init__(config)
|
| 416 |
+
self.model = ArceeModel(config)
|
| 417 |
+
self.vocab_size = config.vocab_size
|
| 418 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 419 |
+
|
| 420 |
+
# Initialize weights and apply final processing
|
| 421 |
+
self.post_init()
|
| 422 |
+
|
| 423 |
+
@can_return_tuple
|
| 424 |
+
@auto_docstring
|
| 425 |
+
def forward(
|
| 426 |
+
self,
|
| 427 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 428 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 429 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 430 |
+
past_key_values: Optional[Cache] = None,
|
| 431 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 432 |
+
labels: Optional[torch.LongTensor] = None,
|
| 433 |
+
use_cache: Optional[bool] = None,
|
| 434 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 435 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 436 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 437 |
+
) -> CausalLMOutputWithPast:
|
| 438 |
+
r"""
|
| 439 |
+
Example:
|
| 440 |
+
|
| 441 |
+
```python
|
| 442 |
+
>>> from transformers import AutoTokenizer, ArceeForCausalLM
|
| 443 |
+
|
| 444 |
+
>>> model = ArceeForCausalLM.from_pretrained("meta-arcee/Arcee-2-7b-hf")
|
| 445 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-arcee/Arcee-2-7b-hf")
|
| 446 |
+
|
| 447 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 448 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 449 |
+
|
| 450 |
+
>>> # Generate
|
| 451 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 452 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 453 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 454 |
+
```"""
|
| 455 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 456 |
+
input_ids=input_ids,
|
| 457 |
+
attention_mask=attention_mask,
|
| 458 |
+
position_ids=position_ids,
|
| 459 |
+
past_key_values=past_key_values,
|
| 460 |
+
inputs_embeds=inputs_embeds,
|
| 461 |
+
use_cache=use_cache,
|
| 462 |
+
cache_position=cache_position,
|
| 463 |
+
**kwargs,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
hidden_states = outputs.last_hidden_state
|
| 467 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 468 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 469 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 470 |
+
|
| 471 |
+
loss = None
|
| 472 |
+
if labels is not None:
|
| 473 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 474 |
+
|
| 475 |
+
return CausalLMOutputWithPast(
|
| 476 |
+
loss=loss,
|
| 477 |
+
logits=logits,
|
| 478 |
+
past_key_values=outputs.past_key_values,
|
| 479 |
+
hidden_states=outputs.hidden_states,
|
| 480 |
+
attentions=outputs.attentions,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
| 485 |
+
class ArceeForSequenceClassification(GenericForSequenceClassification, ArceePreTrainedModel):
|
| 486 |
+
pass
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
| 490 |
+
class ArceeForQuestionAnswering(GenericForQuestionAnswering, ArceePreTrainedModel):
|
| 491 |
+
base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
| 495 |
+
class ArceeForTokenClassification(GenericForTokenClassification, ArceePreTrainedModel):
|
| 496 |
+
pass
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
__all__ = [
|
| 500 |
+
"ArceeForCausalLM",
|
| 501 |
+
"ArceeForQuestionAnswering",
|
| 502 |
+
"ArceeForSequenceClassification",
|
| 503 |
+
"ArceeForTokenClassification",
|
| 504 |
+
"ArceeModel",
|
| 505 |
+
"ArceePreTrainedModel",
|
| 506 |
+
]
|
venv/lib/python3.13/site-packages/transformers/models/arcee/modular_arcee.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch Arcee model."""
|
| 16 |
+
|
| 17 |
+
from transformers.utils import auto_docstring, logging
|
| 18 |
+
|
| 19 |
+
from ..llama.configuration_llama import LlamaConfig
|
| 20 |
+
from ..llama.modeling_llama import (
|
| 21 |
+
LlamaForCausalLM,
|
| 22 |
+
LlamaForQuestionAnswering,
|
| 23 |
+
LlamaForSequenceClassification,
|
| 24 |
+
LlamaForTokenClassification,
|
| 25 |
+
)
|
| 26 |
+
from ..nemotron.modeling_nemotron import NemotronMLP
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ArceeConfig(LlamaConfig):
|
| 33 |
+
r"""
|
| 34 |
+
This is the configuration class to store the configuration of a [`ArceeModel`]. It is used to instantiate an Arcee
|
| 35 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 36 |
+
defaults will yield a similar configuration to that of the AFM-4.5B-Base.
|
| 37 |
+
|
| 38 |
+
Pre-trained weights are available at
|
| 39 |
+
[arcee-ai/AFM-4.5B](https://huggingface.co/arcee-ai/AFM-4.5B)
|
| 40 |
+
and were used to build the examples below.
|
| 41 |
+
|
| 42 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 43 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 47 |
+
Vocabulary size of the Arcee model. Defines the number of different tokens that can be represented by the
|
| 48 |
+
`inputs_ids` passed when calling [`ArceeModel`]
|
| 49 |
+
hidden_size (`int`, *optional*, defaults to 2560):
|
| 50 |
+
Dimension of the hidden representations.
|
| 51 |
+
intermediate_size (`int`, *optional*, defaults to 18432):
|
| 52 |
+
Dimension of the MLP representations.
|
| 53 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 54 |
+
Number of hidden layers in the Transformer decoder.
|
| 55 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 56 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 57 |
+
num_key_value_heads (`int`, *optional*):
|
| 58 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 59 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 60 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 61 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 62 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 63 |
+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
| 64 |
+
`num_attention_heads`.
|
| 65 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
|
| 66 |
+
The non-linear activation function (function or string) in the decoder.
|
| 67 |
+
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
| 68 |
+
The maximum sequence length that this model might ever be used with. AFM-4.5B-Base supports up to 16384 tokens.
|
| 69 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 70 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 71 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 72 |
+
The epsilon used by the rms normalization layers.
|
| 73 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 74 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 75 |
+
relevant if `config.is_decoder=True`.
|
| 76 |
+
pad_token_id (`int`, *optional*):
|
| 77 |
+
Padding token id.
|
| 78 |
+
bos_token_id (`int`, *optional*, defaults to 128000):
|
| 79 |
+
Beginning of stream token id.
|
| 80 |
+
eos_token_id (`int`, *optional*, defaults to 128001):
|
| 81 |
+
End of stream token id.
|
| 82 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 83 |
+
Whether to tie weight embeddings
|
| 84 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 85 |
+
The base period of the RoPE embeddings.
|
| 86 |
+
rope_scaling (`Dict`, *optional*):
|
| 87 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 88 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 89 |
+
accordingly.
|
| 90 |
+
Expected contents:
|
| 91 |
+
`rope_type` (`str`):
|
| 92 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'yarn'], with 'default' being the original RoPE implementation.
|
| 93 |
+
`factor` (`float`, *optional*):
|
| 94 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 95 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 96 |
+
original maximum pre-trained length.
|
| 97 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 98 |
+
Used with 'yarn'. The original max position embeddings used during pretraining.
|
| 99 |
+
`attention_factor` (`float`, *optional*):
|
| 100 |
+
Used with 'yarn'. The scaling factor to be applied on the attention computation. If unspecified,
|
| 101 |
+
it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value.
|
| 102 |
+
`beta_fast` (`float`, *optional*):
|
| 103 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 104 |
+
ramp function. If unspecified, it defaults to 32.
|
| 105 |
+
`beta_slow` (`float`, *optional*):
|
| 106 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 107 |
+
ramp function. If unspecified, it defaults to 1.
|
| 108 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 109 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 110 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 111 |
+
The dropout ratio for the attention probabilities.
|
| 112 |
+
mlp_bias (`bool`, *optional*, defaults to `False`):
|
| 113 |
+
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
| 114 |
+
head_dim (`int`, *optional*):
|
| 115 |
+
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
|
| 116 |
+
|
| 117 |
+
```python
|
| 118 |
+
>>> from transformers import ArceeModel, ArceeConfig
|
| 119 |
+
|
| 120 |
+
>>> # Initializing an Arcee AFM-4.5B-Base style configuration
|
| 121 |
+
>>> configuration = ArceeConfig()
|
| 122 |
+
|
| 123 |
+
>>> # Initializing a model from the AFM-4.5B-Base style configuration
|
| 124 |
+
>>> model = ArceeModel(configuration)
|
| 125 |
+
|
| 126 |
+
>>> # Accessing the model configuration
|
| 127 |
+
>>> configuration = model.config
|
| 128 |
+
```"""
|
| 129 |
+
|
| 130 |
+
model_type = "arcee"
|
| 131 |
+
base_model_tp_plan = {
|
| 132 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 133 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 134 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 135 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 136 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 137 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
vocab_size=32000,
|
| 143 |
+
hidden_size=2560,
|
| 144 |
+
intermediate_size=18432,
|
| 145 |
+
num_hidden_layers=32,
|
| 146 |
+
num_attention_heads=32,
|
| 147 |
+
num_key_value_heads=None,
|
| 148 |
+
hidden_act="relu2",
|
| 149 |
+
max_position_embeddings=4096,
|
| 150 |
+
initializer_range=0.02,
|
| 151 |
+
rms_norm_eps=1e-5,
|
| 152 |
+
use_cache=True,
|
| 153 |
+
pad_token_id=None,
|
| 154 |
+
bos_token_id=128000,
|
| 155 |
+
eos_token_id=128001,
|
| 156 |
+
tie_word_embeddings=False,
|
| 157 |
+
rope_theta=10000.0,
|
| 158 |
+
rope_scaling=None,
|
| 159 |
+
attention_bias=False,
|
| 160 |
+
attention_dropout=0.0,
|
| 161 |
+
mlp_bias=False,
|
| 162 |
+
head_dim=None,
|
| 163 |
+
**kwargs,
|
| 164 |
+
):
|
| 165 |
+
super().__init__(
|
| 166 |
+
vocab_size=vocab_size,
|
| 167 |
+
hidden_size=hidden_size,
|
| 168 |
+
intermediate_size=intermediate_size,
|
| 169 |
+
num_hidden_layers=num_hidden_layers,
|
| 170 |
+
num_attention_heads=num_attention_heads,
|
| 171 |
+
num_key_value_heads=num_key_value_heads,
|
| 172 |
+
hidden_act=hidden_act,
|
| 173 |
+
max_position_embeddings=max_position_embeddings,
|
| 174 |
+
initializer_range=initializer_range,
|
| 175 |
+
rms_norm_eps=rms_norm_eps,
|
| 176 |
+
use_cache=use_cache,
|
| 177 |
+
pad_token_id=pad_token_id,
|
| 178 |
+
bos_token_id=bos_token_id,
|
| 179 |
+
eos_token_id=eos_token_id,
|
| 180 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 181 |
+
rope_theta=rope_theta,
|
| 182 |
+
rope_scaling=rope_scaling,
|
| 183 |
+
attention_bias=attention_bias,
|
| 184 |
+
attention_dropout=attention_dropout,
|
| 185 |
+
mlp_bias=mlp_bias,
|
| 186 |
+
head_dim=head_dim,
|
| 187 |
+
**kwargs,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
del self.pretraining_tp
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class ArceeMLP(NemotronMLP):
|
| 194 |
+
pass
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
| 198 |
+
class ArceeForCausalLM(LlamaForCausalLM):
|
| 199 |
+
pass
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
| 203 |
+
class ArceeForSequenceClassification(LlamaForSequenceClassification):
|
| 204 |
+
pass
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
| 208 |
+
class ArceeForQuestionAnswering(LlamaForQuestionAnswering):
|
| 209 |
+
pass
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
| 213 |
+
class ArceeForTokenClassification(LlamaForTokenClassification):
|
| 214 |
+
pass
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
__all__ = [
|
| 218 |
+
"ArceeConfig",
|
| 219 |
+
"ArceeForCausalLM",
|
| 220 |
+
"ArceeForQuestionAnswering",
|
| 221 |
+
"ArceeForSequenceClassification",
|
| 222 |
+
"ArceeForTokenClassification",
|
| 223 |
+
"ArceeModel", # noqa: F822
|
| 224 |
+
"ArceePreTrainedModel", # noqa: F822
|
| 225 |
+
]
|
venv/lib/python3.13/site-packages/transformers/models/aria/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_aria import *
|
| 22 |
+
from .image_processing_aria import *
|
| 23 |
+
from .modeling_aria import *
|
| 24 |
+
from .processing_aria import *
|
| 25 |
+
|
| 26 |
+
else:
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
_file = globals()["__file__"]
|
| 30 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
venv/lib/python3.13/site-packages/transformers/models/aria/configuration_aria.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/aria/modular_aria.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_aria.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
from ...configuration_utils import PretrainedConfig
|
| 24 |
+
from ...modeling_rope_utils import rope_config_validation
|
| 25 |
+
from ..auto import CONFIG_MAPPING, AutoConfig
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class AriaTextConfig(PretrainedConfig):
|
| 29 |
+
r"""
|
| 30 |
+
This class handles the configuration for the text component of the Aria model.
|
| 31 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
|
| 32 |
+
[rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
|
| 33 |
+
This class extends the LlamaConfig to include additional parameters specific to the Mixture of Experts (MoE) architecture.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 37 |
+
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
|
| 38 |
+
`inputs_ids` passed when calling [`LlamaModel`]
|
| 39 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 40 |
+
Dimension of the hidden representations.
|
| 41 |
+
intermediate_size (`int`, *optional*, defaults to 4096):
|
| 42 |
+
The size of the MLP representations.
|
| 43 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 44 |
+
Number of hidden layers in the Transformer decoder.
|
| 45 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 46 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 47 |
+
num_key_value_heads (`int`, *optional*):
|
| 48 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 49 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 50 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 51 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 52 |
+
by meanpooling all the original heads within that group. For more details, check out [this
|
| 53 |
+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
| 54 |
+
`num_attention_heads`.
|
| 55 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 56 |
+
The non-linear activation function (function or string) in the decoder.
|
| 57 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 58 |
+
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
|
| 59 |
+
Llama 2 up to 4096, CodeLlama up to 16384.
|
| 60 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 61 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 62 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 63 |
+
The epsilon used by the rms normalization layers.
|
| 64 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 65 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 66 |
+
relevant if `config.is_decoder=True`.
|
| 67 |
+
pad_token_id (`int`, *optional*, defaults to 2):
|
| 68 |
+
Padding token id.
|
| 69 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 70 |
+
Beginning of stream token id.
|
| 71 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 72 |
+
End of stream token id.
|
| 73 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
| 74 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
| 75 |
+
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
|
| 76 |
+
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
|
| 77 |
+
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
|
| 78 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 79 |
+
Whether to tie weight embeddings
|
| 80 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 81 |
+
The base period of the RoPE embeddings.
|
| 82 |
+
rope_scaling (`Dict`, *optional*):
|
| 83 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 84 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 85 |
+
accordingly.
|
| 86 |
+
Expected contents:
|
| 87 |
+
`rope_type` (`str`):
|
| 88 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 89 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 90 |
+
`factor` (`float`, *optional*):
|
| 91 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 92 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 93 |
+
original maximum pre-trained length.
|
| 94 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 95 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 96 |
+
pretraining.
|
| 97 |
+
`attention_factor` (`float`, *optional*):
|
| 98 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 99 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 100 |
+
`factor` field to infer the suggested value.
|
| 101 |
+
`beta_fast` (`float`, *optional*):
|
| 102 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 103 |
+
ramp function. If unspecified, it defaults to 32.
|
| 104 |
+
`beta_slow` (`float`, *optional*):
|
| 105 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 106 |
+
ramp function. If unspecified, it defaults to 1.
|
| 107 |
+
`short_factor` (`list[float]`, *optional*):
|
| 108 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 109 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 110 |
+
size divided by the number of attention heads divided by 2
|
| 111 |
+
`long_factor` (`list[float]`, *optional*):
|
| 112 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 113 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 114 |
+
size divided by the number of attention heads divided by 2
|
| 115 |
+
`low_freq_factor` (`float`, *optional*):
|
| 116 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 117 |
+
`high_freq_factor` (`float`, *optional*):
|
| 118 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 119 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 120 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 121 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 122 |
+
The dropout ratio for the attention probabilities.
|
| 123 |
+
mlp_bias (`bool`, *optional*, defaults to `False`):
|
| 124 |
+
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
| 125 |
+
head_dim (`int`, *optional*):
|
| 126 |
+
The attention head dimension. If None, it will default to hidden_size // num_heads
|
| 127 |
+
moe_num_experts (`int`, *optional*, defaults to 8):
|
| 128 |
+
The number of experts in the MoE layer.
|
| 129 |
+
moe_topk (`int`, *optional*, defaults to 2):
|
| 130 |
+
The number of top experts to route to for each token.
|
| 131 |
+
moe_num_shared_experts (`int`, *optional*, defaults to 2):
|
| 132 |
+
The number of shared experts.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
model_type = "aria_text"
|
| 136 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 137 |
+
# Default tensor parallel plan for base model `AriaTextModel`
|
| 138 |
+
base_model_tp_plan = {
|
| 139 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 140 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 141 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 142 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 143 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 144 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 145 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 146 |
+
}
|
| 147 |
+
base_model_pp_plan = {
|
| 148 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 149 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 150 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 151 |
+
}
|
| 152 |
+
base_config_key = "text_config"
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
vocab_size=32000,
|
| 157 |
+
hidden_size=4096,
|
| 158 |
+
intermediate_size: int = 4096,
|
| 159 |
+
num_hidden_layers=32,
|
| 160 |
+
num_attention_heads=32,
|
| 161 |
+
num_key_value_heads=None,
|
| 162 |
+
hidden_act="silu",
|
| 163 |
+
max_position_embeddings=2048,
|
| 164 |
+
initializer_range=0.02,
|
| 165 |
+
rms_norm_eps=1e-6,
|
| 166 |
+
use_cache=True,
|
| 167 |
+
pad_token_id=2,
|
| 168 |
+
bos_token_id=1,
|
| 169 |
+
eos_token_id=2,
|
| 170 |
+
pretraining_tp=1,
|
| 171 |
+
tie_word_embeddings=False,
|
| 172 |
+
rope_theta=10000.0,
|
| 173 |
+
rope_scaling=None,
|
| 174 |
+
attention_bias=False,
|
| 175 |
+
attention_dropout=0.0,
|
| 176 |
+
mlp_bias=False,
|
| 177 |
+
head_dim=None,
|
| 178 |
+
moe_num_experts: int = 8,
|
| 179 |
+
moe_topk: int = 2,
|
| 180 |
+
moe_num_shared_experts: int = 2,
|
| 181 |
+
**kwargs,
|
| 182 |
+
):
|
| 183 |
+
super().__init__(
|
| 184 |
+
pad_token_id=pad_token_id,
|
| 185 |
+
bos_token_id=bos_token_id,
|
| 186 |
+
eos_token_id=eos_token_id,
|
| 187 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 188 |
+
**kwargs,
|
| 189 |
+
)
|
| 190 |
+
self.vocab_size = vocab_size
|
| 191 |
+
self.max_position_embeddings = max_position_embeddings
|
| 192 |
+
self.hidden_size = hidden_size
|
| 193 |
+
self.intermediate_size = intermediate_size
|
| 194 |
+
self.num_hidden_layers = num_hidden_layers
|
| 195 |
+
self.num_attention_heads = num_attention_heads
|
| 196 |
+
|
| 197 |
+
# for backward compatibility
|
| 198 |
+
if num_key_value_heads is None:
|
| 199 |
+
num_key_value_heads = num_attention_heads
|
| 200 |
+
|
| 201 |
+
self.num_key_value_heads = num_key_value_heads
|
| 202 |
+
self.hidden_act = hidden_act
|
| 203 |
+
self.initializer_range = initializer_range
|
| 204 |
+
self.rms_norm_eps = rms_norm_eps
|
| 205 |
+
self.pretraining_tp = pretraining_tp
|
| 206 |
+
self.use_cache = use_cache
|
| 207 |
+
self.rope_theta = rope_theta
|
| 208 |
+
self.rope_scaling = rope_scaling
|
| 209 |
+
self.attention_bias = attention_bias
|
| 210 |
+
self.attention_dropout = attention_dropout
|
| 211 |
+
self.mlp_bias = mlp_bias
|
| 212 |
+
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
| 213 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 214 |
+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
| 215 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 216 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 217 |
+
rope_config_validation(self)
|
| 218 |
+
self.moe_num_experts = moe_num_experts
|
| 219 |
+
self.moe_topk = moe_topk
|
| 220 |
+
self.moe_num_shared_experts = moe_num_shared_experts
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class AriaConfig(PretrainedConfig):
|
| 224 |
+
r"""
|
| 225 |
+
This class handles the configuration for both vision and text components of the Aria model,
|
| 226 |
+
as well as additional parameters for image token handling and projector mapping.
|
| 227 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
|
| 228 |
+
[rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
|
| 229 |
+
|
| 230 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 231 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
vision_config (`AriaVisionConfig` or `dict`, *optional*):
|
| 235 |
+
Configuration for the vision component.
|
| 236 |
+
vision_feature_layer (`int`, *optional*, defaults to -1):
|
| 237 |
+
The index of the layer to select the vision feature.
|
| 238 |
+
text_config (`AriaTextConfig` or `dict`, *optional*):
|
| 239 |
+
Configuration for the text component.
|
| 240 |
+
projector_patch_to_query_dict (`dict`, *optional*):
|
| 241 |
+
Mapping of patch sizes to query dimensions.
|
| 242 |
+
image_token_index (`int`, *optional*, defaults to 9):
|
| 243 |
+
Index used to represent image tokens.
|
| 244 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 245 |
+
The standard deviation of the truncated normal initializer for initializing all weight matrices.
|
| 246 |
+
|
| 247 |
+
Attributes:
|
| 248 |
+
model_type (`str`):
|
| 249 |
+
Type of the model, set to `"aria"`.
|
| 250 |
+
image_token_index (`int`):
|
| 251 |
+
Index used to represent image tokens.
|
| 252 |
+
projector_patch_to_query_dict (`dict`):
|
| 253 |
+
Mapping of patch sizes to query dimensions.
|
| 254 |
+
vision_config (`AriaVisionConfig`):
|
| 255 |
+
Configuration for the vision component.
|
| 256 |
+
text_config (`AriaTextConfig`):
|
| 257 |
+
Configuration for the text component.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
model_type = "aria"
|
| 261 |
+
attribute_map = {
|
| 262 |
+
"image_token_id": "image_token_index",
|
| 263 |
+
}
|
| 264 |
+
sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig}
|
| 265 |
+
|
| 266 |
+
def __init__(
|
| 267 |
+
self,
|
| 268 |
+
vision_config=None,
|
| 269 |
+
vision_feature_layer: int = -1,
|
| 270 |
+
text_config: AriaTextConfig = None,
|
| 271 |
+
projector_patch_to_query_dict: Optional[dict] = None,
|
| 272 |
+
image_token_index: int = 9,
|
| 273 |
+
initializer_range: float = 0.02,
|
| 274 |
+
**kwargs,
|
| 275 |
+
):
|
| 276 |
+
self.image_token_index = image_token_index
|
| 277 |
+
|
| 278 |
+
# Convert the keys and values of projector_patch_to_query_dict to integers
|
| 279 |
+
# This ensures consistency even if they were provided as strings
|
| 280 |
+
if projector_patch_to_query_dict is None:
|
| 281 |
+
projector_patch_to_query_dict = {
|
| 282 |
+
1225: 128,
|
| 283 |
+
4900: 256,
|
| 284 |
+
}
|
| 285 |
+
self.projector_patch_to_query_dict = {int(k): int(v) for k, v in projector_patch_to_query_dict.items()}
|
| 286 |
+
self.max_value_projector_patch_to_query_dict = max(self.projector_patch_to_query_dict.values())
|
| 287 |
+
self.vision_feature_layer = vision_feature_layer
|
| 288 |
+
if isinstance(vision_config, dict):
|
| 289 |
+
vision_config["model_type"] = "idefics3_vision"
|
| 290 |
+
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
| 291 |
+
elif vision_config is None:
|
| 292 |
+
vision_config = CONFIG_MAPPING["idefics3_vision"]()
|
| 293 |
+
|
| 294 |
+
self.vision_config = vision_config
|
| 295 |
+
self.initializer_range = initializer_range
|
| 296 |
+
|
| 297 |
+
if isinstance(text_config, dict) and "model_type" in text_config:
|
| 298 |
+
text_config = AriaTextConfig(**text_config)
|
| 299 |
+
elif text_config is None:
|
| 300 |
+
text_config = AriaTextConfig()
|
| 301 |
+
|
| 302 |
+
self.text_config = text_config
|
| 303 |
+
|
| 304 |
+
super().__init__(**kwargs)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
__all__ = ["AriaConfig", "AriaTextConfig"]
|
venv/lib/python3.13/site-packages/transformers/models/aria/image_processing_aria.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/aria/modular_aria.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_aria.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
from collections.abc import Iterable
|
| 22 |
+
from typing import Optional, Union
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
|
| 27 |
+
from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
|
| 28 |
+
from ...image_utils import (
|
| 29 |
+
ChannelDimension,
|
| 30 |
+
ImageInput,
|
| 31 |
+
PILImageResampling,
|
| 32 |
+
get_image_size,
|
| 33 |
+
infer_channel_dimension_format,
|
| 34 |
+
is_scaled_image,
|
| 35 |
+
make_flat_list_of_images,
|
| 36 |
+
to_numpy_array,
|
| 37 |
+
valid_images,
|
| 38 |
+
validate_preprocess_arguments,
|
| 39 |
+
)
|
| 40 |
+
from ...utils import TensorType, logging
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def divide_to_patches(image: np.ndarray, patch_size: int, input_data_format) -> list[np.ndarray]:
|
| 47 |
+
"""
|
| 48 |
+
Divides an image into patches of a specified size.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
image (`np.ndarray`):
|
| 52 |
+
The input image.
|
| 53 |
+
patch_size (`int`):
|
| 54 |
+
The size of each patch.
|
| 55 |
+
input_data_format (`ChannelDimension` or `str`):
|
| 56 |
+
The channel dimension format of the input image.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
list: A list of np.ndarray representing the patches.
|
| 60 |
+
"""
|
| 61 |
+
patches = []
|
| 62 |
+
height, width = get_image_size(image, channel_dim=input_data_format)
|
| 63 |
+
for i in range(0, height, patch_size):
|
| 64 |
+
for j in range(0, width, patch_size):
|
| 65 |
+
if input_data_format == ChannelDimension.LAST:
|
| 66 |
+
patch = image[i : i + patch_size, j : j + patch_size]
|
| 67 |
+
else:
|
| 68 |
+
patch = image[:, i : i + patch_size, j : j + patch_size]
|
| 69 |
+
patches.append(patch)
|
| 70 |
+
|
| 71 |
+
return patches
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class AriaImageProcessor(BaseImageProcessor):
|
| 75 |
+
"""
|
| 76 |
+
A vision processor for the Aria model that handles image preprocessing.
|
| 77 |
+
Initialize the AriaImageProcessor.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
|
| 81 |
+
Mean values for normalization.
|
| 82 |
+
image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
|
| 83 |
+
Standard deviation values for normalization.
|
| 84 |
+
max_image_size (`int`, *optional*, defaults to 980):
|
| 85 |
+
Maximum image size.
|
| 86 |
+
min_image_size (`int`, *optional*, defaults to 336):
|
| 87 |
+
Minimum image size.
|
| 88 |
+
split_resolutions (`list`, *optional*, defaults to a list of optimal,resolutions as tuples):
|
| 89 |
+
The optimal resolutions for splitting the image.
|
| 90 |
+
split_image (`bool`, *optional*, defaults to `False`):
|
| 91 |
+
Whether to split the image.
|
| 92 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 93 |
+
Whether to convert the image to RGB.
|
| 94 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 95 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
| 96 |
+
the `preprocess` method.
|
| 97 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 98 |
+
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
| 99 |
+
method.
|
| 100 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 101 |
+
Whether to normalize the image.
|
| 102 |
+
resample (PILImageResampling, *optional*, defaults to `BICUBIC`):
|
| 103 |
+
The resampling filter to use if resizing the image.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
model_input_names = ["pixel_values", "pixel_mask", "num_crops"]
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
image_mean: Optional[list[float]] = None,
|
| 111 |
+
image_std: Optional[list[float]] = None,
|
| 112 |
+
max_image_size: int = 980,
|
| 113 |
+
min_image_size: int = 336,
|
| 114 |
+
split_resolutions: Optional[list[tuple[int, int]]] = None,
|
| 115 |
+
split_image: Optional[bool] = False,
|
| 116 |
+
do_convert_rgb: Optional[bool] = True,
|
| 117 |
+
do_rescale: bool = True,
|
| 118 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 119 |
+
do_normalize: Optional[bool] = True,
|
| 120 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 121 |
+
**kwargs,
|
| 122 |
+
):
|
| 123 |
+
super().__init__(**kwargs)
|
| 124 |
+
|
| 125 |
+
if image_mean is None:
|
| 126 |
+
image_mean = [0.5, 0.5, 0.5]
|
| 127 |
+
if image_std is None:
|
| 128 |
+
image_std = [0.5, 0.5, 0.5]
|
| 129 |
+
self.max_image_size = max_image_size
|
| 130 |
+
self.min_image_size = min_image_size
|
| 131 |
+
self.image_mean = image_mean
|
| 132 |
+
self.image_std = image_std
|
| 133 |
+
self.split_image = split_image
|
| 134 |
+
if split_resolutions is None:
|
| 135 |
+
split_resolutions = [(1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (2, 4), (2, 3), (2, 2), (2, 1), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1), (6, 1), (7, 1), (8, 1)] # fmt: skip
|
| 136 |
+
split_resolutions = [(el[0] * 490, el[1] * 490) for el in split_resolutions]
|
| 137 |
+
self.split_resolutions = split_resolutions
|
| 138 |
+
self.do_convert_rgb = do_convert_rgb
|
| 139 |
+
self.do_rescale = do_rescale
|
| 140 |
+
self.rescale_factor = rescale_factor
|
| 141 |
+
self.do_normalize = do_normalize
|
| 142 |
+
self.resample = resample
|
| 143 |
+
|
| 144 |
+
def preprocess(
|
| 145 |
+
self,
|
| 146 |
+
images: Union[ImageInput, list[ImageInput]],
|
| 147 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 148 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 149 |
+
max_image_size: Optional[int] = None,
|
| 150 |
+
min_image_size: Optional[int] = None,
|
| 151 |
+
split_image: Optional[bool] = None,
|
| 152 |
+
do_convert_rgb: Optional[bool] = None,
|
| 153 |
+
do_rescale: Optional[bool] = None,
|
| 154 |
+
rescale_factor: Optional[float] = None,
|
| 155 |
+
do_normalize: Optional[bool] = None,
|
| 156 |
+
resample: Optional[PILImageResampling] = None,
|
| 157 |
+
return_tensors: Optional[Union[str, TensorType]] = "pt",
|
| 158 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 159 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 160 |
+
):
|
| 161 |
+
"""
|
| 162 |
+
Process a list of images.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
images (ImageInput or list of ImageInput):
|
| 166 |
+
The input image or a list of images.
|
| 167 |
+
image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
|
| 168 |
+
Mean values for normalization.
|
| 169 |
+
image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
|
| 170 |
+
Standard deviation values for normalization.
|
| 171 |
+
max_image_size (`int`, *optional*, defaults to `self.max_image_size` (980)):
|
| 172 |
+
Maximum image size.
|
| 173 |
+
min_image_size (`int`, *optional*, defaults to `self.min_image_size` (336)):
|
| 174 |
+
Minimum image size.
|
| 175 |
+
split_image (`bool`, *optional*, defaults to `self.split_image` (False)):
|
| 176 |
+
Whether to split the image.
|
| 177 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb` (True)):
|
| 178 |
+
Whether to convert the image to RGB.
|
| 179 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 180 |
+
Whether to rescale the image.
|
| 181 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 182 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 183 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize` (True)):
|
| 184 |
+
Whether to normalize the image.
|
| 185 |
+
resample (PILImageResampling, *optional*, defaults to `self.resample` (BICUBIC)):
|
| 186 |
+
The resampling filter to use if resizing the image.
|
| 187 |
+
return_tensors (`str` or `TensorType`, *optional*, defaults to "pt"):
|
| 188 |
+
The type of tensor to return.
|
| 189 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 190 |
+
The channel dimension format for the output image. Can be one of:
|
| 191 |
+
- `"channels_first"` or `ChannelDimension.FIRST`:
|
| 192 |
+
image in (num_channels, height, width) format.
|
| 193 |
+
- `"channels_last"` or `ChannelDimension.LAST`:
|
| 194 |
+
image in (height, width, num_channels) format.
|
| 195 |
+
If unset, will use same as the input image.
|
| 196 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 197 |
+
The channel dimension format for the input image. Can be one of:
|
| 198 |
+
- `"channels_first"` or `ChannelDimension.FIRST`:
|
| 199 |
+
image in (num_channels, height, width) format.
|
| 200 |
+
- `"channels_last"` or `ChannelDimension.LAST`:
|
| 201 |
+
image in (height, width, num_channels) format.
|
| 202 |
+
If unset, will use the inferred format of the input image.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
BatchFeature:
|
| 206 |
+
A BatchFeature object containing:
|
| 207 |
+
- 'pixel_values':
|
| 208 |
+
Tensor of processed image pixel values.
|
| 209 |
+
- 'pixel_mask':
|
| 210 |
+
Boolean pixel mask. This mask is a 2D tensor of shape (max_image_size, max_image_size) where:
|
| 211 |
+
- True (1) values indicate pixels that belong to the original resized image.
|
| 212 |
+
- False (0) values indicate pixels that are part of the padding.
|
| 213 |
+
The mask helps distinguish between actual image content and padded areas in subsequent processing steps.
|
| 214 |
+
- 'num_crops':
|
| 215 |
+
The maximum number of crops across all images.
|
| 216 |
+
"""
|
| 217 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 218 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 219 |
+
max_image_size = max_image_size if max_image_size is not None else self.max_image_size
|
| 220 |
+
min_image_size = min_image_size if min_image_size is not None else self.min_image_size
|
| 221 |
+
split_image = split_image if split_image is not None else self.split_image
|
| 222 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 223 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 224 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 225 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 226 |
+
resample = resample if resample is not None else self.resample
|
| 227 |
+
|
| 228 |
+
if max_image_size not in [490, 980]:
|
| 229 |
+
raise ValueError("max_image_size must be either 490 or 980")
|
| 230 |
+
|
| 231 |
+
images = self.fetch_images(images)
|
| 232 |
+
images = make_flat_list_of_images(images)
|
| 233 |
+
|
| 234 |
+
if not valid_images(images):
|
| 235 |
+
raise ValueError(
|
| 236 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 237 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
validate_preprocess_arguments(
|
| 241 |
+
do_normalize=do_normalize,
|
| 242 |
+
image_mean=image_mean,
|
| 243 |
+
image_std=image_std,
|
| 244 |
+
resample=resample,
|
| 245 |
+
do_rescale=do_rescale,
|
| 246 |
+
rescale_factor=rescale_factor,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if do_convert_rgb:
|
| 250 |
+
images = [convert_to_rgb(image) for image in images]
|
| 251 |
+
|
| 252 |
+
# All transformations expect numpy arrays.
|
| 253 |
+
images = [to_numpy_array(image) for image in images]
|
| 254 |
+
|
| 255 |
+
if do_rescale and is_scaled_image(images[0]):
|
| 256 |
+
logger.warning_once(
|
| 257 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 258 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
if input_data_format is None:
|
| 262 |
+
# We assume that all images have the same channel dimension format.
|
| 263 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 264 |
+
|
| 265 |
+
pixel_values = []
|
| 266 |
+
pixel_masks = []
|
| 267 |
+
num_crops = None
|
| 268 |
+
|
| 269 |
+
for image in images:
|
| 270 |
+
if split_image:
|
| 271 |
+
crop_images = self.get_image_patches(
|
| 272 |
+
image,
|
| 273 |
+
self.split_resolutions,
|
| 274 |
+
max_image_size,
|
| 275 |
+
resample,
|
| 276 |
+
data_format=input_data_format,
|
| 277 |
+
input_data_format=input_data_format,
|
| 278 |
+
)
|
| 279 |
+
else:
|
| 280 |
+
crop_images = [image]
|
| 281 |
+
if num_crops is None or len(crop_images) > num_crops:
|
| 282 |
+
num_crops = len(crop_images)
|
| 283 |
+
|
| 284 |
+
for crop_image in crop_images:
|
| 285 |
+
# At this point the scale is the rescaling factor that would bring the image to max_size in its larger dimension
|
| 286 |
+
h, w = get_image_size(crop_image)
|
| 287 |
+
scale = max_image_size / max(h, w)
|
| 288 |
+
if w >= h:
|
| 289 |
+
new_size = (max(int(h * scale), min_image_size), max_image_size) # h, w
|
| 290 |
+
else:
|
| 291 |
+
new_size = (max_image_size, max(int(w * scale), min_image_size)) # h, w
|
| 292 |
+
|
| 293 |
+
crop_image_resized = resize(
|
| 294 |
+
crop_image,
|
| 295 |
+
new_size,
|
| 296 |
+
resample=resample,
|
| 297 |
+
data_format=input_data_format,
|
| 298 |
+
input_data_format=input_data_format,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
padding_bottom, padding_right = max_image_size - new_size[0], max_image_size - new_size[1]
|
| 302 |
+
crop_image_padded = pad(
|
| 303 |
+
crop_image_resized,
|
| 304 |
+
((0, padding_bottom), (0, padding_right)),
|
| 305 |
+
data_format=input_data_format,
|
| 306 |
+
input_data_format=input_data_format,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Create a pixel mask
|
| 310 |
+
pixel_mask = np.zeros((max_image_size, max_image_size), dtype=bool)
|
| 311 |
+
pixel_mask[: new_size[0], : new_size[1]] = 1
|
| 312 |
+
pixel_masks.append(pixel_mask)
|
| 313 |
+
|
| 314 |
+
if do_rescale:
|
| 315 |
+
crop_image_padded = self.rescale(
|
| 316 |
+
image=crop_image_padded, scale=rescale_factor, input_data_format=input_data_format
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
if do_normalize:
|
| 320 |
+
crop_image_padded = self.normalize(
|
| 321 |
+
crop_image_padded,
|
| 322 |
+
self.image_mean,
|
| 323 |
+
self.image_std,
|
| 324 |
+
data_format=input_data_format,
|
| 325 |
+
input_data_format=input_data_format,
|
| 326 |
+
)
|
| 327 |
+
crop_image_padded = (
|
| 328 |
+
to_channel_dimension_format(crop_image_padded, data_format, input_data_format)
|
| 329 |
+
if data_format is not None
|
| 330 |
+
else crop_image_padded
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
pixel_values.append(crop_image_padded)
|
| 334 |
+
return BatchFeature(
|
| 335 |
+
data={
|
| 336 |
+
"pixel_values": np.stack(pixel_values, axis=0),
|
| 337 |
+
"pixel_mask": np.stack(pixel_masks, axis=0),
|
| 338 |
+
"num_crops": num_crops,
|
| 339 |
+
},
|
| 340 |
+
tensor_type=return_tensors,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
def _resize_for_patching(
|
| 344 |
+
self, image: np.ndarray, target_resolution: tuple, resample, input_data_format: ChannelDimension
|
| 345 |
+
) -> np.ndarray:
|
| 346 |
+
"""
|
| 347 |
+
Resizes an image to a target resolution while maintaining aspect ratio.
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
image (np.ndarray):
|
| 351 |
+
The input image.
|
| 352 |
+
target_resolution (tuple):
|
| 353 |
+
The target resolution (height, width) of the image.
|
| 354 |
+
resample (`PILImageResampling`):
|
| 355 |
+
Resampling filter to use if resizing the image.
|
| 356 |
+
input_data_format (`ChannelDimension` or `str`):
|
| 357 |
+
The channel dimension format of the input image.
|
| 358 |
+
|
| 359 |
+
Returns:
|
| 360 |
+
np.ndarray: The resized and padded image.
|
| 361 |
+
"""
|
| 362 |
+
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
| 363 |
+
|
| 364 |
+
# Resize the image
|
| 365 |
+
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
|
| 366 |
+
|
| 367 |
+
return resized_image
|
| 368 |
+
|
| 369 |
+
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
|
| 370 |
+
original_height, original_width = original_resolution
|
| 371 |
+
target_height, target_width = target_resolution
|
| 372 |
+
paste_x, r_x = divmod(target_width - original_width, 2)
|
| 373 |
+
paste_y, r_y = divmod(target_height - original_height, 2)
|
| 374 |
+
return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
|
| 375 |
+
|
| 376 |
+
def _pad_for_patching(
|
| 377 |
+
self, image: np.ndarray, target_resolution: tuple, input_data_format: ChannelDimension
|
| 378 |
+
) -> np.ndarray:
|
| 379 |
+
"""
|
| 380 |
+
Pad an image to a target resolution while maintaining aspect ratio.
|
| 381 |
+
"""
|
| 382 |
+
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
|
| 383 |
+
padding = self._get_padding_size(new_resolution, target_resolution)
|
| 384 |
+
|
| 385 |
+
padded_image = self.pad(image, padding=padding)
|
| 386 |
+
|
| 387 |
+
return padded_image
|
| 388 |
+
|
| 389 |
+
def pad(
|
| 390 |
+
self,
|
| 391 |
+
image: np.ndarray,
|
| 392 |
+
padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]],
|
| 393 |
+
mode: PaddingMode = PaddingMode.CONSTANT,
|
| 394 |
+
constant_values: Union[float, Iterable[float]] = 0.0,
|
| 395 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 396 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 397 |
+
) -> np.ndarray:
|
| 398 |
+
"""
|
| 399 |
+
Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
|
| 400 |
+
dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
|
| 401 |
+
as input.
|
| 402 |
+
|
| 403 |
+
Args:
|
| 404 |
+
image (`np.ndarray`):
|
| 405 |
+
The image to pad.
|
| 406 |
+
padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
|
| 407 |
+
Padding to apply to the edges of the height, width axes. Can be one of three formats:
|
| 408 |
+
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
|
| 409 |
+
- `((before, after),)` yields same before and after pad for height and width.
|
| 410 |
+
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
|
| 411 |
+
mode (`PaddingMode`):
|
| 412 |
+
The padding mode to use. Can be one of:
|
| 413 |
+
- `"constant"`: pads with a constant value.
|
| 414 |
+
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
|
| 415 |
+
vector along each axis.
|
| 416 |
+
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
|
| 417 |
+
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
|
| 418 |
+
constant_values (`float` or `Iterable[float]`, *optional*):
|
| 419 |
+
The value to use for the padding if `mode` is `"constant"`.
|
| 420 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 421 |
+
The channel dimension format for the output image. Can be one of:
|
| 422 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 423 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 424 |
+
If unset, will use same as the input image.
|
| 425 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 426 |
+
The channel dimension format for the input image. Can be one of:
|
| 427 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 428 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 429 |
+
If unset, will use the inferred format of the input image.
|
| 430 |
+
|
| 431 |
+
Returns:
|
| 432 |
+
`np.ndarray`: The padded image.
|
| 433 |
+
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
# call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
|
| 437 |
+
if isinstance(padding, int) or len(padding) != 4:
|
| 438 |
+
return pad(image, padding, mode, constant_values, data_format, input_data_format)
|
| 439 |
+
|
| 440 |
+
if input_data_format is None:
|
| 441 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 442 |
+
|
| 443 |
+
padding_mode_mapping = {
|
| 444 |
+
PaddingMode.CONSTANT: "constant",
|
| 445 |
+
PaddingMode.REFLECT: "reflect",
|
| 446 |
+
PaddingMode.REPLICATE: "edge",
|
| 447 |
+
PaddingMode.SYMMETRIC: "symmetric",
|
| 448 |
+
}
|
| 449 |
+
image = np.pad(image, padding, mode=padding_mode_mapping[mode], constant_values=constant_values)
|
| 450 |
+
image = (
|
| 451 |
+
to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
|
| 452 |
+
)
|
| 453 |
+
return image
|
| 454 |
+
|
| 455 |
+
def get_image_patches(
|
| 456 |
+
self,
|
| 457 |
+
image: np.ndarray,
|
| 458 |
+
grid_pinpoints: list[tuple[int, int]],
|
| 459 |
+
patch_size: int,
|
| 460 |
+
resample: PILImageResampling,
|
| 461 |
+
data_format: ChannelDimension,
|
| 462 |
+
input_data_format: ChannelDimension,
|
| 463 |
+
) -> list[np.ndarray]:
|
| 464 |
+
"""
|
| 465 |
+
Process an image with variable resolutions by dividing it into patches.
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
image (`np.ndarray`):
|
| 469 |
+
The input image to be processed.
|
| 470 |
+
grid_pinpoints (list[tuple[int, int]]):
|
| 471 |
+
A list of possible resolutions as tuples.
|
| 472 |
+
patch_size (`int`):
|
| 473 |
+
Size of the patches to divide the image into.
|
| 474 |
+
resample (`PILImageResampling`):
|
| 475 |
+
Resampling filter to use if resizing the image.
|
| 476 |
+
data_format (`ChannelDimension` or `str`):
|
| 477 |
+
The channel dimension format for the output image.
|
| 478 |
+
input_data_format (`ChannelDimension` or `str`):
|
| 479 |
+
The channel dimension format of the input image.
|
| 480 |
+
|
| 481 |
+
Returns:
|
| 482 |
+
`list[np.ndarray]`: A list of NumPy arrays containing the processed image patches.
|
| 483 |
+
"""
|
| 484 |
+
if not isinstance(grid_pinpoints, list):
|
| 485 |
+
raise TypeError("grid_pinpoints must be a list of possible resolutions.")
|
| 486 |
+
|
| 487 |
+
possible_resolutions = grid_pinpoints
|
| 488 |
+
|
| 489 |
+
image_size = get_image_size(image, channel_dim=input_data_format)
|
| 490 |
+
best_resolution = select_best_resolution(image_size, possible_resolutions)
|
| 491 |
+
resized_image = self._resize_for_patching(
|
| 492 |
+
image, best_resolution, resample=resample, input_data_format=input_data_format
|
| 493 |
+
)
|
| 494 |
+
padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
|
| 495 |
+
|
| 496 |
+
patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format)
|
| 497 |
+
|
| 498 |
+
# make sure that all patches are in the input data format
|
| 499 |
+
patches = [
|
| 500 |
+
to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format)
|
| 501 |
+
for patch in patches
|
| 502 |
+
]
|
| 503 |
+
return patches
|
| 504 |
+
|
| 505 |
+
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
|
| 506 |
+
"""
|
| 507 |
+
A utility that returns number of image patches for a given image size.
|
| 508 |
+
|
| 509 |
+
Args:
|
| 510 |
+
height (`int`):
|
| 511 |
+
Height of the input image.
|
| 512 |
+
width (`int`):
|
| 513 |
+
Width of the input image.
|
| 514 |
+
images_kwargs (`dict`, *optional*)
|
| 515 |
+
Any kwargs to override defaults of the image processor.
|
| 516 |
+
Returns:
|
| 517 |
+
`int`: Number of patches per image.
|
| 518 |
+
"""
|
| 519 |
+
split_image = images_kwargs.get("split_image", self.split_image)
|
| 520 |
+
max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
|
| 521 |
+
|
| 522 |
+
resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
|
| 523 |
+
num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
|
| 524 |
+
return num_patches
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
__all__ = ["AriaImageProcessor"]
|
venv/lib/python3.13/site-packages/transformers/models/aria/modeling_aria.py
ADDED
|
@@ -0,0 +1,1275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/aria/modular_aria.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_aria.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from typing import Callable, Optional, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
|
| 27 |
+
from ...activations import ACT2FN
|
| 28 |
+
from ...cache_utils import Cache, DynamicCache
|
| 29 |
+
from ...generation import GenerationMixin
|
| 30 |
+
from ...integrations import use_kernel_forward_from_hub
|
| 31 |
+
from ...masking_utils import create_causal_mask
|
| 32 |
+
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
| 33 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 34 |
+
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
| 35 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 36 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 37 |
+
from ...processing_utils import Unpack
|
| 38 |
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
| 39 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 40 |
+
from ...utils.generic import check_model_inputs
|
| 41 |
+
from ..auto import AutoModel
|
| 42 |
+
from .configuration_aria import AriaConfig, AriaTextConfig
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 46 |
+
class AriaTextRMSNorm(nn.Module):
|
| 47 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 48 |
+
"""
|
| 49 |
+
AriaTextRMSNorm is equivalent to T5LayerNorm
|
| 50 |
+
"""
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 53 |
+
self.variance_epsilon = eps
|
| 54 |
+
|
| 55 |
+
def forward(self, hidden_states):
|
| 56 |
+
input_dtype = hidden_states.dtype
|
| 57 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 58 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 59 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 60 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 61 |
+
|
| 62 |
+
def extra_repr(self):
|
| 63 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class AriaProjectorMLP(nn.Module):
|
| 67 |
+
"""
|
| 68 |
+
Feed-Forward Network module for the Aria Projector.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
in_features (`int`):
|
| 72 |
+
Input embedding dimension.
|
| 73 |
+
hidden_features (`int`):
|
| 74 |
+
Hidden dimension of the feed-forward network.
|
| 75 |
+
output_dim (`int`):
|
| 76 |
+
Output dimension.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, in_features, hidden_features, output_dim):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.linear_in = nn.Linear(in_features, hidden_features, bias=False)
|
| 82 |
+
self.linear_out = nn.Linear(hidden_features, output_dim, bias=False)
|
| 83 |
+
self.act = ACT2FN["gelu_new"]
|
| 84 |
+
|
| 85 |
+
def forward(self, hidden_states):
|
| 86 |
+
hidden_states = self.act(self.linear_in(hidden_states))
|
| 87 |
+
hidden_states = self.linear_out(hidden_states)
|
| 88 |
+
return hidden_states
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class AriaCrossAttention(nn.Module):
|
| 92 |
+
"""
|
| 93 |
+
Aria Cross-Attention module.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
config (`AriaConfig`):
|
| 97 |
+
The configuration to use.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, config: AriaConfig, dropout_rate: float = 0):
|
| 101 |
+
super().__init__()
|
| 102 |
+
hidden_size = config.vision_config.hidden_size
|
| 103 |
+
num_heads = config.vision_config.num_attention_heads
|
| 104 |
+
self.num_heads = num_heads
|
| 105 |
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 106 |
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 107 |
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 108 |
+
|
| 109 |
+
# Original code here: https://github.com/rhymes-ai/Aria/blob/719ff4e52b727443cba3793b0e27fe64e0244fe1/aria/model/projector.py#L48
|
| 110 |
+
self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
|
| 111 |
+
self.linear = nn.Linear(hidden_size, hidden_size)
|
| 112 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 113 |
+
|
| 114 |
+
self.layer_norm = nn.LayerNorm(hidden_size)
|
| 115 |
+
self.layer_norm_kv = nn.LayerNorm(hidden_size)
|
| 116 |
+
|
| 117 |
+
def forward(self, key_value_states, hidden_states, attn_mask=None):
|
| 118 |
+
"""
|
| 119 |
+
Forward pass of the AriaCrossAttention module.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
key_value_states (`torch.Tensor`):
|
| 123 |
+
Input tensor for key and value.
|
| 124 |
+
hidden_states (`torch.Tensor`):
|
| 125 |
+
Input tensor for query.
|
| 126 |
+
attn_mask (`torch.Tensor`, *optional*, defaults to None):
|
| 127 |
+
Attention mask.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
torch.Tensor:
|
| 131 |
+
Output tensor after cross-attention.
|
| 132 |
+
"""
|
| 133 |
+
query = self.q_proj(self.layer_norm(hidden_states))
|
| 134 |
+
|
| 135 |
+
key_value_states = self.layer_norm_kv(key_value_states)
|
| 136 |
+
key = self.k_proj(key_value_states)
|
| 137 |
+
value = self.v_proj(key_value_states)
|
| 138 |
+
|
| 139 |
+
attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask)
|
| 140 |
+
|
| 141 |
+
attn_output = self.dropout(self.linear(attn_output))
|
| 142 |
+
|
| 143 |
+
return attn_output
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class AriaProjector(nn.Module):
|
| 147 |
+
"""
|
| 148 |
+
Aria Projector module.
|
| 149 |
+
|
| 150 |
+
This module projects vision features into the language model's embedding space, enabling interaction between vision and language components.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
config (`AriaConfig`):
|
| 154 |
+
Configuration object for the model.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
config: AriaConfig,
|
| 160 |
+
):
|
| 161 |
+
super().__init__()
|
| 162 |
+
|
| 163 |
+
self.patch_to_query_dict = config.projector_patch_to_query_dict
|
| 164 |
+
self.in_features = config.vision_config.hidden_size
|
| 165 |
+
self.num_heads = config.vision_config.num_attention_heads
|
| 166 |
+
self.kv_dim = config.vision_config.hidden_size
|
| 167 |
+
self.hidden_features = config.text_config.hidden_size
|
| 168 |
+
self.output_dim = config.text_config.hidden_size
|
| 169 |
+
|
| 170 |
+
self.query = nn.Parameter(torch.zeros(config.max_value_projector_patch_to_query_dict, self.in_features))
|
| 171 |
+
|
| 172 |
+
self.cross_attn = AriaCrossAttention(config)
|
| 173 |
+
|
| 174 |
+
self.layer_norm = nn.LayerNorm(self.in_features)
|
| 175 |
+
self.feed_forward = AriaProjectorMLP(self.in_features, self.hidden_features, self.output_dim)
|
| 176 |
+
|
| 177 |
+
def forward(self, key_value_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
| 178 |
+
"""
|
| 179 |
+
Forward pass of the Projector module.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
key_value_states (`torch.Tensor`):
|
| 183 |
+
Input tensor of shape (batch_size, num_patches, kv_dim).
|
| 184 |
+
attn_mask (`torch.Tensor`, *optional*, default is None):
|
| 185 |
+
Attention mask.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
`torch.Tensor`: Output tensor of shape (batch_size, query_number, output_dim).
|
| 189 |
+
"""
|
| 190 |
+
batch_size, num_patches = key_value_states.shape[0], key_value_states.shape[1]
|
| 191 |
+
|
| 192 |
+
if num_patches not in self.patch_to_query_dict:
|
| 193 |
+
raise KeyError(
|
| 194 |
+
f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}."
|
| 195 |
+
)
|
| 196 |
+
query_num = self.patch_to_query_dict[num_patches]
|
| 197 |
+
|
| 198 |
+
queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
|
| 199 |
+
|
| 200 |
+
if attn_mask is not None:
|
| 201 |
+
attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
|
| 202 |
+
attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)
|
| 203 |
+
|
| 204 |
+
attention_out = self.cross_attn(key_value_states, queries, attn_mask=attn_mask)
|
| 205 |
+
|
| 206 |
+
out = self.feed_forward(self.layer_norm(attention_out))
|
| 207 |
+
|
| 208 |
+
return out
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class AriaSharedExpertsMLP(nn.Module):
|
| 212 |
+
"""
|
| 213 |
+
Shared Expert MLP for shared experts.
|
| 214 |
+
|
| 215 |
+
Unlike routed experts, shared experts process all tokens without routing.
|
| 216 |
+
This class reconfigures the intermediate size in comparison to the LlamaMLP.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
config (`AriaTextConfig`): Configuration object for the Aria language model.
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
def __init__(self, config: AriaTextConfig):
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.config = config
|
| 225 |
+
self.hidden_size = config.hidden_size
|
| 226 |
+
self.intermediate_size = config.intermediate_size * config.moe_num_shared_experts
|
| 227 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 228 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 229 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
| 230 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 231 |
+
|
| 232 |
+
def forward(self, x):
|
| 233 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 234 |
+
return down_proj
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
|
| 238 |
+
"""
|
| 239 |
+
Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
token_states (torch.Tensor): Input tensor of shape (num_tokens, in_features).
|
| 243 |
+
expert_weights (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features).
|
| 244 |
+
tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
torch.Tensor: Output tensor of shape (num_tokens, out_features).
|
| 248 |
+
"""
|
| 249 |
+
num_tokens = token_states.shape[0]
|
| 250 |
+
out_features = expert_weights.shape[-1]
|
| 251 |
+
output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device)
|
| 252 |
+
|
| 253 |
+
cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
|
| 254 |
+
# Insert zero at the beginning for offset index's convenience
|
| 255 |
+
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
|
| 256 |
+
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
|
| 257 |
+
|
| 258 |
+
for expert_num in range(expert_weights.shape[0]):
|
| 259 |
+
start = cumsum_num_tokens[expert_num]
|
| 260 |
+
end = cumsum_num_tokens[expert_num + 1]
|
| 261 |
+
tokens = token_states[start:end]
|
| 262 |
+
|
| 263 |
+
out = torch.matmul(tokens, expert_weights[expert_num])
|
| 264 |
+
output[start:end] = out
|
| 265 |
+
return output
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class AriaGroupedExpertsGemm(nn.Module):
|
| 269 |
+
"""
|
| 270 |
+
Grouped GEMM (General Matrix Multiplication) module for efficient expert computation.
|
| 271 |
+
This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm)
|
| 272 |
+
for optimized performance. If the grouped_gemm library is not installed, it gracefully
|
| 273 |
+
falls back to a sequential GEMM implementation, which may be slower but ensures
|
| 274 |
+
functionality.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
in_features (`int`):
|
| 278 |
+
Number of input features.
|
| 279 |
+
out_features (`int`):
|
| 280 |
+
Number of output features.
|
| 281 |
+
groups (`int`):
|
| 282 |
+
Number of expert groups.
|
| 283 |
+
"""
|
| 284 |
+
|
| 285 |
+
def __init__(self, in_features, out_features, groups):
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.in_features = in_features
|
| 288 |
+
self.out_features = out_features
|
| 289 |
+
self.groups = groups
|
| 290 |
+
self.weight = nn.Parameter(torch.empty(groups, in_features, out_features))
|
| 291 |
+
|
| 292 |
+
def forward(self, input, tokens_per_expert):
|
| 293 |
+
"""
|
| 294 |
+
Perform grouped matrix multiplication.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
input (`torch.Tensor`):
|
| 298 |
+
Input tensor of shape (num_tokens, in_features).
|
| 299 |
+
tokens_per_expert (`torch.Tensor`):
|
| 300 |
+
Number of tokens assigned to each expert.
|
| 301 |
+
|
| 302 |
+
Returns:
|
| 303 |
+
torch.Tensor: Output tensor of shape (num_tokens, out_features).
|
| 304 |
+
"""
|
| 305 |
+
return sequential_experts_gemm(
|
| 306 |
+
input,
|
| 307 |
+
self.weight,
|
| 308 |
+
tokens_per_expert.cpu(),
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class AriaGroupedExpertsMLP(nn.Module):
|
| 313 |
+
"""
|
| 314 |
+
Grouped MLP module for Mixture of Experts.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
config (`AriaTextConfig`):
|
| 318 |
+
Configuration object for the model.
|
| 319 |
+
"""
|
| 320 |
+
|
| 321 |
+
def __init__(self, config: AriaTextConfig) -> None:
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.config = config
|
| 324 |
+
self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts)
|
| 325 |
+
self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts)
|
| 326 |
+
|
| 327 |
+
def forward(self, permuted_tokens, tokens_per_expert):
|
| 328 |
+
"""
|
| 329 |
+
Forward pass of the Grouped MLP.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
permuted_tokens (torch.Tensor): Permuted input tokens.
|
| 333 |
+
tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
torch.Tensor: Output tensor after passing through the MLP.
|
| 337 |
+
"""
|
| 338 |
+
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
|
| 339 |
+
projection, gate = torch.chunk(fc1_output, 2, dim=-1)
|
| 340 |
+
fc1_output = nn.functional.silu(projection) * gate
|
| 341 |
+
fc2_output = self.fc2(fc1_output, tokens_per_expert)
|
| 342 |
+
return fc2_output
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# Token permutation adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587
|
| 346 |
+
class AriaTextMoELayer(nn.Module):
|
| 347 |
+
"""
|
| 348 |
+
Aria Text Mixture of Experts (MoE) Layer.
|
| 349 |
+
|
| 350 |
+
This layer applies a gating mechanism to route input tokens to different experts.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
config (`AriaTextConfig`):
|
| 354 |
+
Configuration object for the text component of the model.
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
def __init__(self, config: AriaTextConfig):
|
| 358 |
+
super().__init__()
|
| 359 |
+
|
| 360 |
+
self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
|
| 361 |
+
self.experts = AriaGroupedExpertsMLP(config)
|
| 362 |
+
self.shared_experts = AriaSharedExpertsMLP(config)
|
| 363 |
+
self.config = config
|
| 364 |
+
|
| 365 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 366 |
+
"""
|
| 367 |
+
Forward pass of the MoE Layer.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
hidden_states (`torch.Tensor`):
|
| 371 |
+
Input tensor of shape (batch_size, sequence_length, hidden_size).
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
torch.Tensor: Output tensor after passing through the MoE layer.
|
| 375 |
+
|
| 376 |
+
Process:
|
| 377 |
+
1. Route tokens to experts using the router.
|
| 378 |
+
2. Permute tokens based on routing decisions.
|
| 379 |
+
3. Process tokens through experts.
|
| 380 |
+
4. Unpermute and combine expert outputs.
|
| 381 |
+
5. Add shared expert output to the final result.
|
| 382 |
+
"""
|
| 383 |
+
original_shape = hidden_states.shape
|
| 384 |
+
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
| 385 |
+
|
| 386 |
+
# Top K Routing
|
| 387 |
+
logits = self.router(hidden_states)
|
| 388 |
+
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
|
| 389 |
+
scores = nn.functional.softmax(top_logits, dim=-1)
|
| 390 |
+
|
| 391 |
+
original_dtype = top_indices.dtype
|
| 392 |
+
|
| 393 |
+
tokens_per_expert = torch.histc(
|
| 394 |
+
top_indices.flatten().to(torch.float32),
|
| 395 |
+
bins=self.config.moe_num_experts,
|
| 396 |
+
min=0,
|
| 397 |
+
max=self.config.moe_num_experts - 1,
|
| 398 |
+
).to(original_dtype)
|
| 399 |
+
indices = top_indices
|
| 400 |
+
|
| 401 |
+
# Token permutation
|
| 402 |
+
flatten_indices = indices.view(-1)
|
| 403 |
+
sorted_indices = torch.argsort(flatten_indices)
|
| 404 |
+
permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk)
|
| 405 |
+
|
| 406 |
+
# Process through experts
|
| 407 |
+
expert_output = self.experts(permuted_tokens, tokens_per_expert)
|
| 408 |
+
|
| 409 |
+
# Token unpermutation
|
| 410 |
+
unpermuted_tokens = torch.zeros(
|
| 411 |
+
(scores.shape[0] * self.config.moe_topk, expert_output.size(1)),
|
| 412 |
+
dtype=expert_output.dtype,
|
| 413 |
+
device=expert_output.device,
|
| 414 |
+
)
|
| 415 |
+
unpermuted_tokens.index_copy_(0, sorted_indices, expert_output)
|
| 416 |
+
unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1))
|
| 417 |
+
|
| 418 |
+
output = (unpermuted_tokens * scores.unsqueeze(-1)).sum(dim=1).view(original_shape)
|
| 419 |
+
|
| 420 |
+
# Add shared expert output
|
| 421 |
+
shared_expert_output = self.shared_experts(hidden_states.view(original_shape))
|
| 422 |
+
return output + shared_expert_output
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def rotate_half(x):
|
| 426 |
+
"""Rotates half the hidden dims of the input."""
|
| 427 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 428 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 429 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 433 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
q (`torch.Tensor`): The query tensor.
|
| 437 |
+
k (`torch.Tensor`): The key tensor.
|
| 438 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 439 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 440 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 441 |
+
Deprecated and unused.
|
| 442 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 443 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 444 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 445 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 446 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 447 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 448 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 449 |
+
Returns:
|
| 450 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 451 |
+
"""
|
| 452 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 453 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 454 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 455 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 456 |
+
return q_embed, k_embed
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 460 |
+
"""
|
| 461 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 462 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 463 |
+
"""
|
| 464 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 465 |
+
if n_rep == 1:
|
| 466 |
+
return hidden_states
|
| 467 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 468 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def eager_attention_forward(
|
| 472 |
+
module: nn.Module,
|
| 473 |
+
query: torch.Tensor,
|
| 474 |
+
key: torch.Tensor,
|
| 475 |
+
value: torch.Tensor,
|
| 476 |
+
attention_mask: Optional[torch.Tensor],
|
| 477 |
+
scaling: float,
|
| 478 |
+
dropout: float = 0.0,
|
| 479 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 480 |
+
):
|
| 481 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 482 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 483 |
+
|
| 484 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 485 |
+
if attention_mask is not None:
|
| 486 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 487 |
+
attn_weights = attn_weights + causal_mask
|
| 488 |
+
|
| 489 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 490 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 491 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 492 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 493 |
+
|
| 494 |
+
return attn_output, attn_weights
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class AriaTextAttention(nn.Module):
|
| 498 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 499 |
+
|
| 500 |
+
def __init__(self, config: AriaTextConfig, layer_idx: int):
|
| 501 |
+
super().__init__()
|
| 502 |
+
self.config = config
|
| 503 |
+
self.layer_idx = layer_idx
|
| 504 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 505 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 506 |
+
self.scaling = self.head_dim**-0.5
|
| 507 |
+
self.attention_dropout = config.attention_dropout
|
| 508 |
+
self.is_causal = True
|
| 509 |
+
|
| 510 |
+
self.q_proj = nn.Linear(
|
| 511 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 512 |
+
)
|
| 513 |
+
self.k_proj = nn.Linear(
|
| 514 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 515 |
+
)
|
| 516 |
+
self.v_proj = nn.Linear(
|
| 517 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 518 |
+
)
|
| 519 |
+
self.o_proj = nn.Linear(
|
| 520 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 524 |
+
def forward(
|
| 525 |
+
self,
|
| 526 |
+
hidden_states: torch.Tensor,
|
| 527 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 528 |
+
attention_mask: Optional[torch.Tensor],
|
| 529 |
+
past_key_values: Optional[Cache] = None,
|
| 530 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 531 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 532 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 533 |
+
input_shape = hidden_states.shape[:-1]
|
| 534 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 535 |
+
|
| 536 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 537 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 538 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 539 |
+
|
| 540 |
+
cos, sin = position_embeddings
|
| 541 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 542 |
+
|
| 543 |
+
if past_key_values is not None:
|
| 544 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 545 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 546 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 547 |
+
|
| 548 |
+
attention_interface: Callable = eager_attention_forward
|
| 549 |
+
if self.config._attn_implementation != "eager":
|
| 550 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 551 |
+
|
| 552 |
+
attn_output, attn_weights = attention_interface(
|
| 553 |
+
self,
|
| 554 |
+
query_states,
|
| 555 |
+
key_states,
|
| 556 |
+
value_states,
|
| 557 |
+
attention_mask,
|
| 558 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 559 |
+
scaling=self.scaling,
|
| 560 |
+
**kwargs,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 564 |
+
attn_output = self.o_proj(attn_output)
|
| 565 |
+
return attn_output, attn_weights
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
class AriaTextDecoderLayer(GradientCheckpointingLayer):
|
| 569 |
+
"""
|
| 570 |
+
Aria Text Decoder Layer.
|
| 571 |
+
|
| 572 |
+
This class defines a single decoder layer in the language model, incorporating self-attention and Mixture of Experts (MoE) feed-forward network.
|
| 573 |
+
|
| 574 |
+
Args:
|
| 575 |
+
config (`AriaTextConfig`):
|
| 576 |
+
Configuration object for the text component of the model.
|
| 577 |
+
layer_idx (`int`):
|
| 578 |
+
Index of the layer.
|
| 579 |
+
"""
|
| 580 |
+
|
| 581 |
+
def __init__(self, config: AriaTextConfig, layer_idx: int):
|
| 582 |
+
super().__init__()
|
| 583 |
+
self.hidden_size = config.hidden_size
|
| 584 |
+
|
| 585 |
+
self.self_attn = AriaTextAttention(config=config, layer_idx=layer_idx)
|
| 586 |
+
self.mlp = AriaTextMoELayer(config)
|
| 587 |
+
self.input_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 588 |
+
self.post_attention_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 589 |
+
|
| 590 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 591 |
+
def forward(
|
| 592 |
+
self,
|
| 593 |
+
hidden_states: torch.Tensor,
|
| 594 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 595 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 596 |
+
past_key_values: Optional[Cache] = None,
|
| 597 |
+
use_cache: Optional[bool] = False,
|
| 598 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 599 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 600 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 601 |
+
) -> torch.Tensor:
|
| 602 |
+
residual = hidden_states
|
| 603 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 604 |
+
# Self Attention
|
| 605 |
+
hidden_states, _ = self.self_attn(
|
| 606 |
+
hidden_states=hidden_states,
|
| 607 |
+
attention_mask=attention_mask,
|
| 608 |
+
position_ids=position_ids,
|
| 609 |
+
past_key_values=past_key_values,
|
| 610 |
+
use_cache=use_cache,
|
| 611 |
+
cache_position=cache_position,
|
| 612 |
+
position_embeddings=position_embeddings,
|
| 613 |
+
**kwargs,
|
| 614 |
+
)
|
| 615 |
+
hidden_states = residual + hidden_states
|
| 616 |
+
|
| 617 |
+
# Fully Connected
|
| 618 |
+
residual = hidden_states
|
| 619 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 620 |
+
hidden_states = self.mlp(hidden_states)
|
| 621 |
+
hidden_states = residual + hidden_states
|
| 622 |
+
return hidden_states
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
@auto_docstring
|
| 626 |
+
class AriaTextPreTrainedModel(PreTrainedModel):
|
| 627 |
+
config: AriaTextConfig
|
| 628 |
+
base_model_prefix = "model"
|
| 629 |
+
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
|
| 630 |
+
supports_gradient_checkpointing = True
|
| 631 |
+
_skip_keys_device_placement = "past_key_values"
|
| 632 |
+
_supports_flash_attn = True
|
| 633 |
+
_supports_sdpa = True
|
| 634 |
+
|
| 635 |
+
_supports_attention_backend = True
|
| 636 |
+
_can_record_outputs = {
|
| 637 |
+
"hidden_states": AriaTextDecoderLayer,
|
| 638 |
+
"attentions": AriaTextAttention,
|
| 639 |
+
}
|
| 640 |
+
|
| 641 |
+
def _init_weights(self, module):
|
| 642 |
+
super()._init_weights(module)
|
| 643 |
+
if isinstance(module, AriaGroupedExpertsGemm):
|
| 644 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
@auto_docstring
|
| 648 |
+
class AriaPreTrainedModel(PreTrainedModel):
|
| 649 |
+
config: AriaConfig
|
| 650 |
+
base_model_prefix = ""
|
| 651 |
+
supports_gradient_checkpointing = True
|
| 652 |
+
_no_split_modules = ["AriaDecoderLayer"]
|
| 653 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 654 |
+
_supports_flash_attn = True
|
| 655 |
+
_supports_sdpa = True
|
| 656 |
+
_supports_flex_attn = True
|
| 657 |
+
_can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
|
| 658 |
+
_supports_attention_backend = True
|
| 659 |
+
_can_record_outputs = {
|
| 660 |
+
"hidden_states": AriaTextDecoderLayer,
|
| 661 |
+
"attentions": AriaTextAttention,
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
def _init_weights(self, module):
|
| 665 |
+
super()._init_weights(module)
|
| 666 |
+
if isinstance(module, AriaProjector):
|
| 667 |
+
nn.init.trunc_normal_(module.query, std=self.config.initializer_range)
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
class AriaTextRotaryEmbedding(nn.Module):
|
| 671 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 672 |
+
|
| 673 |
+
def __init__(self, config: AriaTextConfig, device=None):
|
| 674 |
+
super().__init__()
|
| 675 |
+
# BC: "rope_type" was originally "type"
|
| 676 |
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 677 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 678 |
+
else:
|
| 679 |
+
self.rope_type = "default"
|
| 680 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 681 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 682 |
+
|
| 683 |
+
self.config = config
|
| 684 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 685 |
+
|
| 686 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 687 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 688 |
+
self.original_inv_freq = self.inv_freq
|
| 689 |
+
|
| 690 |
+
@torch.no_grad()
|
| 691 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 692 |
+
def forward(self, x, position_ids):
|
| 693 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 694 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 695 |
+
|
| 696 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 697 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 698 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 699 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 700 |
+
cos = emb.cos() * self.attention_scaling
|
| 701 |
+
sin = emb.sin() * self.attention_scaling
|
| 702 |
+
|
| 703 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
@auto_docstring
|
| 707 |
+
class AriaTextModel(AriaTextPreTrainedModel):
|
| 708 |
+
def __init__(self, config: AriaTextConfig):
|
| 709 |
+
super().__init__(config)
|
| 710 |
+
self.padding_idx = config.pad_token_id
|
| 711 |
+
self.vocab_size = config.vocab_size
|
| 712 |
+
|
| 713 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 714 |
+
self.layers = nn.ModuleList(
|
| 715 |
+
[AriaTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 716 |
+
)
|
| 717 |
+
self.norm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 718 |
+
self.rotary_emb = AriaTextRotaryEmbedding(config=config)
|
| 719 |
+
self.gradient_checkpointing = False
|
| 720 |
+
|
| 721 |
+
# Initialize weights and apply final processing
|
| 722 |
+
self.post_init()
|
| 723 |
+
|
| 724 |
+
@check_model_inputs()
|
| 725 |
+
@auto_docstring
|
| 726 |
+
def forward(
|
| 727 |
+
self,
|
| 728 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 729 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 730 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 731 |
+
past_key_values: Optional[Cache] = None,
|
| 732 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 733 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 734 |
+
use_cache: Optional[bool] = None,
|
| 735 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 736 |
+
) -> BaseModelOutputWithPast:
|
| 737 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 738 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 739 |
+
|
| 740 |
+
if inputs_embeds is None:
|
| 741 |
+
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
| 742 |
+
|
| 743 |
+
if use_cache and past_key_values is None:
|
| 744 |
+
past_key_values = DynamicCache(config=self.config)
|
| 745 |
+
|
| 746 |
+
if cache_position is None:
|
| 747 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 748 |
+
cache_position: torch.Tensor = torch.arange(
|
| 749 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
if position_ids is None:
|
| 753 |
+
position_ids = cache_position.unsqueeze(0)
|
| 754 |
+
|
| 755 |
+
causal_mask = create_causal_mask(
|
| 756 |
+
config=self.config,
|
| 757 |
+
input_embeds=inputs_embeds,
|
| 758 |
+
attention_mask=attention_mask,
|
| 759 |
+
cache_position=cache_position,
|
| 760 |
+
past_key_values=past_key_values,
|
| 761 |
+
position_ids=position_ids,
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
hidden_states = inputs_embeds
|
| 765 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 766 |
+
|
| 767 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 768 |
+
hidden_states = decoder_layer(
|
| 769 |
+
hidden_states,
|
| 770 |
+
attention_mask=causal_mask,
|
| 771 |
+
position_ids=position_ids,
|
| 772 |
+
past_key_values=past_key_values,
|
| 773 |
+
cache_position=cache_position,
|
| 774 |
+
position_embeddings=position_embeddings,
|
| 775 |
+
**kwargs,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
hidden_states = self.norm(hidden_states)
|
| 779 |
+
return BaseModelOutputWithPast(
|
| 780 |
+
last_hidden_state=hidden_states,
|
| 781 |
+
past_key_values=past_key_values,
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
@auto_docstring
|
| 786 |
+
class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
| 787 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 788 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 789 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 790 |
+
|
| 791 |
+
def __init__(self, config: AriaTextConfig):
|
| 792 |
+
super().__init__(config)
|
| 793 |
+
self.model = AriaTextModel(config)
|
| 794 |
+
self.vocab_size = config.vocab_size
|
| 795 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 796 |
+
|
| 797 |
+
# Initialize weights and apply final processing
|
| 798 |
+
self.post_init()
|
| 799 |
+
|
| 800 |
+
@auto_docstring
|
| 801 |
+
def forward(
|
| 802 |
+
self,
|
| 803 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 804 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 805 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 806 |
+
past_key_values: Optional[Cache] = None,
|
| 807 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 808 |
+
labels: Optional[torch.LongTensor] = None,
|
| 809 |
+
use_cache: Optional[bool] = None,
|
| 810 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 811 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 812 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 813 |
+
) -> CausalLMOutputWithPast:
|
| 814 |
+
r"""
|
| 815 |
+
Example:
|
| 816 |
+
|
| 817 |
+
```python
|
| 818 |
+
>>> from transformers import AutoTokenizer, AriaTextForCausalLM
|
| 819 |
+
|
| 820 |
+
>>> model = AriaTextForCausalLM.from_pretrained("meta-aria_text/AriaText-2-7b-hf")
|
| 821 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-aria_text/AriaText-2-7b-hf")
|
| 822 |
+
|
| 823 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 824 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 825 |
+
|
| 826 |
+
>>> # Generate
|
| 827 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 828 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 829 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 830 |
+
```"""
|
| 831 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 832 |
+
input_ids=input_ids,
|
| 833 |
+
attention_mask=attention_mask,
|
| 834 |
+
position_ids=position_ids,
|
| 835 |
+
past_key_values=past_key_values,
|
| 836 |
+
inputs_embeds=inputs_embeds,
|
| 837 |
+
use_cache=use_cache,
|
| 838 |
+
cache_position=cache_position,
|
| 839 |
+
**kwargs,
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
hidden_states = outputs.last_hidden_state
|
| 843 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 844 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 845 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 846 |
+
|
| 847 |
+
loss = None
|
| 848 |
+
if labels is not None:
|
| 849 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 850 |
+
|
| 851 |
+
return CausalLMOutputWithPast(
|
| 852 |
+
loss=loss,
|
| 853 |
+
logits=logits,
|
| 854 |
+
past_key_values=outputs.past_key_values,
|
| 855 |
+
hidden_states=outputs.hidden_states,
|
| 856 |
+
attentions=outputs.attentions,
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
@dataclass
|
| 861 |
+
@auto_docstring(
|
| 862 |
+
custom_intro="""
|
| 863 |
+
Base class for Aria causal language model (or autoregressive) outputs.
|
| 864 |
+
"""
|
| 865 |
+
)
|
| 866 |
+
class AriaCausalLMOutputWithPast(ModelOutput):
|
| 867 |
+
r"""
|
| 868 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 869 |
+
Language modeling loss (for next-token prediction).
|
| 870 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 871 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 872 |
+
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 873 |
+
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
| 874 |
+
|
| 875 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 876 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 877 |
+
image_hidden_states (`torch.FloatTensor`, *optional*):
|
| 878 |
+
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
| 879 |
+
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
| 880 |
+
"""
|
| 881 |
+
|
| 882 |
+
loss: Optional[torch.FloatTensor] = None
|
| 883 |
+
logits: Optional[torch.FloatTensor] = None
|
| 884 |
+
past_key_values: Optional[Cache] = None
|
| 885 |
+
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 886 |
+
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 887 |
+
image_hidden_states: Optional[torch.FloatTensor] = None
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
@dataclass
|
| 891 |
+
@auto_docstring(
|
| 892 |
+
custom_intro="""
|
| 893 |
+
Base class for Aria outputs, with hidden states and attentions.
|
| 894 |
+
"""
|
| 895 |
+
)
|
| 896 |
+
class AriaModelOutputWithPast(BaseModelOutputWithPast):
|
| 897 |
+
r"""
|
| 898 |
+
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 899 |
+
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
| 900 |
+
|
| 901 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 902 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 903 |
+
image_hidden_states (`torch.FloatTensor`, *optional*):
|
| 904 |
+
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
| 905 |
+
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
| 906 |
+
"""
|
| 907 |
+
|
| 908 |
+
image_hidden_states: Optional[torch.FloatTensor] = None
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
@auto_docstring(
|
| 912 |
+
custom_intro="""
|
| 913 |
+
The Aria model which consists of a vision backbone and a language model, without a language modeling head.
|
| 914 |
+
"""
|
| 915 |
+
)
|
| 916 |
+
class AriaModel(AriaPreTrainedModel):
|
| 917 |
+
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
| 918 |
+
|
| 919 |
+
def __init__(self, config: AriaConfig):
|
| 920 |
+
super().__init__(config)
|
| 921 |
+
self.vision_tower = AutoModel.from_config(config.vision_config)
|
| 922 |
+
self.multi_modal_projector = AriaProjector(config)
|
| 923 |
+
self.language_model = AutoModel.from_config(config.text_config)
|
| 924 |
+
self.post_init()
|
| 925 |
+
|
| 926 |
+
def get_input_embeddings(self):
|
| 927 |
+
return self.language_model.get_input_embeddings()
|
| 928 |
+
|
| 929 |
+
def set_input_embeddings(self, value):
|
| 930 |
+
self.language_model.set_input_embeddings(value)
|
| 931 |
+
|
| 932 |
+
def set_decoder(self, decoder):
|
| 933 |
+
self.language_model = decoder
|
| 934 |
+
|
| 935 |
+
def get_decoder(self):
|
| 936 |
+
return self.language_model
|
| 937 |
+
|
| 938 |
+
def get_image_features(
|
| 939 |
+
self,
|
| 940 |
+
pixel_values: torch.FloatTensor,
|
| 941 |
+
pixel_mask: Optional[torch.FloatTensor] = None,
|
| 942 |
+
vision_feature_layer: int = -1,
|
| 943 |
+
):
|
| 944 |
+
"""
|
| 945 |
+
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
| 946 |
+
|
| 947 |
+
Args:
|
| 948 |
+
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
| 949 |
+
The tensors corresponding to the input images.
|
| 950 |
+
pixel_mask (`torch.FloatTensor]`, *optional*):
|
| 951 |
+
The tensors corresponding to the input image mask.
|
| 952 |
+
vision_feature_layer (`Union[int, list[int]]`, *optional*):
|
| 953 |
+
The index of the layer to select the vision feature. If multiple indices are provided,
|
| 954 |
+
the vision feature of the corresponding indices will be concatenated to form the
|
| 955 |
+
vision features.
|
| 956 |
+
Returns:
|
| 957 |
+
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
| 958 |
+
"""
|
| 959 |
+
vision_feature_layer = (
|
| 960 |
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
| 961 |
+
)
|
| 962 |
+
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
| 963 |
+
image_outputs = self.vision_tower(
|
| 964 |
+
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
|
| 965 |
+
)
|
| 966 |
+
image_attn_mask = None
|
| 967 |
+
if patch_attention_mask is not None:
|
| 968 |
+
flattened_mask = patch_attention_mask.flatten(1)
|
| 969 |
+
image_attn_mask = torch.logical_not(flattened_mask)
|
| 970 |
+
|
| 971 |
+
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
| 972 |
+
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
|
| 973 |
+
return image_features
|
| 974 |
+
|
| 975 |
+
def get_placeholder_mask(
|
| 976 |
+
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
|
| 977 |
+
):
|
| 978 |
+
"""
|
| 979 |
+
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
|
| 980 |
+
equal to the length of multimodal features. If the lengths are different, an error is raised.
|
| 981 |
+
"""
|
| 982 |
+
if input_ids is None:
|
| 983 |
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
| 984 |
+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 985 |
+
)
|
| 986 |
+
special_image_mask = special_image_mask.all(-1)
|
| 987 |
+
else:
|
| 988 |
+
special_image_mask = input_ids == self.config.image_token_id
|
| 989 |
+
|
| 990 |
+
n_image_tokens = special_image_mask.sum()
|
| 991 |
+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 992 |
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
| 993 |
+
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
| 994 |
+
raise ValueError(
|
| 995 |
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
| 996 |
+
)
|
| 997 |
+
return special_image_mask
|
| 998 |
+
|
| 999 |
+
@can_return_tuple
|
| 1000 |
+
@auto_docstring
|
| 1001 |
+
def forward(
|
| 1002 |
+
self,
|
| 1003 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1004 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1005 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1006 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1007 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1008 |
+
past_key_values: Optional[Cache] = None,
|
| 1009 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1010 |
+
use_cache: Optional[bool] = None,
|
| 1011 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1012 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 1013 |
+
) -> Union[tuple, AriaModelOutputWithPast]:
|
| 1014 |
+
if inputs_embeds is None:
|
| 1015 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 1016 |
+
|
| 1017 |
+
# 2. Merge text and images
|
| 1018 |
+
if pixel_values is not None and inputs_embeds.shape[1] != 1:
|
| 1019 |
+
image_features = self.get_image_features(
|
| 1020 |
+
pixel_values=pixel_values,
|
| 1021 |
+
pixel_mask=pixel_mask,
|
| 1022 |
+
vision_feature_layer=self.config.vision_feature_layer,
|
| 1023 |
+
)
|
| 1024 |
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 1025 |
+
special_image_mask = self.get_placeholder_mask(
|
| 1026 |
+
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
| 1027 |
+
)
|
| 1028 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
| 1029 |
+
|
| 1030 |
+
outputs = self.language_model(
|
| 1031 |
+
attention_mask=attention_mask,
|
| 1032 |
+
position_ids=position_ids,
|
| 1033 |
+
past_key_values=past_key_values,
|
| 1034 |
+
inputs_embeds=inputs_embeds,
|
| 1035 |
+
use_cache=use_cache,
|
| 1036 |
+
cache_position=cache_position,
|
| 1037 |
+
**kwargs,
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
return AriaModelOutputWithPast(
|
| 1041 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 1042 |
+
past_key_values=outputs.past_key_values if use_cache else None,
|
| 1043 |
+
hidden_states=outputs.hidden_states,
|
| 1044 |
+
attentions=outputs.attentions,
|
| 1045 |
+
image_hidden_states=image_features if pixel_values is not None else None,
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
def _create_patch_attention_mask(self, pixel_mask):
|
| 1049 |
+
if pixel_mask is None:
|
| 1050 |
+
return None
|
| 1051 |
+
|
| 1052 |
+
patches_subgrid = pixel_mask.unfold(
|
| 1053 |
+
dimension=1,
|
| 1054 |
+
size=self.vision_tower.config.patch_size,
|
| 1055 |
+
step=self.vision_tower.config.patch_size,
|
| 1056 |
+
)
|
| 1057 |
+
patches_subgrid = patches_subgrid.unfold(
|
| 1058 |
+
dimension=2,
|
| 1059 |
+
size=self.vision_tower.config.patch_size,
|
| 1060 |
+
step=self.vision_tower.config.patch_size,
|
| 1061 |
+
)
|
| 1062 |
+
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
| 1063 |
+
|
| 1064 |
+
|
| 1065 |
+
@auto_docstring(
|
| 1066 |
+
custom_intro="""
|
| 1067 |
+
Aria model for conditional generation tasks.
|
| 1068 |
+
|
| 1069 |
+
This model combines a vision tower, a multi-modal projector, and a language model
|
| 1070 |
+
to perform tasks that involve both image and text inputs.
|
| 1071 |
+
"""
|
| 1072 |
+
)
|
| 1073 |
+
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
| 1074 |
+
_checkpoint_conversion_mapping = {
|
| 1075 |
+
"^language_model.model": "model.language_model",
|
| 1076 |
+
"^vision_tower": "model.vision_tower",
|
| 1077 |
+
"^multi_modal_projector": "model.multi_modal_projector",
|
| 1078 |
+
"^language_model.lm_head": "lm_head",
|
| 1079 |
+
}
|
| 1080 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1081 |
+
|
| 1082 |
+
def __init__(self, config: AriaConfig):
|
| 1083 |
+
super().__init__(config)
|
| 1084 |
+
self.model = AriaModel(config)
|
| 1085 |
+
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
| 1086 |
+
self.post_init()
|
| 1087 |
+
|
| 1088 |
+
def get_input_embeddings(self):
|
| 1089 |
+
return self.model.get_input_embeddings()
|
| 1090 |
+
|
| 1091 |
+
def set_input_embeddings(self, value):
|
| 1092 |
+
self.model.set_input_embeddings(value)
|
| 1093 |
+
|
| 1094 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 1095 |
+
return self.lm_head
|
| 1096 |
+
|
| 1097 |
+
def set_decoder(self, decoder):
|
| 1098 |
+
self.model.set_decoder(decoder)
|
| 1099 |
+
|
| 1100 |
+
def get_decoder(self):
|
| 1101 |
+
return self.model.get_decoder()
|
| 1102 |
+
|
| 1103 |
+
def get_image_features(
|
| 1104 |
+
self,
|
| 1105 |
+
pixel_values: torch.FloatTensor,
|
| 1106 |
+
pixel_mask: Optional[torch.FloatTensor] = None,
|
| 1107 |
+
vision_feature_layer: int = -1,
|
| 1108 |
+
):
|
| 1109 |
+
return self.model.get_image_features(
|
| 1110 |
+
pixel_values=pixel_values,
|
| 1111 |
+
pixel_mask=pixel_mask,
|
| 1112 |
+
vision_feature_layer=vision_feature_layer,
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
# Make modules available through conditional class for BC
|
| 1116 |
+
@property
|
| 1117 |
+
def language_model(self):
|
| 1118 |
+
return self.model.language_model
|
| 1119 |
+
|
| 1120 |
+
@property
|
| 1121 |
+
def vision_tower(self):
|
| 1122 |
+
return self.model.vision_tower
|
| 1123 |
+
|
| 1124 |
+
@property
|
| 1125 |
+
def multi_modal_projector(self):
|
| 1126 |
+
return self.model.multi_modal_projector
|
| 1127 |
+
|
| 1128 |
+
@can_return_tuple
|
| 1129 |
+
@auto_docstring
|
| 1130 |
+
def forward(
|
| 1131 |
+
self,
|
| 1132 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1133 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1134 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1135 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1136 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1137 |
+
past_key_values: Optional[Cache] = None,
|
| 1138 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1139 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1140 |
+
use_cache: Optional[bool] = None,
|
| 1141 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1142 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1143 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1144 |
+
) -> Union[tuple, AriaCausalLMOutputWithPast]:
|
| 1145 |
+
r"""
|
| 1146 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1147 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1148 |
+
config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `AriaForConditionalGeneration`).
|
| 1149 |
+
Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
|
| 1150 |
+
computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1151 |
+
|
| 1152 |
+
Example:
|
| 1153 |
+
|
| 1154 |
+
```python
|
| 1155 |
+
>>> import requests
|
| 1156 |
+
>>> import torch
|
| 1157 |
+
>>> from PIL import Image
|
| 1158 |
+
>>> from io import BytesIO
|
| 1159 |
+
|
| 1160 |
+
>>> from transformers import AutoProcessor, AutoModel
|
| 1161 |
+
>>> from transformers.image_utils import load_image
|
| 1162 |
+
|
| 1163 |
+
>>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
|
| 1164 |
+
>>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
|
| 1165 |
+
>>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
|
| 1166 |
+
>>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
|
| 1167 |
+
|
| 1168 |
+
>>> processor = AutoProcessor.from_pretrained("Rhymes-AI/Aria")
|
| 1169 |
+
>>> model = AutoModel.from_pretrained("Rhymes-AI/Aria", dtype=torch.bfloat16, device_map="auto")
|
| 1170 |
+
|
| 1171 |
+
>>> # Create inputs
|
| 1172 |
+
>>> messages = [
|
| 1173 |
+
... {
|
| 1174 |
+
... "role": "user",
|
| 1175 |
+
... "content": [
|
| 1176 |
+
... {"type": "image"},
|
| 1177 |
+
... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
|
| 1178 |
+
... {"type": "image"},
|
| 1179 |
+
... {"type": "text", "text": "What can we see in this image?"},
|
| 1180 |
+
... ]
|
| 1181 |
+
... },
|
| 1182 |
+
... {
|
| 1183 |
+
... "role": "user",
|
| 1184 |
+
... "content": [
|
| 1185 |
+
... {"type": "image"},
|
| 1186 |
+
... {"type": "text", "text": "In which city is that bridge located?"},
|
| 1187 |
+
... ]
|
| 1188 |
+
... }
|
| 1189 |
+
... ]
|
| 1190 |
+
|
| 1191 |
+
>>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
|
| 1192 |
+
>>> images = [[image1, image2], [image3]]
|
| 1193 |
+
>>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
|
| 1194 |
+
|
| 1195 |
+
>>> # Generate
|
| 1196 |
+
>>> generated_ids = model.generate(**inputs, max_new_tokens=256)
|
| 1197 |
+
>>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 1198 |
+
|
| 1199 |
+
>>> print(generated_texts[0])
|
| 1200 |
+
Assistant: There are buildings, trees, lights, and water visible in this image.
|
| 1201 |
+
|
| 1202 |
+
>>> print(generated_texts[1])
|
| 1203 |
+
Assistant: The bridge is in San Francisco.
|
| 1204 |
+
```"""
|
| 1205 |
+
outputs = self.model(
|
| 1206 |
+
input_ids=input_ids,
|
| 1207 |
+
pixel_values=pixel_values,
|
| 1208 |
+
pixel_mask=pixel_mask,
|
| 1209 |
+
attention_mask=attention_mask,
|
| 1210 |
+
position_ids=position_ids,
|
| 1211 |
+
past_key_values=past_key_values,
|
| 1212 |
+
inputs_embeds=inputs_embeds,
|
| 1213 |
+
use_cache=use_cache,
|
| 1214 |
+
cache_position=cache_position,
|
| 1215 |
+
**kwargs,
|
| 1216 |
+
)
|
| 1217 |
+
|
| 1218 |
+
hidden_states = outputs[0]
|
| 1219 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1220 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1221 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 1222 |
+
|
| 1223 |
+
loss = None
|
| 1224 |
+
if labels is not None:
|
| 1225 |
+
loss = self.loss_function(
|
| 1226 |
+
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
| 1227 |
+
)
|
| 1228 |
+
|
| 1229 |
+
return AriaCausalLMOutputWithPast(
|
| 1230 |
+
loss=loss,
|
| 1231 |
+
logits=logits,
|
| 1232 |
+
past_key_values=outputs.past_key_values,
|
| 1233 |
+
hidden_states=outputs.hidden_states,
|
| 1234 |
+
attentions=outputs.attentions,
|
| 1235 |
+
)
|
| 1236 |
+
|
| 1237 |
+
def prepare_inputs_for_generation(
|
| 1238 |
+
self,
|
| 1239 |
+
input_ids,
|
| 1240 |
+
past_key_values=None,
|
| 1241 |
+
inputs_embeds=None,
|
| 1242 |
+
pixel_values=None,
|
| 1243 |
+
pixel_mask=None,
|
| 1244 |
+
attention_mask=None,
|
| 1245 |
+
cache_position=None,
|
| 1246 |
+
logits_to_keep=None,
|
| 1247 |
+
**kwargs,
|
| 1248 |
+
):
|
| 1249 |
+
model_inputs = super().prepare_inputs_for_generation(
|
| 1250 |
+
input_ids,
|
| 1251 |
+
past_key_values=past_key_values,
|
| 1252 |
+
inputs_embeds=inputs_embeds,
|
| 1253 |
+
attention_mask=attention_mask,
|
| 1254 |
+
cache_position=cache_position,
|
| 1255 |
+
logits_to_keep=logits_to_keep,
|
| 1256 |
+
**kwargs,
|
| 1257 |
+
)
|
| 1258 |
+
|
| 1259 |
+
if cache_position[0] == 0:
|
| 1260 |
+
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
| 1261 |
+
# Otherwise we need pixel values to be passed to model
|
| 1262 |
+
model_inputs["pixel_values"] = pixel_values
|
| 1263 |
+
model_inputs["pixel_mask"] = pixel_mask
|
| 1264 |
+
|
| 1265 |
+
return model_inputs
|
| 1266 |
+
|
| 1267 |
+
|
| 1268 |
+
__all__ = [
|
| 1269 |
+
"AriaForConditionalGeneration",
|
| 1270 |
+
"AriaPreTrainedModel",
|
| 1271 |
+
"AriaTextPreTrainedModel",
|
| 1272 |
+
"AriaTextModel",
|
| 1273 |
+
"AriaModel",
|
| 1274 |
+
"AriaTextForCausalLM",
|
| 1275 |
+
]
|
venv/lib/python3.13/site-packages/transformers/models/aria/modular_aria.py
ADDED
|
@@ -0,0 +1,1610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from collections.abc import Iterable
|
| 16 |
+
from typing import Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from torch import nn
|
| 21 |
+
|
| 22 |
+
from ...activations import ACT2FN
|
| 23 |
+
from ...cache_utils import Cache
|
| 24 |
+
from ...configuration_utils import PretrainedConfig
|
| 25 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
|
| 26 |
+
from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
|
| 27 |
+
from ...image_utils import (
|
| 28 |
+
ChannelDimension,
|
| 29 |
+
ImageInput,
|
| 30 |
+
PILImageResampling,
|
| 31 |
+
get_image_size,
|
| 32 |
+
infer_channel_dimension_format,
|
| 33 |
+
is_scaled_image,
|
| 34 |
+
make_flat_list_of_images,
|
| 35 |
+
to_numpy_array,
|
| 36 |
+
valid_images,
|
| 37 |
+
validate_preprocess_arguments,
|
| 38 |
+
)
|
| 39 |
+
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
| 40 |
+
from ...modeling_utils import PreTrainedModel
|
| 41 |
+
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
| 42 |
+
from ...tokenization_utils import PreTokenizedInput, TextInput
|
| 43 |
+
from ...utils import TensorType, TransformersKwargs, auto_docstring, can_return_tuple, logging
|
| 44 |
+
from ..auto import CONFIG_MAPPING, AutoConfig, AutoTokenizer
|
| 45 |
+
from ..llama.configuration_llama import LlamaConfig
|
| 46 |
+
from ..llama.modeling_llama import (
|
| 47 |
+
LlamaAttention,
|
| 48 |
+
LlamaDecoderLayer,
|
| 49 |
+
LlamaForCausalLM,
|
| 50 |
+
LlamaMLP,
|
| 51 |
+
LlamaModel,
|
| 52 |
+
LlamaPreTrainedModel,
|
| 53 |
+
LlamaRMSNorm,
|
| 54 |
+
)
|
| 55 |
+
from ..llava.modeling_llava import (
|
| 56 |
+
LlavaCausalLMOutputWithPast,
|
| 57 |
+
LlavaForConditionalGeneration,
|
| 58 |
+
LlavaModel,
|
| 59 |
+
LlavaModelOutputWithPast,
|
| 60 |
+
)
|
| 61 |
+
from ..llava_next.image_processing_llava_next import divide_to_patches
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
logger = logging.get_logger(__name__)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
|
| 68 |
+
"""
|
| 69 |
+
Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
token_states (torch.Tensor): Input tensor of shape (num_tokens, in_features).
|
| 73 |
+
expert_weights (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features).
|
| 74 |
+
tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
torch.Tensor: Output tensor of shape (num_tokens, out_features).
|
| 78 |
+
"""
|
| 79 |
+
num_tokens = token_states.shape[0]
|
| 80 |
+
out_features = expert_weights.shape[-1]
|
| 81 |
+
output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device)
|
| 82 |
+
|
| 83 |
+
cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
|
| 84 |
+
# Insert zero at the beginning for offset index's convenience
|
| 85 |
+
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
|
| 86 |
+
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
|
| 87 |
+
|
| 88 |
+
for expert_num in range(expert_weights.shape[0]):
|
| 89 |
+
start = cumsum_num_tokens[expert_num]
|
| 90 |
+
end = cumsum_num_tokens[expert_num + 1]
|
| 91 |
+
tokens = token_states[start:end]
|
| 92 |
+
|
| 93 |
+
out = torch.matmul(tokens, expert_weights[expert_num])
|
| 94 |
+
output[start:end] = out
|
| 95 |
+
return output
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class AriaTextConfig(LlamaConfig):
|
| 99 |
+
r"""
|
| 100 |
+
This class handles the configuration for the text component of the Aria model.
|
| 101 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
|
| 102 |
+
[rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
|
| 103 |
+
This class extends the LlamaConfig to include additional parameters specific to the Mixture of Experts (MoE) architecture.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 107 |
+
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
|
| 108 |
+
`inputs_ids` passed when calling [`LlamaModel`]
|
| 109 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 110 |
+
Dimension of the hidden representations.
|
| 111 |
+
intermediate_size (`int`, *optional*, defaults to 4096):
|
| 112 |
+
The size of the MLP representations.
|
| 113 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 114 |
+
Number of hidden layers in the Transformer decoder.
|
| 115 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 116 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 117 |
+
num_key_value_heads (`int`, *optional*):
|
| 118 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 119 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 120 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 121 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 122 |
+
by meanpooling all the original heads within that group. For more details, check out [this
|
| 123 |
+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
| 124 |
+
`num_attention_heads`.
|
| 125 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 126 |
+
The non-linear activation function (function or string) in the decoder.
|
| 127 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 128 |
+
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
|
| 129 |
+
Llama 2 up to 4096, CodeLlama up to 16384.
|
| 130 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 131 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 132 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 133 |
+
The epsilon used by the rms normalization layers.
|
| 134 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 135 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 136 |
+
relevant if `config.is_decoder=True`.
|
| 137 |
+
pad_token_id (`int`, *optional*, defaults to 2):
|
| 138 |
+
Padding token id.
|
| 139 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 140 |
+
Beginning of stream token id.
|
| 141 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 142 |
+
End of stream token id.
|
| 143 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
| 144 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
| 145 |
+
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
|
| 146 |
+
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
|
| 147 |
+
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
|
| 148 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 149 |
+
Whether to tie weight embeddings
|
| 150 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 151 |
+
The base period of the RoPE embeddings.
|
| 152 |
+
rope_scaling (`Dict`, *optional*):
|
| 153 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 154 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 155 |
+
accordingly.
|
| 156 |
+
Expected contents:
|
| 157 |
+
`rope_type` (`str`):
|
| 158 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 159 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 160 |
+
`factor` (`float`, *optional*):
|
| 161 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 162 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 163 |
+
original maximum pre-trained length.
|
| 164 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 165 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 166 |
+
pretraining.
|
| 167 |
+
`attention_factor` (`float`, *optional*):
|
| 168 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 169 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 170 |
+
`factor` field to infer the suggested value.
|
| 171 |
+
`beta_fast` (`float`, *optional*):
|
| 172 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 173 |
+
ramp function. If unspecified, it defaults to 32.
|
| 174 |
+
`beta_slow` (`float`, *optional*):
|
| 175 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 176 |
+
ramp function. If unspecified, it defaults to 1.
|
| 177 |
+
`short_factor` (`list[float]`, *optional*):
|
| 178 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 179 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 180 |
+
size divided by the number of attention heads divided by 2
|
| 181 |
+
`long_factor` (`list[float]`, *optional*):
|
| 182 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 183 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 184 |
+
size divided by the number of attention heads divided by 2
|
| 185 |
+
`low_freq_factor` (`float`, *optional*):
|
| 186 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 187 |
+
`high_freq_factor` (`float`, *optional*):
|
| 188 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 189 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 190 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 191 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 192 |
+
The dropout ratio for the attention probabilities.
|
| 193 |
+
mlp_bias (`bool`, *optional*, defaults to `False`):
|
| 194 |
+
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
| 195 |
+
head_dim (`int`, *optional*):
|
| 196 |
+
The attention head dimension. If None, it will default to hidden_size // num_heads
|
| 197 |
+
moe_num_experts (`int`, *optional*, defaults to 8):
|
| 198 |
+
The number of experts in the MoE layer.
|
| 199 |
+
moe_topk (`int`, *optional*, defaults to 2):
|
| 200 |
+
The number of top experts to route to for each token.
|
| 201 |
+
moe_num_shared_experts (`int`, *optional*, defaults to 2):
|
| 202 |
+
The number of shared experts.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
model_type = "aria_text"
|
| 206 |
+
base_config_key = "text_config"
|
| 207 |
+
|
| 208 |
+
def __init__(
|
| 209 |
+
self,
|
| 210 |
+
intermediate_size: int = 4096,
|
| 211 |
+
moe_num_experts: int = 8,
|
| 212 |
+
moe_topk: int = 2,
|
| 213 |
+
moe_num_shared_experts: int = 2,
|
| 214 |
+
pad_token_id=2,
|
| 215 |
+
**super_kwargs,
|
| 216 |
+
):
|
| 217 |
+
super().__init__(pad_token_id=pad_token_id, **super_kwargs)
|
| 218 |
+
self.intermediate_size = intermediate_size
|
| 219 |
+
self.moe_num_experts = moe_num_experts
|
| 220 |
+
self.moe_topk = moe_topk
|
| 221 |
+
self.moe_num_shared_experts = moe_num_shared_experts
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class AriaConfig(PretrainedConfig):
|
| 225 |
+
r"""
|
| 226 |
+
This class handles the configuration for both vision and text components of the Aria model,
|
| 227 |
+
as well as additional parameters for image token handling and projector mapping.
|
| 228 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
|
| 229 |
+
[rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
|
| 230 |
+
|
| 231 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 232 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
vision_config (`AriaVisionConfig` or `dict`, *optional*):
|
| 236 |
+
Configuration for the vision component.
|
| 237 |
+
vision_feature_layer (`int`, *optional*, defaults to -1):
|
| 238 |
+
The index of the layer to select the vision feature.
|
| 239 |
+
text_config (`AriaTextConfig` or `dict`, *optional*):
|
| 240 |
+
Configuration for the text component.
|
| 241 |
+
projector_patch_to_query_dict (`dict`, *optional*):
|
| 242 |
+
Mapping of patch sizes to query dimensions.
|
| 243 |
+
image_token_index (`int`, *optional*, defaults to 9):
|
| 244 |
+
Index used to represent image tokens.
|
| 245 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 246 |
+
The standard deviation of the truncated normal initializer for initializing all weight matrices.
|
| 247 |
+
|
| 248 |
+
Attributes:
|
| 249 |
+
model_type (`str`):
|
| 250 |
+
Type of the model, set to `"aria"`.
|
| 251 |
+
image_token_index (`int`):
|
| 252 |
+
Index used to represent image tokens.
|
| 253 |
+
projector_patch_to_query_dict (`dict`):
|
| 254 |
+
Mapping of patch sizes to query dimensions.
|
| 255 |
+
vision_config (`AriaVisionConfig`):
|
| 256 |
+
Configuration for the vision component.
|
| 257 |
+
text_config (`AriaTextConfig`):
|
| 258 |
+
Configuration for the text component.
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
model_type = "aria"
|
| 262 |
+
attribute_map = {
|
| 263 |
+
"image_token_id": "image_token_index",
|
| 264 |
+
}
|
| 265 |
+
sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig}
|
| 266 |
+
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
vision_config=None,
|
| 270 |
+
vision_feature_layer: int = -1,
|
| 271 |
+
text_config: AriaTextConfig = None,
|
| 272 |
+
projector_patch_to_query_dict: Optional[dict] = None,
|
| 273 |
+
image_token_index: int = 9,
|
| 274 |
+
initializer_range: float = 0.02,
|
| 275 |
+
**kwargs,
|
| 276 |
+
):
|
| 277 |
+
self.image_token_index = image_token_index
|
| 278 |
+
|
| 279 |
+
# Convert the keys and values of projector_patch_to_query_dict to integers
|
| 280 |
+
# This ensures consistency even if they were provided as strings
|
| 281 |
+
if projector_patch_to_query_dict is None:
|
| 282 |
+
projector_patch_to_query_dict = {
|
| 283 |
+
1225: 128,
|
| 284 |
+
4900: 256,
|
| 285 |
+
}
|
| 286 |
+
self.projector_patch_to_query_dict = {int(k): int(v) for k, v in projector_patch_to_query_dict.items()}
|
| 287 |
+
self.max_value_projector_patch_to_query_dict = max(self.projector_patch_to_query_dict.values())
|
| 288 |
+
self.vision_feature_layer = vision_feature_layer
|
| 289 |
+
if isinstance(vision_config, dict):
|
| 290 |
+
vision_config["model_type"] = "idefics3_vision"
|
| 291 |
+
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
| 292 |
+
elif vision_config is None:
|
| 293 |
+
vision_config = CONFIG_MAPPING["idefics3_vision"]()
|
| 294 |
+
|
| 295 |
+
self.vision_config = vision_config
|
| 296 |
+
self.initializer_range = initializer_range
|
| 297 |
+
|
| 298 |
+
if isinstance(text_config, dict) and "model_type" in text_config:
|
| 299 |
+
text_config = AriaTextConfig(**text_config)
|
| 300 |
+
elif text_config is None:
|
| 301 |
+
text_config = AriaTextConfig()
|
| 302 |
+
|
| 303 |
+
self.text_config = text_config
|
| 304 |
+
|
| 305 |
+
super().__init__(**kwargs)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class AriaTextRMSNorm(LlamaRMSNorm):
|
| 309 |
+
pass
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class AriaProjectorMLP(nn.Module):
|
| 313 |
+
"""
|
| 314 |
+
Feed-Forward Network module for the Aria Projector.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
in_features (`int`):
|
| 318 |
+
Input embedding dimension.
|
| 319 |
+
hidden_features (`int`):
|
| 320 |
+
Hidden dimension of the feed-forward network.
|
| 321 |
+
output_dim (`int`):
|
| 322 |
+
Output dimension.
|
| 323 |
+
"""
|
| 324 |
+
|
| 325 |
+
def __init__(self, in_features, hidden_features, output_dim):
|
| 326 |
+
super().__init__()
|
| 327 |
+
self.linear_in = nn.Linear(in_features, hidden_features, bias=False)
|
| 328 |
+
self.linear_out = nn.Linear(hidden_features, output_dim, bias=False)
|
| 329 |
+
self.act = ACT2FN["gelu_new"]
|
| 330 |
+
|
| 331 |
+
def forward(self, hidden_states):
|
| 332 |
+
hidden_states = self.act(self.linear_in(hidden_states))
|
| 333 |
+
hidden_states = self.linear_out(hidden_states)
|
| 334 |
+
return hidden_states
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class AriaCrossAttention(nn.Module):
|
| 338 |
+
"""
|
| 339 |
+
Aria Cross-Attention module.
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
config (`AriaConfig`):
|
| 343 |
+
The configuration to use.
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
def __init__(self, config: AriaConfig, dropout_rate: float = 0):
|
| 347 |
+
super().__init__()
|
| 348 |
+
hidden_size = config.vision_config.hidden_size
|
| 349 |
+
num_heads = config.vision_config.num_attention_heads
|
| 350 |
+
self.num_heads = num_heads
|
| 351 |
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 352 |
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 353 |
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 354 |
+
|
| 355 |
+
# Original code here: https://github.com/rhymes-ai/Aria/blob/719ff4e52b727443cba3793b0e27fe64e0244fe1/aria/model/projector.py#L48
|
| 356 |
+
self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
|
| 357 |
+
self.linear = nn.Linear(hidden_size, hidden_size)
|
| 358 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 359 |
+
|
| 360 |
+
self.layer_norm = nn.LayerNorm(hidden_size)
|
| 361 |
+
self.layer_norm_kv = nn.LayerNorm(hidden_size)
|
| 362 |
+
|
| 363 |
+
def forward(self, key_value_states, hidden_states, attn_mask=None):
|
| 364 |
+
"""
|
| 365 |
+
Forward pass of the AriaCrossAttention module.
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
key_value_states (`torch.Tensor`):
|
| 369 |
+
Input tensor for key and value.
|
| 370 |
+
hidden_states (`torch.Tensor`):
|
| 371 |
+
Input tensor for query.
|
| 372 |
+
attn_mask (`torch.Tensor`, *optional*, defaults to None):
|
| 373 |
+
Attention mask.
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
torch.Tensor:
|
| 377 |
+
Output tensor after cross-attention.
|
| 378 |
+
"""
|
| 379 |
+
query = self.q_proj(self.layer_norm(hidden_states))
|
| 380 |
+
|
| 381 |
+
key_value_states = self.layer_norm_kv(key_value_states)
|
| 382 |
+
key = self.k_proj(key_value_states)
|
| 383 |
+
value = self.v_proj(key_value_states)
|
| 384 |
+
|
| 385 |
+
attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask)
|
| 386 |
+
|
| 387 |
+
attn_output = self.dropout(self.linear(attn_output))
|
| 388 |
+
|
| 389 |
+
return attn_output
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class AriaProjector(nn.Module):
|
| 393 |
+
"""
|
| 394 |
+
Aria Projector module.
|
| 395 |
+
|
| 396 |
+
This module projects vision features into the language model's embedding space, enabling interaction between vision and language components.
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
config (`AriaConfig`):
|
| 400 |
+
Configuration object for the model.
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
def __init__(
|
| 404 |
+
self,
|
| 405 |
+
config: AriaConfig,
|
| 406 |
+
):
|
| 407 |
+
super().__init__()
|
| 408 |
+
|
| 409 |
+
self.patch_to_query_dict = config.projector_patch_to_query_dict
|
| 410 |
+
self.in_features = config.vision_config.hidden_size
|
| 411 |
+
self.num_heads = config.vision_config.num_attention_heads
|
| 412 |
+
self.kv_dim = config.vision_config.hidden_size
|
| 413 |
+
self.hidden_features = config.text_config.hidden_size
|
| 414 |
+
self.output_dim = config.text_config.hidden_size
|
| 415 |
+
|
| 416 |
+
self.query = nn.Parameter(torch.zeros(config.max_value_projector_patch_to_query_dict, self.in_features))
|
| 417 |
+
|
| 418 |
+
self.cross_attn = AriaCrossAttention(config)
|
| 419 |
+
|
| 420 |
+
self.layer_norm = nn.LayerNorm(self.in_features)
|
| 421 |
+
self.feed_forward = AriaProjectorMLP(self.in_features, self.hidden_features, self.output_dim)
|
| 422 |
+
|
| 423 |
+
def forward(self, key_value_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
| 424 |
+
"""
|
| 425 |
+
Forward pass of the Projector module.
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
key_value_states (`torch.Tensor`):
|
| 429 |
+
Input tensor of shape (batch_size, num_patches, kv_dim).
|
| 430 |
+
attn_mask (`torch.Tensor`, *optional*, default is None):
|
| 431 |
+
Attention mask.
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
`torch.Tensor`: Output tensor of shape (batch_size, query_number, output_dim).
|
| 435 |
+
"""
|
| 436 |
+
batch_size, num_patches = key_value_states.shape[0], key_value_states.shape[1]
|
| 437 |
+
|
| 438 |
+
if num_patches not in self.patch_to_query_dict:
|
| 439 |
+
raise KeyError(
|
| 440 |
+
f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}."
|
| 441 |
+
)
|
| 442 |
+
query_num = self.patch_to_query_dict[num_patches]
|
| 443 |
+
|
| 444 |
+
queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
|
| 445 |
+
|
| 446 |
+
if attn_mask is not None:
|
| 447 |
+
attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
|
| 448 |
+
attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)
|
| 449 |
+
|
| 450 |
+
attention_out = self.cross_attn(key_value_states, queries, attn_mask=attn_mask)
|
| 451 |
+
|
| 452 |
+
out = self.feed_forward(self.layer_norm(attention_out))
|
| 453 |
+
|
| 454 |
+
return out
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class AriaImageProcessor(BaseImageProcessor):
|
| 458 |
+
"""
|
| 459 |
+
A vision processor for the Aria model that handles image preprocessing.
|
| 460 |
+
Initialize the AriaImageProcessor.
|
| 461 |
+
|
| 462 |
+
Args:
|
| 463 |
+
image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
|
| 464 |
+
Mean values for normalization.
|
| 465 |
+
image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
|
| 466 |
+
Standard deviation values for normalization.
|
| 467 |
+
max_image_size (`int`, *optional*, defaults to 980):
|
| 468 |
+
Maximum image size.
|
| 469 |
+
min_image_size (`int`, *optional*, defaults to 336):
|
| 470 |
+
Minimum image size.
|
| 471 |
+
split_resolutions (`list`, *optional*, defaults to a list of optimal,resolutions as tuples):
|
| 472 |
+
The optimal resolutions for splitting the image.
|
| 473 |
+
split_image (`bool`, *optional*, defaults to `False`):
|
| 474 |
+
Whether to split the image.
|
| 475 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 476 |
+
Whether to convert the image to RGB.
|
| 477 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 478 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
| 479 |
+
the `preprocess` method.
|
| 480 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 481 |
+
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
| 482 |
+
method.
|
| 483 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 484 |
+
Whether to normalize the image.
|
| 485 |
+
resample (PILImageResampling, *optional*, defaults to `BICUBIC`):
|
| 486 |
+
The resampling filter to use if resizing the image.
|
| 487 |
+
"""
|
| 488 |
+
|
| 489 |
+
model_input_names = ["pixel_values", "pixel_mask", "num_crops"]
|
| 490 |
+
|
| 491 |
+
def __init__(
|
| 492 |
+
self,
|
| 493 |
+
image_mean: Optional[list[float]] = None,
|
| 494 |
+
image_std: Optional[list[float]] = None,
|
| 495 |
+
max_image_size: int = 980,
|
| 496 |
+
min_image_size: int = 336,
|
| 497 |
+
split_resolutions: Optional[list[tuple[int, int]]] = None,
|
| 498 |
+
split_image: Optional[bool] = False,
|
| 499 |
+
do_convert_rgb: Optional[bool] = True,
|
| 500 |
+
do_rescale: bool = True,
|
| 501 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 502 |
+
do_normalize: Optional[bool] = True,
|
| 503 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 504 |
+
**kwargs,
|
| 505 |
+
):
|
| 506 |
+
super().__init__(**kwargs)
|
| 507 |
+
|
| 508 |
+
if image_mean is None:
|
| 509 |
+
image_mean = [0.5, 0.5, 0.5]
|
| 510 |
+
if image_std is None:
|
| 511 |
+
image_std = [0.5, 0.5, 0.5]
|
| 512 |
+
self.max_image_size = max_image_size
|
| 513 |
+
self.min_image_size = min_image_size
|
| 514 |
+
self.image_mean = image_mean
|
| 515 |
+
self.image_std = image_std
|
| 516 |
+
self.split_image = split_image
|
| 517 |
+
if split_resolutions is None:
|
| 518 |
+
split_resolutions = [(1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (2, 4), (2, 3), (2, 2), (2, 1), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1), (6, 1), (7, 1), (8, 1)] # fmt: skip
|
| 519 |
+
split_resolutions = [(el[0] * 490, el[1] * 490) for el in split_resolutions]
|
| 520 |
+
self.split_resolutions = split_resolutions
|
| 521 |
+
self.do_convert_rgb = do_convert_rgb
|
| 522 |
+
self.do_rescale = do_rescale
|
| 523 |
+
self.rescale_factor = rescale_factor
|
| 524 |
+
self.do_normalize = do_normalize
|
| 525 |
+
self.resample = resample
|
| 526 |
+
|
| 527 |
+
def preprocess(
|
| 528 |
+
self,
|
| 529 |
+
images: Union[ImageInput, list[ImageInput]],
|
| 530 |
+
image_mean: Optional[Union[float, list[float]]] = None,
|
| 531 |
+
image_std: Optional[Union[float, list[float]]] = None,
|
| 532 |
+
max_image_size: Optional[int] = None,
|
| 533 |
+
min_image_size: Optional[int] = None,
|
| 534 |
+
split_image: Optional[bool] = None,
|
| 535 |
+
do_convert_rgb: Optional[bool] = None,
|
| 536 |
+
do_rescale: Optional[bool] = None,
|
| 537 |
+
rescale_factor: Optional[float] = None,
|
| 538 |
+
do_normalize: Optional[bool] = None,
|
| 539 |
+
resample: Optional[PILImageResampling] = None,
|
| 540 |
+
return_tensors: Optional[Union[str, TensorType]] = "pt",
|
| 541 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 542 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 543 |
+
):
|
| 544 |
+
"""
|
| 545 |
+
Process a list of images.
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
images (ImageInput or list of ImageInput):
|
| 549 |
+
The input image or a list of images.
|
| 550 |
+
image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
|
| 551 |
+
Mean values for normalization.
|
| 552 |
+
image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
|
| 553 |
+
Standard deviation values for normalization.
|
| 554 |
+
max_image_size (`int`, *optional*, defaults to `self.max_image_size` (980)):
|
| 555 |
+
Maximum image size.
|
| 556 |
+
min_image_size (`int`, *optional*, defaults to `self.min_image_size` (336)):
|
| 557 |
+
Minimum image size.
|
| 558 |
+
split_image (`bool`, *optional*, defaults to `self.split_image` (False)):
|
| 559 |
+
Whether to split the image.
|
| 560 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb` (True)):
|
| 561 |
+
Whether to convert the image to RGB.
|
| 562 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 563 |
+
Whether to rescale the image.
|
| 564 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 565 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 566 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize` (True)):
|
| 567 |
+
Whether to normalize the image.
|
| 568 |
+
resample (PILImageResampling, *optional*, defaults to `self.resample` (BICUBIC)):
|
| 569 |
+
The resampling filter to use if resizing the image.
|
| 570 |
+
return_tensors (`str` or `TensorType`, *optional*, defaults to "pt"):
|
| 571 |
+
The type of tensor to return.
|
| 572 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 573 |
+
The channel dimension format for the output image. Can be one of:
|
| 574 |
+
- `"channels_first"` or `ChannelDimension.FIRST`:
|
| 575 |
+
image in (num_channels, height, width) format.
|
| 576 |
+
- `"channels_last"` or `ChannelDimension.LAST`:
|
| 577 |
+
image in (height, width, num_channels) format.
|
| 578 |
+
If unset, will use same as the input image.
|
| 579 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 580 |
+
The channel dimension format for the input image. Can be one of:
|
| 581 |
+
- `"channels_first"` or `ChannelDimension.FIRST`:
|
| 582 |
+
image in (num_channels, height, width) format.
|
| 583 |
+
- `"channels_last"` or `ChannelDimension.LAST`:
|
| 584 |
+
image in (height, width, num_channels) format.
|
| 585 |
+
If unset, will use the inferred format of the input image.
|
| 586 |
+
|
| 587 |
+
Returns:
|
| 588 |
+
BatchFeature:
|
| 589 |
+
A BatchFeature object containing:
|
| 590 |
+
- 'pixel_values':
|
| 591 |
+
Tensor of processed image pixel values.
|
| 592 |
+
- 'pixel_mask':
|
| 593 |
+
Boolean pixel mask. This mask is a 2D tensor of shape (max_image_size, max_image_size) where:
|
| 594 |
+
- True (1) values indicate pixels that belong to the original resized image.
|
| 595 |
+
- False (0) values indicate pixels that are part of the padding.
|
| 596 |
+
The mask helps distinguish between actual image content and padded areas in subsequent processing steps.
|
| 597 |
+
- 'num_crops':
|
| 598 |
+
The maximum number of crops across all images.
|
| 599 |
+
"""
|
| 600 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 601 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 602 |
+
max_image_size = max_image_size if max_image_size is not None else self.max_image_size
|
| 603 |
+
min_image_size = min_image_size if min_image_size is not None else self.min_image_size
|
| 604 |
+
split_image = split_image if split_image is not None else self.split_image
|
| 605 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 606 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 607 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 608 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 609 |
+
resample = resample if resample is not None else self.resample
|
| 610 |
+
|
| 611 |
+
if max_image_size not in [490, 980]:
|
| 612 |
+
raise ValueError("max_image_size must be either 490 or 980")
|
| 613 |
+
|
| 614 |
+
images = self.fetch_images(images)
|
| 615 |
+
images = make_flat_list_of_images(images)
|
| 616 |
+
|
| 617 |
+
if not valid_images(images):
|
| 618 |
+
raise ValueError(
|
| 619 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 620 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
validate_preprocess_arguments(
|
| 624 |
+
do_normalize=do_normalize,
|
| 625 |
+
image_mean=image_mean,
|
| 626 |
+
image_std=image_std,
|
| 627 |
+
resample=resample,
|
| 628 |
+
do_rescale=do_rescale,
|
| 629 |
+
rescale_factor=rescale_factor,
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
if do_convert_rgb:
|
| 633 |
+
images = [convert_to_rgb(image) for image in images]
|
| 634 |
+
|
| 635 |
+
# All transformations expect numpy arrays.
|
| 636 |
+
images = [to_numpy_array(image) for image in images]
|
| 637 |
+
|
| 638 |
+
if do_rescale and is_scaled_image(images[0]):
|
| 639 |
+
logger.warning_once(
|
| 640 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 641 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
if input_data_format is None:
|
| 645 |
+
# We assume that all images have the same channel dimension format.
|
| 646 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 647 |
+
|
| 648 |
+
pixel_values = []
|
| 649 |
+
pixel_masks = []
|
| 650 |
+
num_crops = None
|
| 651 |
+
|
| 652 |
+
for image in images:
|
| 653 |
+
if split_image:
|
| 654 |
+
crop_images = self.get_image_patches(
|
| 655 |
+
image,
|
| 656 |
+
self.split_resolutions,
|
| 657 |
+
max_image_size,
|
| 658 |
+
resample,
|
| 659 |
+
data_format=input_data_format,
|
| 660 |
+
input_data_format=input_data_format,
|
| 661 |
+
)
|
| 662 |
+
else:
|
| 663 |
+
crop_images = [image]
|
| 664 |
+
if num_crops is None or len(crop_images) > num_crops:
|
| 665 |
+
num_crops = len(crop_images)
|
| 666 |
+
|
| 667 |
+
for crop_image in crop_images:
|
| 668 |
+
# At this point the scale is the rescaling factor that would bring the image to max_size in its larger dimension
|
| 669 |
+
h, w = get_image_size(crop_image)
|
| 670 |
+
scale = max_image_size / max(h, w)
|
| 671 |
+
if w >= h:
|
| 672 |
+
new_size = (max(int(h * scale), min_image_size), max_image_size) # h, w
|
| 673 |
+
else:
|
| 674 |
+
new_size = (max_image_size, max(int(w * scale), min_image_size)) # h, w
|
| 675 |
+
|
| 676 |
+
crop_image_resized = resize(
|
| 677 |
+
crop_image,
|
| 678 |
+
new_size,
|
| 679 |
+
resample=resample,
|
| 680 |
+
data_format=input_data_format,
|
| 681 |
+
input_data_format=input_data_format,
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
padding_bottom, padding_right = max_image_size - new_size[0], max_image_size - new_size[1]
|
| 685 |
+
crop_image_padded = pad(
|
| 686 |
+
crop_image_resized,
|
| 687 |
+
((0, padding_bottom), (0, padding_right)),
|
| 688 |
+
data_format=input_data_format,
|
| 689 |
+
input_data_format=input_data_format,
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
# Create a pixel mask
|
| 693 |
+
pixel_mask = np.zeros((max_image_size, max_image_size), dtype=bool)
|
| 694 |
+
pixel_mask[: new_size[0], : new_size[1]] = 1
|
| 695 |
+
pixel_masks.append(pixel_mask)
|
| 696 |
+
|
| 697 |
+
if do_rescale:
|
| 698 |
+
crop_image_padded = self.rescale(
|
| 699 |
+
image=crop_image_padded, scale=rescale_factor, input_data_format=input_data_format
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
if do_normalize:
|
| 703 |
+
crop_image_padded = self.normalize(
|
| 704 |
+
crop_image_padded,
|
| 705 |
+
self.image_mean,
|
| 706 |
+
self.image_std,
|
| 707 |
+
data_format=input_data_format,
|
| 708 |
+
input_data_format=input_data_format,
|
| 709 |
+
)
|
| 710 |
+
crop_image_padded = (
|
| 711 |
+
to_channel_dimension_format(crop_image_padded, data_format, input_data_format)
|
| 712 |
+
if data_format is not None
|
| 713 |
+
else crop_image_padded
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
pixel_values.append(crop_image_padded)
|
| 717 |
+
return BatchFeature(
|
| 718 |
+
data={
|
| 719 |
+
"pixel_values": np.stack(pixel_values, axis=0),
|
| 720 |
+
"pixel_mask": np.stack(pixel_masks, axis=0),
|
| 721 |
+
"num_crops": num_crops,
|
| 722 |
+
},
|
| 723 |
+
tensor_type=return_tensors,
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
def _resize_for_patching(
|
| 727 |
+
self, image: np.ndarray, target_resolution: tuple, resample, input_data_format: ChannelDimension
|
| 728 |
+
) -> np.ndarray:
|
| 729 |
+
"""
|
| 730 |
+
Resizes an image to a target resolution while maintaining aspect ratio.
|
| 731 |
+
|
| 732 |
+
Args:
|
| 733 |
+
image (np.ndarray):
|
| 734 |
+
The input image.
|
| 735 |
+
target_resolution (tuple):
|
| 736 |
+
The target resolution (height, width) of the image.
|
| 737 |
+
resample (`PILImageResampling`):
|
| 738 |
+
Resampling filter to use if resizing the image.
|
| 739 |
+
input_data_format (`ChannelDimension` or `str`):
|
| 740 |
+
The channel dimension format of the input image.
|
| 741 |
+
|
| 742 |
+
Returns:
|
| 743 |
+
np.ndarray: The resized and padded image.
|
| 744 |
+
"""
|
| 745 |
+
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
| 746 |
+
|
| 747 |
+
# Resize the image
|
| 748 |
+
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
|
| 749 |
+
|
| 750 |
+
return resized_image
|
| 751 |
+
|
| 752 |
+
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
|
| 753 |
+
original_height, original_width = original_resolution
|
| 754 |
+
target_height, target_width = target_resolution
|
| 755 |
+
paste_x, r_x = divmod(target_width - original_width, 2)
|
| 756 |
+
paste_y, r_y = divmod(target_height - original_height, 2)
|
| 757 |
+
return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
|
| 758 |
+
|
| 759 |
+
def _pad_for_patching(
|
| 760 |
+
self, image: np.ndarray, target_resolution: tuple, input_data_format: ChannelDimension
|
| 761 |
+
) -> np.ndarray:
|
| 762 |
+
"""
|
| 763 |
+
Pad an image to a target resolution while maintaining aspect ratio.
|
| 764 |
+
"""
|
| 765 |
+
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
|
| 766 |
+
padding = self._get_padding_size(new_resolution, target_resolution)
|
| 767 |
+
|
| 768 |
+
padded_image = self.pad(image, padding=padding)
|
| 769 |
+
|
| 770 |
+
return padded_image
|
| 771 |
+
|
| 772 |
+
def pad(
|
| 773 |
+
self,
|
| 774 |
+
image: np.ndarray,
|
| 775 |
+
padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]],
|
| 776 |
+
mode: PaddingMode = PaddingMode.CONSTANT,
|
| 777 |
+
constant_values: Union[float, Iterable[float]] = 0.0,
|
| 778 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 779 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 780 |
+
) -> np.ndarray:
|
| 781 |
+
"""
|
| 782 |
+
Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
|
| 783 |
+
dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
|
| 784 |
+
as input.
|
| 785 |
+
|
| 786 |
+
Args:
|
| 787 |
+
image (`np.ndarray`):
|
| 788 |
+
The image to pad.
|
| 789 |
+
padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
|
| 790 |
+
Padding to apply to the edges of the height, width axes. Can be one of three formats:
|
| 791 |
+
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
|
| 792 |
+
- `((before, after),)` yields same before and after pad for height and width.
|
| 793 |
+
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
|
| 794 |
+
mode (`PaddingMode`):
|
| 795 |
+
The padding mode to use. Can be one of:
|
| 796 |
+
- `"constant"`: pads with a constant value.
|
| 797 |
+
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
|
| 798 |
+
vector along each axis.
|
| 799 |
+
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
|
| 800 |
+
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
|
| 801 |
+
constant_values (`float` or `Iterable[float]`, *optional*):
|
| 802 |
+
The value to use for the padding if `mode` is `"constant"`.
|
| 803 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 804 |
+
The channel dimension format for the output image. Can be one of:
|
| 805 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 806 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 807 |
+
If unset, will use same as the input image.
|
| 808 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 809 |
+
The channel dimension format for the input image. Can be one of:
|
| 810 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 811 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 812 |
+
If unset, will use the inferred format of the input image.
|
| 813 |
+
|
| 814 |
+
Returns:
|
| 815 |
+
`np.ndarray`: The padded image.
|
| 816 |
+
|
| 817 |
+
"""
|
| 818 |
+
|
| 819 |
+
# call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
|
| 820 |
+
if isinstance(padding, int) or len(padding) != 4:
|
| 821 |
+
return pad(image, padding, mode, constant_values, data_format, input_data_format)
|
| 822 |
+
|
| 823 |
+
if input_data_format is None:
|
| 824 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 825 |
+
|
| 826 |
+
padding_mode_mapping = {
|
| 827 |
+
PaddingMode.CONSTANT: "constant",
|
| 828 |
+
PaddingMode.REFLECT: "reflect",
|
| 829 |
+
PaddingMode.REPLICATE: "edge",
|
| 830 |
+
PaddingMode.SYMMETRIC: "symmetric",
|
| 831 |
+
}
|
| 832 |
+
image = np.pad(image, padding, mode=padding_mode_mapping[mode], constant_values=constant_values)
|
| 833 |
+
image = (
|
| 834 |
+
to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
|
| 835 |
+
)
|
| 836 |
+
return image
|
| 837 |
+
|
| 838 |
+
def get_image_patches(
|
| 839 |
+
self,
|
| 840 |
+
image: np.ndarray,
|
| 841 |
+
grid_pinpoints: list[tuple[int, int]],
|
| 842 |
+
patch_size: int,
|
| 843 |
+
resample: PILImageResampling,
|
| 844 |
+
data_format: ChannelDimension,
|
| 845 |
+
input_data_format: ChannelDimension,
|
| 846 |
+
) -> list[np.ndarray]:
|
| 847 |
+
"""
|
| 848 |
+
Process an image with variable resolutions by dividing it into patches.
|
| 849 |
+
|
| 850 |
+
Args:
|
| 851 |
+
image (`np.ndarray`):
|
| 852 |
+
The input image to be processed.
|
| 853 |
+
grid_pinpoints (list[tuple[int, int]]):
|
| 854 |
+
A list of possible resolutions as tuples.
|
| 855 |
+
patch_size (`int`):
|
| 856 |
+
Size of the patches to divide the image into.
|
| 857 |
+
resample (`PILImageResampling`):
|
| 858 |
+
Resampling filter to use if resizing the image.
|
| 859 |
+
data_format (`ChannelDimension` or `str`):
|
| 860 |
+
The channel dimension format for the output image.
|
| 861 |
+
input_data_format (`ChannelDimension` or `str`):
|
| 862 |
+
The channel dimension format of the input image.
|
| 863 |
+
|
| 864 |
+
Returns:
|
| 865 |
+
`list[np.ndarray]`: A list of NumPy arrays containing the processed image patches.
|
| 866 |
+
"""
|
| 867 |
+
if not isinstance(grid_pinpoints, list):
|
| 868 |
+
raise TypeError("grid_pinpoints must be a list of possible resolutions.")
|
| 869 |
+
|
| 870 |
+
possible_resolutions = grid_pinpoints
|
| 871 |
+
|
| 872 |
+
image_size = get_image_size(image, channel_dim=input_data_format)
|
| 873 |
+
best_resolution = select_best_resolution(image_size, possible_resolutions)
|
| 874 |
+
resized_image = self._resize_for_patching(
|
| 875 |
+
image, best_resolution, resample=resample, input_data_format=input_data_format
|
| 876 |
+
)
|
| 877 |
+
padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
|
| 878 |
+
|
| 879 |
+
patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format)
|
| 880 |
+
|
| 881 |
+
# make sure that all patches are in the input data format
|
| 882 |
+
patches = [
|
| 883 |
+
to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format)
|
| 884 |
+
for patch in patches
|
| 885 |
+
]
|
| 886 |
+
return patches
|
| 887 |
+
|
| 888 |
+
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
|
| 889 |
+
"""
|
| 890 |
+
A utility that returns number of image patches for a given image size.
|
| 891 |
+
|
| 892 |
+
Args:
|
| 893 |
+
height (`int`):
|
| 894 |
+
Height of the input image.
|
| 895 |
+
width (`int`):
|
| 896 |
+
Width of the input image.
|
| 897 |
+
images_kwargs (`dict`, *optional*)
|
| 898 |
+
Any kwargs to override defaults of the image processor.
|
| 899 |
+
Returns:
|
| 900 |
+
`int`: Number of patches per image.
|
| 901 |
+
"""
|
| 902 |
+
split_image = images_kwargs.get("split_image", self.split_image)
|
| 903 |
+
max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
|
| 904 |
+
|
| 905 |
+
resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
|
| 906 |
+
num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
|
| 907 |
+
return num_patches
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
class AriaProcessorKwargs(ProcessingKwargs, total=False):
|
| 911 |
+
_defaults = {
|
| 912 |
+
"text_kwargs": {
|
| 913 |
+
"padding": False,
|
| 914 |
+
"return_mm_token_type_ids": False,
|
| 915 |
+
},
|
| 916 |
+
"images_kwargs": {
|
| 917 |
+
"max_image_size": 980,
|
| 918 |
+
"split_image": False,
|
| 919 |
+
},
|
| 920 |
+
"return_tensors": TensorType.PYTORCH,
|
| 921 |
+
}
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
class AriaProcessor(ProcessorMixin):
|
| 925 |
+
"""
|
| 926 |
+
AriaProcessor is a processor for the Aria model which wraps the Aria image preprocessor and the LLama slow tokenizer.
|
| 927 |
+
|
| 928 |
+
Args:
|
| 929 |
+
image_processor (`AriaImageProcessor`, *optional*):
|
| 930 |
+
The AriaImageProcessor to use for image preprocessing.
|
| 931 |
+
tokenizer (`PreTrainedTokenizerBase`, *optional*):
|
| 932 |
+
An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
|
| 933 |
+
chat_template (`str`, *optional*):
|
| 934 |
+
A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string.
|
| 935 |
+
size_conversion (`Dict`, *optional*):
|
| 936 |
+
A dictionary indicating size conversions for images.
|
| 937 |
+
"""
|
| 938 |
+
|
| 939 |
+
attributes = ["image_processor", "tokenizer"]
|
| 940 |
+
image_processor_class = "AriaImageProcessor"
|
| 941 |
+
tokenizer_class = "AutoTokenizer"
|
| 942 |
+
|
| 943 |
+
def __init__(
|
| 944 |
+
self,
|
| 945 |
+
image_processor=None,
|
| 946 |
+
tokenizer: Union[AutoTokenizer, str] = None,
|
| 947 |
+
chat_template: Optional[str] = None,
|
| 948 |
+
size_conversion: Optional[dict[Union[float, int], int]] = None,
|
| 949 |
+
):
|
| 950 |
+
if size_conversion is None:
|
| 951 |
+
size_conversion = {490: 128, 980: 256}
|
| 952 |
+
self.size_conversion = {int(k): v for k, v in size_conversion.items()}
|
| 953 |
+
|
| 954 |
+
self.image_token = tokenizer.image_token
|
| 955 |
+
self.image_token_id = tokenizer.image_token_id
|
| 956 |
+
if tokenizer is not None and tokenizer.pad_token is None:
|
| 957 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 958 |
+
|
| 959 |
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
| 960 |
+
|
| 961 |
+
def __call__(
|
| 962 |
+
self,
|
| 963 |
+
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
|
| 964 |
+
images: Optional[ImageInput] = None,
|
| 965 |
+
audio=None,
|
| 966 |
+
videos=None,
|
| 967 |
+
**kwargs: Unpack[AriaProcessorKwargs],
|
| 968 |
+
) -> BatchFeature:
|
| 969 |
+
"""
|
| 970 |
+
Main method to prepare for the model one or several sequences(s) and image(s).
|
| 971 |
+
|
| 972 |
+
Args:
|
| 973 |
+
text (`TextInput`, `PreTokenizedInput`, `list[TextInput]`, `list[PreTokenizedInput]`):
|
| 974 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 975 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 976 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 977 |
+
images (`ImageInput`):
|
| 978 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
| 979 |
+
tensor. Both channels-first and channels-last formats are supported.
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
Returns:
|
| 983 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
| 984 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 985 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
| 986 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
| 987 |
+
`None`).
|
| 988 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
| 989 |
+
- **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
|
| 990 |
+
"""
|
| 991 |
+
output_kwargs = self._merge_kwargs(
|
| 992 |
+
AriaProcessorKwargs,
|
| 993 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 994 |
+
**kwargs,
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
if isinstance(text, str):
|
| 998 |
+
text = [text]
|
| 999 |
+
elif not isinstance(text, list) and not isinstance(text[0], str):
|
| 1000 |
+
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
|
| 1001 |
+
|
| 1002 |
+
if images is not None:
|
| 1003 |
+
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
| 1004 |
+
# expand the image_token according to the num_crops and tokens per image
|
| 1005 |
+
tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
|
| 1006 |
+
prompt_strings = []
|
| 1007 |
+
num_crops = image_inputs.pop("num_crops") * tokens_per_image
|
| 1008 |
+
for sample in text:
|
| 1009 |
+
sample = sample.replace(self.tokenizer.image_token, self.tokenizer.image_token * num_crops)
|
| 1010 |
+
prompt_strings.append(sample)
|
| 1011 |
+
|
| 1012 |
+
else:
|
| 1013 |
+
image_inputs = {}
|
| 1014 |
+
prompt_strings = text
|
| 1015 |
+
|
| 1016 |
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 1017 |
+
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
| 1018 |
+
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
|
| 1019 |
+
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
| 1020 |
+
|
| 1021 |
+
if return_mm_token_type_ids:
|
| 1022 |
+
array_ids = np.array(text_inputs["input_ids"])
|
| 1023 |
+
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
| 1024 |
+
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
| 1025 |
+
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
| 1026 |
+
|
| 1027 |
+
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
| 1028 |
+
|
| 1029 |
+
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
| 1030 |
+
"""
|
| 1031 |
+
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
| 1032 |
+
Args:
|
| 1033 |
+
image_sizes (`list[list[int]]`, *optional*):
|
| 1034 |
+
The input sizes formatted as (height, width) per each image.
|
| 1035 |
+
Returns:
|
| 1036 |
+
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
| 1037 |
+
input modalities, along with other useful data.
|
| 1038 |
+
"""
|
| 1039 |
+
|
| 1040 |
+
vision_data = {}
|
| 1041 |
+
if image_sizes is not None:
|
| 1042 |
+
images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
|
| 1043 |
+
images_kwargs.update(kwargs)
|
| 1044 |
+
|
| 1045 |
+
max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
|
| 1046 |
+
num_image_patches = [
|
| 1047 |
+
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
|
| 1048 |
+
for image_size in image_sizes
|
| 1049 |
+
]
|
| 1050 |
+
num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
|
| 1051 |
+
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
| 1052 |
+
|
| 1053 |
+
return MultiModalData(**vision_data)
|
| 1054 |
+
|
| 1055 |
+
@property
|
| 1056 |
+
def model_input_names(self):
|
| 1057 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 1058 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 1059 |
+
|
| 1060 |
+
# Remove `num_crops`, it is popped and used only when processing. Make a copy of list when removing
|
| 1061 |
+
# otherwise `self.image_processor.model_input_names` is also modified
|
| 1062 |
+
image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
|
| 1063 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 1064 |
+
|
| 1065 |
+
|
| 1066 |
+
class AriaSharedExpertsMLP(LlamaMLP):
|
| 1067 |
+
"""
|
| 1068 |
+
Shared Expert MLP for shared experts.
|
| 1069 |
+
|
| 1070 |
+
Unlike routed experts, shared experts process all tokens without routing.
|
| 1071 |
+
This class reconfigures the intermediate size in comparison to the LlamaMLP.
|
| 1072 |
+
|
| 1073 |
+
Args:
|
| 1074 |
+
config (`AriaTextConfig`): Configuration object for the Aria language model.
|
| 1075 |
+
"""
|
| 1076 |
+
|
| 1077 |
+
def __init__(self, config: AriaTextConfig):
|
| 1078 |
+
super().__init__(config)
|
| 1079 |
+
self.intermediate_size = config.intermediate_size * config.moe_num_shared_experts
|
| 1080 |
+
|
| 1081 |
+
|
| 1082 |
+
class AriaGroupedExpertsGemm(nn.Module):
|
| 1083 |
+
"""
|
| 1084 |
+
Grouped GEMM (General Matrix Multiplication) module for efficient expert computation.
|
| 1085 |
+
This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm)
|
| 1086 |
+
for optimized performance. If the grouped_gemm library is not installed, it gracefully
|
| 1087 |
+
falls back to a sequential GEMM implementation, which may be slower but ensures
|
| 1088 |
+
functionality.
|
| 1089 |
+
|
| 1090 |
+
Args:
|
| 1091 |
+
in_features (`int`):
|
| 1092 |
+
Number of input features.
|
| 1093 |
+
out_features (`int`):
|
| 1094 |
+
Number of output features.
|
| 1095 |
+
groups (`int`):
|
| 1096 |
+
Number of expert groups.
|
| 1097 |
+
"""
|
| 1098 |
+
|
| 1099 |
+
def __init__(self, in_features, out_features, groups):
|
| 1100 |
+
super().__init__()
|
| 1101 |
+
self.in_features = in_features
|
| 1102 |
+
self.out_features = out_features
|
| 1103 |
+
self.groups = groups
|
| 1104 |
+
self.weight = nn.Parameter(torch.empty(groups, in_features, out_features))
|
| 1105 |
+
|
| 1106 |
+
def forward(self, input, tokens_per_expert):
|
| 1107 |
+
"""
|
| 1108 |
+
Perform grouped matrix multiplication.
|
| 1109 |
+
|
| 1110 |
+
Args:
|
| 1111 |
+
input (`torch.Tensor`):
|
| 1112 |
+
Input tensor of shape (num_tokens, in_features).
|
| 1113 |
+
tokens_per_expert (`torch.Tensor`):
|
| 1114 |
+
Number of tokens assigned to each expert.
|
| 1115 |
+
|
| 1116 |
+
Returns:
|
| 1117 |
+
torch.Tensor: Output tensor of shape (num_tokens, out_features).
|
| 1118 |
+
"""
|
| 1119 |
+
return sequential_experts_gemm(
|
| 1120 |
+
input,
|
| 1121 |
+
self.weight,
|
| 1122 |
+
tokens_per_expert.cpu(),
|
| 1123 |
+
)
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
class AriaGroupedExpertsMLP(nn.Module):
|
| 1127 |
+
"""
|
| 1128 |
+
Grouped MLP module for Mixture of Experts.
|
| 1129 |
+
|
| 1130 |
+
Args:
|
| 1131 |
+
config (`AriaTextConfig`):
|
| 1132 |
+
Configuration object for the model.
|
| 1133 |
+
"""
|
| 1134 |
+
|
| 1135 |
+
def __init__(self, config: AriaTextConfig) -> None:
|
| 1136 |
+
super().__init__()
|
| 1137 |
+
self.config = config
|
| 1138 |
+
self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts)
|
| 1139 |
+
self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts)
|
| 1140 |
+
|
| 1141 |
+
def forward(self, permuted_tokens, tokens_per_expert):
|
| 1142 |
+
"""
|
| 1143 |
+
Forward pass of the Grouped MLP.
|
| 1144 |
+
|
| 1145 |
+
Args:
|
| 1146 |
+
permuted_tokens (torch.Tensor): Permuted input tokens.
|
| 1147 |
+
tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
|
| 1148 |
+
|
| 1149 |
+
Returns:
|
| 1150 |
+
torch.Tensor: Output tensor after passing through the MLP.
|
| 1151 |
+
"""
|
| 1152 |
+
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
|
| 1153 |
+
projection, gate = torch.chunk(fc1_output, 2, dim=-1)
|
| 1154 |
+
fc1_output = nn.functional.silu(projection) * gate
|
| 1155 |
+
fc2_output = self.fc2(fc1_output, tokens_per_expert)
|
| 1156 |
+
return fc2_output
|
| 1157 |
+
|
| 1158 |
+
|
| 1159 |
+
# Token permutation adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587
|
| 1160 |
+
class AriaTextMoELayer(nn.Module):
|
| 1161 |
+
"""
|
| 1162 |
+
Aria Text Mixture of Experts (MoE) Layer.
|
| 1163 |
+
|
| 1164 |
+
This layer applies a gating mechanism to route input tokens to different experts.
|
| 1165 |
+
|
| 1166 |
+
Args:
|
| 1167 |
+
config (`AriaTextConfig`):
|
| 1168 |
+
Configuration object for the text component of the model.
|
| 1169 |
+
"""
|
| 1170 |
+
|
| 1171 |
+
def __init__(self, config: AriaTextConfig):
|
| 1172 |
+
super().__init__()
|
| 1173 |
+
|
| 1174 |
+
self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
|
| 1175 |
+
self.experts = AriaGroupedExpertsMLP(config)
|
| 1176 |
+
self.shared_experts = AriaSharedExpertsMLP(config)
|
| 1177 |
+
self.config = config
|
| 1178 |
+
|
| 1179 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 1180 |
+
"""
|
| 1181 |
+
Forward pass of the MoE Layer.
|
| 1182 |
+
|
| 1183 |
+
Args:
|
| 1184 |
+
hidden_states (`torch.Tensor`):
|
| 1185 |
+
Input tensor of shape (batch_size, sequence_length, hidden_size).
|
| 1186 |
+
|
| 1187 |
+
Returns:
|
| 1188 |
+
torch.Tensor: Output tensor after passing through the MoE layer.
|
| 1189 |
+
|
| 1190 |
+
Process:
|
| 1191 |
+
1. Route tokens to experts using the router.
|
| 1192 |
+
2. Permute tokens based on routing decisions.
|
| 1193 |
+
3. Process tokens through experts.
|
| 1194 |
+
4. Unpermute and combine expert outputs.
|
| 1195 |
+
5. Add shared expert output to the final result.
|
| 1196 |
+
"""
|
| 1197 |
+
original_shape = hidden_states.shape
|
| 1198 |
+
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
| 1199 |
+
|
| 1200 |
+
# Top K Routing
|
| 1201 |
+
logits = self.router(hidden_states)
|
| 1202 |
+
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
|
| 1203 |
+
scores = nn.functional.softmax(top_logits, dim=-1)
|
| 1204 |
+
|
| 1205 |
+
original_dtype = top_indices.dtype
|
| 1206 |
+
|
| 1207 |
+
tokens_per_expert = torch.histc(
|
| 1208 |
+
top_indices.flatten().to(torch.float32),
|
| 1209 |
+
bins=self.config.moe_num_experts,
|
| 1210 |
+
min=0,
|
| 1211 |
+
max=self.config.moe_num_experts - 1,
|
| 1212 |
+
).to(original_dtype)
|
| 1213 |
+
indices = top_indices
|
| 1214 |
+
|
| 1215 |
+
# Token permutation
|
| 1216 |
+
flatten_indices = indices.view(-1)
|
| 1217 |
+
sorted_indices = torch.argsort(flatten_indices)
|
| 1218 |
+
permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk)
|
| 1219 |
+
|
| 1220 |
+
# Process through experts
|
| 1221 |
+
expert_output = self.experts(permuted_tokens, tokens_per_expert)
|
| 1222 |
+
|
| 1223 |
+
# Token unpermutation
|
| 1224 |
+
unpermuted_tokens = torch.zeros(
|
| 1225 |
+
(scores.shape[0] * self.config.moe_topk, expert_output.size(1)),
|
| 1226 |
+
dtype=expert_output.dtype,
|
| 1227 |
+
device=expert_output.device,
|
| 1228 |
+
)
|
| 1229 |
+
unpermuted_tokens.index_copy_(0, sorted_indices, expert_output)
|
| 1230 |
+
unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1))
|
| 1231 |
+
|
| 1232 |
+
output = (unpermuted_tokens * scores.unsqueeze(-1)).sum(dim=1).view(original_shape)
|
| 1233 |
+
|
| 1234 |
+
# Add shared expert output
|
| 1235 |
+
shared_expert_output = self.shared_experts(hidden_states.view(original_shape))
|
| 1236 |
+
return output + shared_expert_output
|
| 1237 |
+
|
| 1238 |
+
|
| 1239 |
+
class AriaTextAttention(LlamaAttention):
|
| 1240 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 1241 |
+
|
| 1242 |
+
pass
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
class AriaTextDecoderLayer(LlamaDecoderLayer):
|
| 1246 |
+
"""
|
| 1247 |
+
Aria Text Decoder Layer.
|
| 1248 |
+
|
| 1249 |
+
This class defines a single decoder layer in the language model, incorporating self-attention and Mixture of Experts (MoE) feed-forward network.
|
| 1250 |
+
|
| 1251 |
+
Args:
|
| 1252 |
+
config (`AriaTextConfig`):
|
| 1253 |
+
Configuration object for the text component of the model.
|
| 1254 |
+
layer_idx (`int`):
|
| 1255 |
+
Index of the layer.
|
| 1256 |
+
"""
|
| 1257 |
+
|
| 1258 |
+
def __init__(self, config: AriaTextConfig, layer_idx: int):
|
| 1259 |
+
super().__init__(config, layer_idx)
|
| 1260 |
+
self.mlp = AriaTextMoELayer(config)
|
| 1261 |
+
|
| 1262 |
+
|
| 1263 |
+
@auto_docstring
|
| 1264 |
+
class AriaTextPreTrainedModel(PreTrainedModel):
|
| 1265 |
+
config: AriaTextConfig
|
| 1266 |
+
base_model_prefix = "model"
|
| 1267 |
+
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
|
| 1268 |
+
supports_gradient_checkpointing = True
|
| 1269 |
+
_skip_keys_device_placement = "past_key_values"
|
| 1270 |
+
_supports_flash_attn = True
|
| 1271 |
+
_supports_sdpa = True
|
| 1272 |
+
|
| 1273 |
+
_supports_attention_backend = True
|
| 1274 |
+
_can_record_outputs = {
|
| 1275 |
+
"hidden_states": AriaTextDecoderLayer,
|
| 1276 |
+
"attentions": AriaTextAttention,
|
| 1277 |
+
}
|
| 1278 |
+
|
| 1279 |
+
def _init_weights(self, module):
|
| 1280 |
+
super()._init_weights(module)
|
| 1281 |
+
if isinstance(module, AriaGroupedExpertsGemm):
|
| 1282 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 1283 |
+
|
| 1284 |
+
|
| 1285 |
+
class AriaPreTrainedModel(LlamaPreTrainedModel):
|
| 1286 |
+
config: AriaConfig
|
| 1287 |
+
base_model_prefix = ""
|
| 1288 |
+
_can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
|
| 1289 |
+
_supports_attention_backend = True
|
| 1290 |
+
|
| 1291 |
+
def _init_weights(self, module):
|
| 1292 |
+
PreTrainedModel._init_weights(self, module)
|
| 1293 |
+
if isinstance(module, AriaProjector):
|
| 1294 |
+
nn.init.trunc_normal_(module.query, std=self.config.initializer_range)
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
class AriaTextModel(LlamaModel):
|
| 1298 |
+
def __init__(self, config: AriaTextConfig):
|
| 1299 |
+
super().__init__(config)
|
| 1300 |
+
self.layers = nn.ModuleList(
|
| 1301 |
+
[AriaTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 1302 |
+
)
|
| 1303 |
+
self.gradient_checkpointing = False
|
| 1304 |
+
self.post_init()
|
| 1305 |
+
|
| 1306 |
+
|
| 1307 |
+
class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
|
| 1308 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1309 |
+
|
| 1310 |
+
def __init__(self, config: AriaTextConfig):
|
| 1311 |
+
super().__init__(config)
|
| 1312 |
+
self.model = AriaTextModel(config)
|
| 1313 |
+
self.vocab_size = config.vocab_size
|
| 1314 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1315 |
+
|
| 1316 |
+
# Initialize weights and apply final processing
|
| 1317 |
+
self.post_init()
|
| 1318 |
+
|
| 1319 |
+
@auto_docstring
|
| 1320 |
+
def forward(self, **super_kwargs):
|
| 1321 |
+
super().forward(self, **super_kwargs)
|
| 1322 |
+
|
| 1323 |
+
|
| 1324 |
+
class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
|
| 1325 |
+
pass
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
class AriaModelOutputWithPast(LlavaModelOutputWithPast):
|
| 1329 |
+
pass
|
| 1330 |
+
|
| 1331 |
+
|
| 1332 |
+
class AriaModel(LlavaModel):
|
| 1333 |
+
def __init__(self, config: AriaConfig):
|
| 1334 |
+
super().__init__(config)
|
| 1335 |
+
self.multi_modal_projector = AriaProjector(config)
|
| 1336 |
+
|
| 1337 |
+
def _create_patch_attention_mask(self, pixel_mask):
|
| 1338 |
+
if pixel_mask is None:
|
| 1339 |
+
return None
|
| 1340 |
+
|
| 1341 |
+
patches_subgrid = pixel_mask.unfold(
|
| 1342 |
+
dimension=1,
|
| 1343 |
+
size=self.vision_tower.config.patch_size,
|
| 1344 |
+
step=self.vision_tower.config.patch_size,
|
| 1345 |
+
)
|
| 1346 |
+
patches_subgrid = patches_subgrid.unfold(
|
| 1347 |
+
dimension=2,
|
| 1348 |
+
size=self.vision_tower.config.patch_size,
|
| 1349 |
+
step=self.vision_tower.config.patch_size,
|
| 1350 |
+
)
|
| 1351 |
+
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
| 1352 |
+
|
| 1353 |
+
def get_image_features(
|
| 1354 |
+
self,
|
| 1355 |
+
pixel_values: torch.FloatTensor,
|
| 1356 |
+
pixel_mask: Optional[torch.FloatTensor] = None,
|
| 1357 |
+
vision_feature_layer: int = -1,
|
| 1358 |
+
):
|
| 1359 |
+
"""
|
| 1360 |
+
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
| 1361 |
+
|
| 1362 |
+
Args:
|
| 1363 |
+
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
| 1364 |
+
The tensors corresponding to the input images.
|
| 1365 |
+
pixel_mask (`torch.FloatTensor]`, *optional*):
|
| 1366 |
+
The tensors corresponding to the input image mask.
|
| 1367 |
+
vision_feature_layer (`Union[int, list[int]]`, *optional*):
|
| 1368 |
+
The index of the layer to select the vision feature. If multiple indices are provided,
|
| 1369 |
+
the vision feature of the corresponding indices will be concatenated to form the
|
| 1370 |
+
vision features.
|
| 1371 |
+
Returns:
|
| 1372 |
+
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
| 1373 |
+
"""
|
| 1374 |
+
vision_feature_layer = (
|
| 1375 |
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
| 1376 |
+
)
|
| 1377 |
+
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
| 1378 |
+
image_outputs = self.vision_tower(
|
| 1379 |
+
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
|
| 1380 |
+
)
|
| 1381 |
+
image_attn_mask = None
|
| 1382 |
+
if patch_attention_mask is not None:
|
| 1383 |
+
flattened_mask = patch_attention_mask.flatten(1)
|
| 1384 |
+
image_attn_mask = torch.logical_not(flattened_mask)
|
| 1385 |
+
|
| 1386 |
+
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
| 1387 |
+
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
|
| 1388 |
+
return image_features
|
| 1389 |
+
|
| 1390 |
+
def forward(
|
| 1391 |
+
self,
|
| 1392 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1393 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1394 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1395 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1396 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1397 |
+
past_key_values: Optional[Cache] = None,
|
| 1398 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1399 |
+
use_cache: Optional[bool] = None,
|
| 1400 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1401 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 1402 |
+
) -> Union[tuple, AriaModelOutputWithPast]:
|
| 1403 |
+
if inputs_embeds is None:
|
| 1404 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 1405 |
+
|
| 1406 |
+
# 2. Merge text and images
|
| 1407 |
+
if pixel_values is not None and inputs_embeds.shape[1] != 1:
|
| 1408 |
+
image_features = self.get_image_features(
|
| 1409 |
+
pixel_values=pixel_values,
|
| 1410 |
+
pixel_mask=pixel_mask,
|
| 1411 |
+
vision_feature_layer=self.config.vision_feature_layer,
|
| 1412 |
+
)
|
| 1413 |
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 1414 |
+
special_image_mask = self.get_placeholder_mask(
|
| 1415 |
+
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
| 1416 |
+
)
|
| 1417 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
| 1418 |
+
|
| 1419 |
+
outputs = self.language_model(
|
| 1420 |
+
attention_mask=attention_mask,
|
| 1421 |
+
position_ids=position_ids,
|
| 1422 |
+
past_key_values=past_key_values,
|
| 1423 |
+
inputs_embeds=inputs_embeds,
|
| 1424 |
+
use_cache=use_cache,
|
| 1425 |
+
cache_position=cache_position,
|
| 1426 |
+
**kwargs,
|
| 1427 |
+
)
|
| 1428 |
+
|
| 1429 |
+
return AriaModelOutputWithPast(
|
| 1430 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 1431 |
+
past_key_values=outputs.past_key_values if use_cache else None,
|
| 1432 |
+
hidden_states=outputs.hidden_states,
|
| 1433 |
+
attentions=outputs.attentions,
|
| 1434 |
+
image_hidden_states=image_features if pixel_values is not None else None,
|
| 1435 |
+
)
|
| 1436 |
+
|
| 1437 |
+
|
| 1438 |
+
@auto_docstring(
|
| 1439 |
+
custom_intro="""
|
| 1440 |
+
Aria model for conditional generation tasks.
|
| 1441 |
+
|
| 1442 |
+
This model combines a vision tower, a multi-modal projector, and a language model
|
| 1443 |
+
to perform tasks that involve both image and text inputs.
|
| 1444 |
+
"""
|
| 1445 |
+
)
|
| 1446 |
+
class AriaForConditionalGeneration(LlavaForConditionalGeneration):
|
| 1447 |
+
def get_image_features(
|
| 1448 |
+
self,
|
| 1449 |
+
pixel_values: torch.FloatTensor,
|
| 1450 |
+
pixel_mask: Optional[torch.FloatTensor] = None,
|
| 1451 |
+
vision_feature_layer: int = -1,
|
| 1452 |
+
):
|
| 1453 |
+
return self.model.get_image_features(
|
| 1454 |
+
pixel_values=pixel_values,
|
| 1455 |
+
pixel_mask=pixel_mask,
|
| 1456 |
+
vision_feature_layer=vision_feature_layer,
|
| 1457 |
+
)
|
| 1458 |
+
|
| 1459 |
+
@can_return_tuple
|
| 1460 |
+
@auto_docstring
|
| 1461 |
+
def forward(
|
| 1462 |
+
self,
|
| 1463 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1464 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1465 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1466 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1467 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1468 |
+
past_key_values: Optional[Cache] = None,
|
| 1469 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1470 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1471 |
+
use_cache: Optional[bool] = None,
|
| 1472 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1473 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1474 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1475 |
+
) -> Union[tuple, AriaCausalLMOutputWithPast]:
|
| 1476 |
+
r"""
|
| 1477 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1478 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1479 |
+
config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `AriaForConditionalGeneration`).
|
| 1480 |
+
Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
|
| 1481 |
+
computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1482 |
+
|
| 1483 |
+
Example:
|
| 1484 |
+
|
| 1485 |
+
```python
|
| 1486 |
+
>>> import requests
|
| 1487 |
+
>>> import torch
|
| 1488 |
+
>>> from PIL import Image
|
| 1489 |
+
>>> from io import BytesIO
|
| 1490 |
+
|
| 1491 |
+
>>> from transformers import AutoProcessor, AutoModel
|
| 1492 |
+
>>> from transformers.image_utils import load_image
|
| 1493 |
+
|
| 1494 |
+
>>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
|
| 1495 |
+
>>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
|
| 1496 |
+
>>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
|
| 1497 |
+
>>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
|
| 1498 |
+
|
| 1499 |
+
>>> processor = AutoProcessor.from_pretrained("Rhymes-AI/Aria")
|
| 1500 |
+
>>> model = AutoModel.from_pretrained("Rhymes-AI/Aria", dtype=torch.bfloat16, device_map="auto")
|
| 1501 |
+
|
| 1502 |
+
>>> # Create inputs
|
| 1503 |
+
>>> messages = [
|
| 1504 |
+
... {
|
| 1505 |
+
... "role": "user",
|
| 1506 |
+
... "content": [
|
| 1507 |
+
... {"type": "image"},
|
| 1508 |
+
... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
|
| 1509 |
+
... {"type": "image"},
|
| 1510 |
+
... {"type": "text", "text": "What can we see in this image?"},
|
| 1511 |
+
... ]
|
| 1512 |
+
... },
|
| 1513 |
+
... {
|
| 1514 |
+
... "role": "user",
|
| 1515 |
+
... "content": [
|
| 1516 |
+
... {"type": "image"},
|
| 1517 |
+
... {"type": "text", "text": "In which city is that bridge located?"},
|
| 1518 |
+
... ]
|
| 1519 |
+
... }
|
| 1520 |
+
... ]
|
| 1521 |
+
|
| 1522 |
+
>>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
|
| 1523 |
+
>>> images = [[image1, image2], [image3]]
|
| 1524 |
+
>>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
|
| 1525 |
+
|
| 1526 |
+
>>> # Generate
|
| 1527 |
+
>>> generated_ids = model.generate(**inputs, max_new_tokens=256)
|
| 1528 |
+
>>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 1529 |
+
|
| 1530 |
+
>>> print(generated_texts[0])
|
| 1531 |
+
Assistant: There are buildings, trees, lights, and water visible in this image.
|
| 1532 |
+
|
| 1533 |
+
>>> print(generated_texts[1])
|
| 1534 |
+
Assistant: The bridge is in San Francisco.
|
| 1535 |
+
```"""
|
| 1536 |
+
outputs = self.model(
|
| 1537 |
+
input_ids=input_ids,
|
| 1538 |
+
pixel_values=pixel_values,
|
| 1539 |
+
pixel_mask=pixel_mask,
|
| 1540 |
+
attention_mask=attention_mask,
|
| 1541 |
+
position_ids=position_ids,
|
| 1542 |
+
past_key_values=past_key_values,
|
| 1543 |
+
inputs_embeds=inputs_embeds,
|
| 1544 |
+
use_cache=use_cache,
|
| 1545 |
+
cache_position=cache_position,
|
| 1546 |
+
**kwargs,
|
| 1547 |
+
)
|
| 1548 |
+
|
| 1549 |
+
hidden_states = outputs[0]
|
| 1550 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1551 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1552 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 1553 |
+
|
| 1554 |
+
loss = None
|
| 1555 |
+
if labels is not None:
|
| 1556 |
+
loss = self.loss_function(
|
| 1557 |
+
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
| 1558 |
+
)
|
| 1559 |
+
|
| 1560 |
+
return AriaCausalLMOutputWithPast(
|
| 1561 |
+
loss=loss,
|
| 1562 |
+
logits=logits,
|
| 1563 |
+
past_key_values=outputs.past_key_values,
|
| 1564 |
+
hidden_states=outputs.hidden_states,
|
| 1565 |
+
attentions=outputs.attentions,
|
| 1566 |
+
)
|
| 1567 |
+
|
| 1568 |
+
def prepare_inputs_for_generation(
|
| 1569 |
+
self,
|
| 1570 |
+
input_ids,
|
| 1571 |
+
past_key_values=None,
|
| 1572 |
+
inputs_embeds=None,
|
| 1573 |
+
pixel_values=None,
|
| 1574 |
+
pixel_mask=None,
|
| 1575 |
+
attention_mask=None,
|
| 1576 |
+
cache_position=None,
|
| 1577 |
+
logits_to_keep=None,
|
| 1578 |
+
**kwargs,
|
| 1579 |
+
):
|
| 1580 |
+
model_inputs = super().prepare_inputs_for_generation(
|
| 1581 |
+
input_ids,
|
| 1582 |
+
past_key_values=past_key_values,
|
| 1583 |
+
inputs_embeds=inputs_embeds,
|
| 1584 |
+
attention_mask=attention_mask,
|
| 1585 |
+
cache_position=cache_position,
|
| 1586 |
+
logits_to_keep=logits_to_keep,
|
| 1587 |
+
**kwargs,
|
| 1588 |
+
)
|
| 1589 |
+
|
| 1590 |
+
if cache_position[0] == 0:
|
| 1591 |
+
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
| 1592 |
+
# Otherwise we need pixel values to be passed to model
|
| 1593 |
+
model_inputs["pixel_values"] = pixel_values
|
| 1594 |
+
model_inputs["pixel_mask"] = pixel_mask
|
| 1595 |
+
|
| 1596 |
+
return model_inputs
|
| 1597 |
+
|
| 1598 |
+
|
| 1599 |
+
__all__ = [
|
| 1600 |
+
"AriaConfig",
|
| 1601 |
+
"AriaTextConfig",
|
| 1602 |
+
"AriaImageProcessor",
|
| 1603 |
+
"AriaProcessor",
|
| 1604 |
+
"AriaForConditionalGeneration",
|
| 1605 |
+
"AriaPreTrainedModel",
|
| 1606 |
+
"AriaTextPreTrainedModel",
|
| 1607 |
+
"AriaTextModel",
|
| 1608 |
+
"AriaModel",
|
| 1609 |
+
"AriaTextForCausalLM",
|
| 1610 |
+
]
|
venv/lib/python3.13/site-packages/transformers/models/aria/processing_aria.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/aria/modular_aria.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_aria.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
from typing import Optional, Union
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
from ...image_processing_utils import BatchFeature
|
| 26 |
+
from ...image_utils import ImageInput
|
| 27 |
+
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
| 28 |
+
from ...tokenization_utils import PreTokenizedInput, TextInput
|
| 29 |
+
from ...utils import TensorType
|
| 30 |
+
from ..auto import AutoTokenizer
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class AriaProcessorKwargs(ProcessingKwargs, total=False):
|
| 34 |
+
_defaults = {
|
| 35 |
+
"text_kwargs": {
|
| 36 |
+
"padding": False,
|
| 37 |
+
"return_mm_token_type_ids": False,
|
| 38 |
+
},
|
| 39 |
+
"images_kwargs": {
|
| 40 |
+
"max_image_size": 980,
|
| 41 |
+
"split_image": False,
|
| 42 |
+
},
|
| 43 |
+
"return_tensors": TensorType.PYTORCH,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class AriaProcessor(ProcessorMixin):
|
| 48 |
+
"""
|
| 49 |
+
AriaProcessor is a processor for the Aria model which wraps the Aria image preprocessor and the LLama slow tokenizer.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
image_processor (`AriaImageProcessor`, *optional*):
|
| 53 |
+
The AriaImageProcessor to use for image preprocessing.
|
| 54 |
+
tokenizer (`PreTrainedTokenizerBase`, *optional*):
|
| 55 |
+
An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
|
| 56 |
+
chat_template (`str`, *optional*):
|
| 57 |
+
A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string.
|
| 58 |
+
size_conversion (`Dict`, *optional*):
|
| 59 |
+
A dictionary indicating size conversions for images.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
attributes = ["image_processor", "tokenizer"]
|
| 63 |
+
image_processor_class = "AriaImageProcessor"
|
| 64 |
+
tokenizer_class = "AutoTokenizer"
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
image_processor=None,
|
| 69 |
+
tokenizer: Union[AutoTokenizer, str] = None,
|
| 70 |
+
chat_template: Optional[str] = None,
|
| 71 |
+
size_conversion: Optional[dict[Union[float, int], int]] = None,
|
| 72 |
+
):
|
| 73 |
+
if size_conversion is None:
|
| 74 |
+
size_conversion = {490: 128, 980: 256}
|
| 75 |
+
self.size_conversion = {int(k): v for k, v in size_conversion.items()}
|
| 76 |
+
|
| 77 |
+
self.image_token = tokenizer.image_token
|
| 78 |
+
self.image_token_id = tokenizer.image_token_id
|
| 79 |
+
if tokenizer is not None and tokenizer.pad_token is None:
|
| 80 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 81 |
+
|
| 82 |
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
| 83 |
+
|
| 84 |
+
def __call__(
|
| 85 |
+
self,
|
| 86 |
+
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
|
| 87 |
+
images: Optional[ImageInput] = None,
|
| 88 |
+
audio=None,
|
| 89 |
+
videos=None,
|
| 90 |
+
**kwargs: Unpack[AriaProcessorKwargs],
|
| 91 |
+
) -> BatchFeature:
|
| 92 |
+
"""
|
| 93 |
+
Main method to prepare for the model one or several sequences(s) and image(s).
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
text (`TextInput`, `PreTokenizedInput`, `list[TextInput]`, `list[PreTokenizedInput]`):
|
| 97 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 98 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 99 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 100 |
+
images (`ImageInput`):
|
| 101 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
| 102 |
+
tensor. Both channels-first and channels-last formats are supported.
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
| 107 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 108 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
| 109 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
| 110 |
+
`None`).
|
| 111 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
| 112 |
+
- **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
|
| 113 |
+
"""
|
| 114 |
+
output_kwargs = self._merge_kwargs(
|
| 115 |
+
AriaProcessorKwargs,
|
| 116 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 117 |
+
**kwargs,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if isinstance(text, str):
|
| 121 |
+
text = [text]
|
| 122 |
+
elif not isinstance(text, list) and not isinstance(text[0], str):
|
| 123 |
+
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
|
| 124 |
+
|
| 125 |
+
if images is not None:
|
| 126 |
+
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
| 127 |
+
# expand the image_token according to the num_crops and tokens per image
|
| 128 |
+
tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
|
| 129 |
+
prompt_strings = []
|
| 130 |
+
num_crops = image_inputs.pop("num_crops") * tokens_per_image
|
| 131 |
+
for sample in text:
|
| 132 |
+
sample = sample.replace(self.tokenizer.image_token, self.tokenizer.image_token * num_crops)
|
| 133 |
+
prompt_strings.append(sample)
|
| 134 |
+
|
| 135 |
+
else:
|
| 136 |
+
image_inputs = {}
|
| 137 |
+
prompt_strings = text
|
| 138 |
+
|
| 139 |
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 140 |
+
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
| 141 |
+
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
|
| 142 |
+
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
| 143 |
+
|
| 144 |
+
if return_mm_token_type_ids:
|
| 145 |
+
array_ids = np.array(text_inputs["input_ids"])
|
| 146 |
+
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
| 147 |
+
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
| 148 |
+
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
| 149 |
+
|
| 150 |
+
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
| 151 |
+
|
| 152 |
+
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
| 153 |
+
"""
|
| 154 |
+
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
| 155 |
+
Args:
|
| 156 |
+
image_sizes (`list[list[int]]`, *optional*):
|
| 157 |
+
The input sizes formatted as (height, width) per each image.
|
| 158 |
+
Returns:
|
| 159 |
+
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
| 160 |
+
input modalities, along with other useful data.
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
vision_data = {}
|
| 164 |
+
if image_sizes is not None:
|
| 165 |
+
images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
|
| 166 |
+
images_kwargs.update(kwargs)
|
| 167 |
+
|
| 168 |
+
max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
|
| 169 |
+
num_image_patches = [
|
| 170 |
+
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
|
| 171 |
+
for image_size in image_sizes
|
| 172 |
+
]
|
| 173 |
+
num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
|
| 174 |
+
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
| 175 |
+
|
| 176 |
+
return MultiModalData(**vision_data)
|
| 177 |
+
|
| 178 |
+
@property
|
| 179 |
+
def model_input_names(self):
|
| 180 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 181 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 182 |
+
|
| 183 |
+
# Remove `num_crops`, it is popped and used only when processing. Make a copy of list when removing
|
| 184 |
+
# otherwise `self.image_processor.model_input_names` is also modified
|
| 185 |
+
image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
|
| 186 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
__all__ = ["AriaProcessor"]
|
venv/lib/python3.13/site-packages/transformers/models/auto/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .auto_factory import *
|
| 22 |
+
from .configuration_auto import *
|
| 23 |
+
from .feature_extraction_auto import *
|
| 24 |
+
from .image_processing_auto import *
|
| 25 |
+
from .modeling_auto import *
|
| 26 |
+
from .modeling_flax_auto import *
|
| 27 |
+
from .modeling_tf_auto import *
|
| 28 |
+
from .processing_auto import *
|
| 29 |
+
from .tokenization_auto import *
|
| 30 |
+
from .video_processing_auto import *
|
| 31 |
+
else:
|
| 32 |
+
import sys
|
| 33 |
+
|
| 34 |
+
_file = globals()["__file__"]
|
| 35 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
venv/lib/python3.13/site-packages/transformers/models/auto/auto_factory.py
ADDED
|
@@ -0,0 +1,882 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Factory function to build auto-model classes."""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
import importlib
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import warnings
|
| 22 |
+
from collections import OrderedDict
|
| 23 |
+
from collections.abc import Iterator
|
| 24 |
+
from typing import Any, TypeVar, Union
|
| 25 |
+
|
| 26 |
+
from ...configuration_utils import PretrainedConfig
|
| 27 |
+
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
| 28 |
+
from ...utils import (
|
| 29 |
+
CONFIG_NAME,
|
| 30 |
+
cached_file,
|
| 31 |
+
copy_func,
|
| 32 |
+
extract_commit_hash,
|
| 33 |
+
find_adapter_config_file,
|
| 34 |
+
is_peft_available,
|
| 35 |
+
is_torch_available,
|
| 36 |
+
logging,
|
| 37 |
+
requires_backends,
|
| 38 |
+
)
|
| 39 |
+
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if is_torch_available():
|
| 43 |
+
from ...generation import GenerationMixin
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
_T = TypeVar("_T")
|
| 49 |
+
# Tokenizers will depend on packages installed, too much variance and there are no common base or Protocol
|
| 50 |
+
_LazyAutoMappingValue = tuple[Union[type[Any], None], Union[type[Any], None]]
|
| 51 |
+
|
| 52 |
+
CLASS_DOCSTRING = """
|
| 53 |
+
This is a generic model class that will be instantiated as one of the model classes of the library when created
|
| 54 |
+
with the [`~BaseAutoModelClass.from_pretrained`] class method or the [`~BaseAutoModelClass.from_config`] class
|
| 55 |
+
method.
|
| 56 |
+
|
| 57 |
+
This class cannot be instantiated directly using `__init__()` (throws an error).
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
FROM_CONFIG_DOCSTRING = """
|
| 61 |
+
Instantiates one of the model classes of the library from a configuration.
|
| 62 |
+
|
| 63 |
+
Note:
|
| 64 |
+
Loading a model from its configuration file does **not** load the model weights. It only affects the
|
| 65 |
+
model's configuration. Use [`~BaseAutoModelClass.from_pretrained`] to load the model weights.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
config ([`PretrainedConfig`]):
|
| 69 |
+
The model class to instantiate is selected based on the configuration class:
|
| 70 |
+
|
| 71 |
+
List options
|
| 72 |
+
attn_implementation (`str`, *optional*):
|
| 73 |
+
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
|
| 74 |
+
|
| 75 |
+
Examples:
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
>>> from transformers import AutoConfig, BaseAutoModelClass
|
| 79 |
+
|
| 80 |
+
>>> # Download configuration from huggingface.co and cache.
|
| 81 |
+
>>> config = AutoConfig.from_pretrained("checkpoint_placeholder")
|
| 82 |
+
>>> model = BaseAutoModelClass.from_config(config)
|
| 83 |
+
```
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
FROM_PRETRAINED_TORCH_DOCSTRING = """
|
| 87 |
+
Instantiate one of the model classes of the library from a pretrained model.
|
| 88 |
+
|
| 89 |
+
The model class to instantiate is selected based on the `model_type` property of the config object (either
|
| 90 |
+
passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
|
| 91 |
+
falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 92 |
+
|
| 93 |
+
List options
|
| 94 |
+
|
| 95 |
+
The model is set in evaluation mode by default using `model.eval()` (so for instance, dropout modules are
|
| 96 |
+
deactivated). To train the model, you should first set it back in training mode with `model.train()`
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 100 |
+
Can be either:
|
| 101 |
+
|
| 102 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
| 103 |
+
- A path to a *directory* containing model weights saved using
|
| 104 |
+
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
| 105 |
+
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
|
| 106 |
+
this case, `from_tf` should be set to `True` and a configuration object should be provided as
|
| 107 |
+
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
|
| 108 |
+
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
| 109 |
+
model_args (additional positional arguments, *optional*):
|
| 110 |
+
Will be passed along to the underlying model `__init__()` method.
|
| 111 |
+
config ([`PretrainedConfig`], *optional*):
|
| 112 |
+
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
| 113 |
+
be automatically loaded when:
|
| 114 |
+
|
| 115 |
+
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
|
| 116 |
+
model).
|
| 117 |
+
- The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
|
| 118 |
+
save directory.
|
| 119 |
+
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
| 120 |
+
configuration JSON file named *config.json* is found in the directory.
|
| 121 |
+
state_dict (*dict[str, torch.Tensor]*, *optional*):
|
| 122 |
+
A state dictionary to use instead of a state dictionary loaded from saved weights file.
|
| 123 |
+
|
| 124 |
+
This option can be used if you want to create a model from a pretrained configuration but load your own
|
| 125 |
+
weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and
|
| 126 |
+
[`~PreTrainedModel.from_pretrained`] is not a simpler option.
|
| 127 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 128 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 129 |
+
standard cache should not be used.
|
| 130 |
+
from_tf (`bool`, *optional*, defaults to `False`):
|
| 131 |
+
Load the model weights from a TensorFlow checkpoint save file (see docstring of
|
| 132 |
+
`pretrained_model_name_or_path` argument).
|
| 133 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 134 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 135 |
+
cached versions if they exist.
|
| 136 |
+
resume_download:
|
| 137 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 138 |
+
Will be removed in v5 of Transformers.
|
| 139 |
+
proxies (`dict[str, str]`, *optional*):
|
| 140 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 141 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 142 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
| 143 |
+
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
| 144 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 145 |
+
Whether or not to only look at local files (e.g., not try downloading the model).
|
| 146 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 147 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 148 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 149 |
+
identifier allowed by git.
|
| 150 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 151 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 152 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 153 |
+
execute code present on the Hub on your local machine.
|
| 154 |
+
code_revision (`str`, *optional*, defaults to `"main"`):
|
| 155 |
+
The specific revision to use for the code on the Hub, if the code leaves in a different repository than
|
| 156 |
+
the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
|
| 157 |
+
system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
|
| 158 |
+
allowed by git.
|
| 159 |
+
kwargs (additional keyword arguments, *optional*):
|
| 160 |
+
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
| 161 |
+
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
| 162 |
+
automatically loaded:
|
| 163 |
+
|
| 164 |
+
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
|
| 165 |
+
underlying model's `__init__` method (we assume all relevant updates to the configuration have
|
| 166 |
+
already been done)
|
| 167 |
+
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
|
| 168 |
+
initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
|
| 169 |
+
corresponds to a configuration attribute will be used to override said attribute with the
|
| 170 |
+
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
|
| 171 |
+
will be passed to the underlying model's `__init__` function.
|
| 172 |
+
|
| 173 |
+
Examples:
|
| 174 |
+
|
| 175 |
+
```python
|
| 176 |
+
>>> from transformers import AutoConfig, BaseAutoModelClass
|
| 177 |
+
|
| 178 |
+
>>> # Download model and configuration from huggingface.co and cache.
|
| 179 |
+
>>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
|
| 180 |
+
|
| 181 |
+
>>> # Update configuration during loading
|
| 182 |
+
>>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
|
| 183 |
+
>>> model.config.output_attentions
|
| 184 |
+
True
|
| 185 |
+
|
| 186 |
+
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
| 187 |
+
>>> config = AutoConfig.from_pretrained("./tf_model/shortcut_placeholder_tf_model_config.json")
|
| 188 |
+
>>> model = BaseAutoModelClass.from_pretrained(
|
| 189 |
+
... "./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index", from_tf=True, config=config
|
| 190 |
+
... )
|
| 191 |
+
```
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
FROM_PRETRAINED_TF_DOCSTRING = """
|
| 195 |
+
Instantiate one of the model classes of the library from a pretrained model.
|
| 196 |
+
|
| 197 |
+
The model class to instantiate is selected based on the `model_type` property of the config object (either
|
| 198 |
+
passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
|
| 199 |
+
falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 200 |
+
|
| 201 |
+
List options
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 205 |
+
Can be either:
|
| 206 |
+
|
| 207 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
| 208 |
+
- A path to a *directory* containing model weights saved using
|
| 209 |
+
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
| 210 |
+
- A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
|
| 211 |
+
case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
|
| 212 |
+
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
|
| 213 |
+
using the provided conversion scripts and loading the TensorFlow model afterwards.
|
| 214 |
+
model_args (additional positional arguments, *optional*):
|
| 215 |
+
Will be passed along to the underlying model `__init__()` method.
|
| 216 |
+
config ([`PretrainedConfig`], *optional*):
|
| 217 |
+
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
| 218 |
+
be automatically loaded when:
|
| 219 |
+
|
| 220 |
+
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
|
| 221 |
+
model).
|
| 222 |
+
- The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
|
| 223 |
+
save directory.
|
| 224 |
+
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
| 225 |
+
configuration JSON file named *config.json* is found in the directory.
|
| 226 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 227 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 228 |
+
standard cache should not be used.
|
| 229 |
+
from_pt (`bool`, *optional*, defaults to `False`):
|
| 230 |
+
Load the model weights from a PyTorch checkpoint save file (see docstring of
|
| 231 |
+
`pretrained_model_name_or_path` argument).
|
| 232 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 233 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 234 |
+
cached versions if they exist.
|
| 235 |
+
resume_download:
|
| 236 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 237 |
+
Will be removed in v5 of Transformers.
|
| 238 |
+
proxies (`dict[str, str]`, *optional*):
|
| 239 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 240 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 241 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
| 242 |
+
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
| 243 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 244 |
+
Whether or not to only look at local files (e.g., not try downloading the model).
|
| 245 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 246 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 247 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 248 |
+
identifier allowed by git.
|
| 249 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 250 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 251 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 252 |
+
execute code present on the Hub on your local machine.
|
| 253 |
+
code_revision (`str`, *optional*, defaults to `"main"`):
|
| 254 |
+
The specific revision to use for the code on the Hub, if the code leaves in a different repository than
|
| 255 |
+
the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
|
| 256 |
+
system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
|
| 257 |
+
allowed by git.
|
| 258 |
+
kwargs (additional keyword arguments, *optional*):
|
| 259 |
+
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
| 260 |
+
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
| 261 |
+
automatically loaded:
|
| 262 |
+
|
| 263 |
+
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
|
| 264 |
+
underlying model's `__init__` method (we assume all relevant updates to the configuration have
|
| 265 |
+
already been done)
|
| 266 |
+
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
|
| 267 |
+
initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
|
| 268 |
+
corresponds to a configuration attribute will be used to override said attribute with the
|
| 269 |
+
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
|
| 270 |
+
will be passed to the underlying model's `__init__` function.
|
| 271 |
+
|
| 272 |
+
Examples:
|
| 273 |
+
|
| 274 |
+
```python
|
| 275 |
+
>>> from transformers import AutoConfig, BaseAutoModelClass
|
| 276 |
+
|
| 277 |
+
>>> # Download model and configuration from huggingface.co and cache.
|
| 278 |
+
>>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
|
| 279 |
+
|
| 280 |
+
>>> # Update configuration during loading
|
| 281 |
+
>>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
|
| 282 |
+
>>> model.config.output_attentions
|
| 283 |
+
True
|
| 284 |
+
|
| 285 |
+
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
|
| 286 |
+
>>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json")
|
| 287 |
+
>>> model = BaseAutoModelClass.from_pretrained(
|
| 288 |
+
... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config
|
| 289 |
+
... )
|
| 290 |
+
```
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
FROM_PRETRAINED_FLAX_DOCSTRING = """
|
| 294 |
+
Instantiate one of the model classes of the library from a pretrained model.
|
| 295 |
+
|
| 296 |
+
The model class to instantiate is selected based on the `model_type` property of the config object (either
|
| 297 |
+
passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
|
| 298 |
+
falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 299 |
+
|
| 300 |
+
List options
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 304 |
+
Can be either:
|
| 305 |
+
|
| 306 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
| 307 |
+
- A path to a *directory* containing model weights saved using
|
| 308 |
+
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
| 309 |
+
- A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
|
| 310 |
+
case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
|
| 311 |
+
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
|
| 312 |
+
using the provided conversion scripts and loading the TensorFlow model afterwards.
|
| 313 |
+
model_args (additional positional arguments, *optional*):
|
| 314 |
+
Will be passed along to the underlying model `__init__()` method.
|
| 315 |
+
config ([`PretrainedConfig`], *optional*):
|
| 316 |
+
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
| 317 |
+
be automatically loaded when:
|
| 318 |
+
|
| 319 |
+
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
|
| 320 |
+
model).
|
| 321 |
+
- The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
|
| 322 |
+
save directory.
|
| 323 |
+
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
| 324 |
+
configuration JSON file named *config.json* is found in the directory.
|
| 325 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 326 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 327 |
+
standard cache should not be used.
|
| 328 |
+
from_pt (`bool`, *optional*, defaults to `False`):
|
| 329 |
+
Load the model weights from a PyTorch checkpoint save file (see docstring of
|
| 330 |
+
`pretrained_model_name_or_path` argument).
|
| 331 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 332 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 333 |
+
cached versions if they exist.
|
| 334 |
+
resume_download:
|
| 335 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 336 |
+
Will be removed in v5 of Transformers.
|
| 337 |
+
proxies (`dict[str, str]`, *optional*):
|
| 338 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 339 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 340 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
| 341 |
+
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
| 342 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 343 |
+
Whether or not to only look at local files (e.g., not try downloading the model).
|
| 344 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 345 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 346 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 347 |
+
identifier allowed by git.
|
| 348 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 349 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 350 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 351 |
+
execute code present on the Hub on your local machine.
|
| 352 |
+
code_revision (`str`, *optional*, defaults to `"main"`):
|
| 353 |
+
The specific revision to use for the code on the Hub, if the code leaves in a different repository than
|
| 354 |
+
the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
|
| 355 |
+
system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
|
| 356 |
+
allowed by git.
|
| 357 |
+
kwargs (additional keyword arguments, *optional*):
|
| 358 |
+
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
| 359 |
+
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
| 360 |
+
automatically loaded:
|
| 361 |
+
|
| 362 |
+
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
|
| 363 |
+
underlying model's `__init__` method (we assume all relevant updates to the configuration have
|
| 364 |
+
already been done)
|
| 365 |
+
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
|
| 366 |
+
initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
|
| 367 |
+
corresponds to a configuration attribute will be used to override said attribute with the
|
| 368 |
+
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
|
| 369 |
+
will be passed to the underlying model's `__init__` function.
|
| 370 |
+
|
| 371 |
+
Examples:
|
| 372 |
+
|
| 373 |
+
```python
|
| 374 |
+
>>> from transformers import AutoConfig, BaseAutoModelClass
|
| 375 |
+
|
| 376 |
+
>>> # Download model and configuration from huggingface.co and cache.
|
| 377 |
+
>>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
|
| 378 |
+
|
| 379 |
+
>>> # Update configuration during loading
|
| 380 |
+
>>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
|
| 381 |
+
>>> model.config.output_attentions
|
| 382 |
+
True
|
| 383 |
+
|
| 384 |
+
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
|
| 385 |
+
>>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json")
|
| 386 |
+
>>> model = BaseAutoModelClass.from_pretrained(
|
| 387 |
+
... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config
|
| 388 |
+
... )
|
| 389 |
+
```
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def _get_model_class(config, model_mapping):
|
| 394 |
+
supported_models = model_mapping[type(config)]
|
| 395 |
+
if not isinstance(supported_models, (list, tuple)):
|
| 396 |
+
return supported_models
|
| 397 |
+
|
| 398 |
+
name_to_model = {model.__name__: model for model in supported_models}
|
| 399 |
+
architectures = getattr(config, "architectures", [])
|
| 400 |
+
for arch in architectures:
|
| 401 |
+
if arch in name_to_model:
|
| 402 |
+
return name_to_model[arch]
|
| 403 |
+
elif f"TF{arch}" in name_to_model:
|
| 404 |
+
return name_to_model[f"TF{arch}"]
|
| 405 |
+
elif f"Flax{arch}" in name_to_model:
|
| 406 |
+
return name_to_model[f"Flax{arch}"]
|
| 407 |
+
|
| 408 |
+
# If not architecture is set in the config or match the supported models, the first element of the tuple is the
|
| 409 |
+
# defaults.
|
| 410 |
+
return supported_models[0]
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class _BaseAutoModelClass:
|
| 414 |
+
# Base class for auto models.
|
| 415 |
+
_model_mapping = None
|
| 416 |
+
|
| 417 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 418 |
+
raise OSError(
|
| 419 |
+
f"{self.__class__.__name__} is designed to be instantiated "
|
| 420 |
+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
| 421 |
+
f"`{self.__class__.__name__}.from_config(config)` methods."
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
@classmethod
|
| 425 |
+
def from_config(cls, config, **kwargs):
|
| 426 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 427 |
+
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
|
| 428 |
+
has_local_code = type(config) in cls._model_mapping
|
| 429 |
+
if has_remote_code:
|
| 430 |
+
class_ref = config.auto_map[cls.__name__]
|
| 431 |
+
if "--" in class_ref:
|
| 432 |
+
upstream_repo = class_ref.split("--")[0]
|
| 433 |
+
else:
|
| 434 |
+
upstream_repo = None
|
| 435 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 436 |
+
trust_remote_code, config._name_or_path, has_local_code, has_remote_code, upstream_repo=upstream_repo
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
if has_remote_code and trust_remote_code:
|
| 440 |
+
if "--" in class_ref:
|
| 441 |
+
repo_id, class_ref = class_ref.split("--")
|
| 442 |
+
else:
|
| 443 |
+
repo_id = config.name_or_path
|
| 444 |
+
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
|
| 445 |
+
# This block handles the case where the user is loading a model with `trust_remote_code=True`
|
| 446 |
+
# but a library model exists with the same name. We don't want to override the autoclass
|
| 447 |
+
# mappings in this case, or all future loads of that model will be the remote code model.
|
| 448 |
+
if not has_local_code:
|
| 449 |
+
cls.register(config.__class__, model_class, exist_ok=True)
|
| 450 |
+
model_class.register_for_auto_class(auto_class=cls)
|
| 451 |
+
_ = kwargs.pop("code_revision", None)
|
| 452 |
+
model_class = add_generation_mixin_to_remote_model(model_class)
|
| 453 |
+
return model_class._from_config(config, **kwargs)
|
| 454 |
+
elif type(config) in cls._model_mapping:
|
| 455 |
+
model_class = _get_model_class(config, cls._model_mapping)
|
| 456 |
+
return model_class._from_config(config, **kwargs)
|
| 457 |
+
|
| 458 |
+
raise ValueError(
|
| 459 |
+
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
| 460 |
+
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}."
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
@classmethod
|
| 464 |
+
def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
|
| 465 |
+
"""Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses."""
|
| 466 |
+
return config
|
| 467 |
+
|
| 468 |
+
@classmethod
|
| 469 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], *model_args, **kwargs):
|
| 470 |
+
config = kwargs.pop("config", None)
|
| 471 |
+
trust_remote_code = kwargs.get("trust_remote_code")
|
| 472 |
+
kwargs["_from_auto"] = True
|
| 473 |
+
hub_kwargs_names = [
|
| 474 |
+
"cache_dir",
|
| 475 |
+
"force_download",
|
| 476 |
+
"local_files_only",
|
| 477 |
+
"proxies",
|
| 478 |
+
"resume_download",
|
| 479 |
+
"revision",
|
| 480 |
+
"subfolder",
|
| 481 |
+
"use_auth_token",
|
| 482 |
+
"token",
|
| 483 |
+
]
|
| 484 |
+
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
| 485 |
+
code_revision = kwargs.pop("code_revision", None)
|
| 486 |
+
commit_hash = kwargs.pop("_commit_hash", None)
|
| 487 |
+
adapter_kwargs = kwargs.pop("adapter_kwargs", None)
|
| 488 |
+
|
| 489 |
+
token = hub_kwargs.pop("token", None)
|
| 490 |
+
use_auth_token = hub_kwargs.pop("use_auth_token", None)
|
| 491 |
+
if use_auth_token is not None:
|
| 492 |
+
warnings.warn(
|
| 493 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 494 |
+
FutureWarning,
|
| 495 |
+
)
|
| 496 |
+
if token is not None:
|
| 497 |
+
raise ValueError(
|
| 498 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 499 |
+
)
|
| 500 |
+
token = use_auth_token
|
| 501 |
+
|
| 502 |
+
if token is not None:
|
| 503 |
+
hub_kwargs["token"] = token
|
| 504 |
+
|
| 505 |
+
if commit_hash is None:
|
| 506 |
+
if not isinstance(config, PretrainedConfig):
|
| 507 |
+
# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
|
| 508 |
+
resolved_config_file = cached_file(
|
| 509 |
+
pretrained_model_name_or_path,
|
| 510 |
+
CONFIG_NAME,
|
| 511 |
+
_raise_exceptions_for_gated_repo=False,
|
| 512 |
+
_raise_exceptions_for_missing_entries=False,
|
| 513 |
+
_raise_exceptions_for_connection_errors=False,
|
| 514 |
+
**hub_kwargs,
|
| 515 |
+
)
|
| 516 |
+
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
| 517 |
+
else:
|
| 518 |
+
commit_hash = getattr(config, "_commit_hash", None)
|
| 519 |
+
|
| 520 |
+
if is_peft_available():
|
| 521 |
+
if adapter_kwargs is None:
|
| 522 |
+
adapter_kwargs = {}
|
| 523 |
+
if token is not None:
|
| 524 |
+
adapter_kwargs["token"] = token
|
| 525 |
+
|
| 526 |
+
maybe_adapter_path = find_adapter_config_file(
|
| 527 |
+
pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
if maybe_adapter_path is not None:
|
| 531 |
+
with open(maybe_adapter_path, "r", encoding="utf-8") as f:
|
| 532 |
+
adapter_config = json.load(f)
|
| 533 |
+
|
| 534 |
+
adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path
|
| 535 |
+
pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
|
| 536 |
+
|
| 537 |
+
if not isinstance(config, PretrainedConfig):
|
| 538 |
+
kwargs_orig = copy.deepcopy(kwargs)
|
| 539 |
+
# ensure not to pollute the config object with dtype="auto" - since it's
|
| 540 |
+
# meaningless in the context of the config object - torch.dtype values are acceptable
|
| 541 |
+
if kwargs.get("torch_dtype") == "auto":
|
| 542 |
+
_ = kwargs.pop("torch_dtype")
|
| 543 |
+
if kwargs.get("dtype") == "auto":
|
| 544 |
+
_ = kwargs.pop("dtype")
|
| 545 |
+
# to not overwrite the quantization_config if config has a quantization_config
|
| 546 |
+
if kwargs.get("quantization_config") is not None:
|
| 547 |
+
_ = kwargs.pop("quantization_config")
|
| 548 |
+
|
| 549 |
+
config, kwargs = AutoConfig.from_pretrained(
|
| 550 |
+
pretrained_model_name_or_path,
|
| 551 |
+
return_unused_kwargs=True,
|
| 552 |
+
code_revision=code_revision,
|
| 553 |
+
_commit_hash=commit_hash,
|
| 554 |
+
**hub_kwargs,
|
| 555 |
+
**kwargs,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
# if torch_dtype=auto was passed here, ensure to pass it on
|
| 559 |
+
if kwargs_orig.get("torch_dtype", None) == "auto":
|
| 560 |
+
kwargs["torch_dtype"] = "auto"
|
| 561 |
+
if kwargs_orig.get("dtype", None) == "auto":
|
| 562 |
+
kwargs["dtype"] = "auto"
|
| 563 |
+
if kwargs_orig.get("quantization_config", None) is not None:
|
| 564 |
+
kwargs["quantization_config"] = kwargs_orig["quantization_config"]
|
| 565 |
+
|
| 566 |
+
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
|
| 567 |
+
has_local_code = type(config) in cls._model_mapping
|
| 568 |
+
upstream_repo = None
|
| 569 |
+
if has_remote_code:
|
| 570 |
+
class_ref = config.auto_map[cls.__name__]
|
| 571 |
+
if "--" in class_ref:
|
| 572 |
+
upstream_repo = class_ref.split("--")[0]
|
| 573 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 574 |
+
trust_remote_code,
|
| 575 |
+
pretrained_model_name_or_path,
|
| 576 |
+
has_local_code,
|
| 577 |
+
has_remote_code,
|
| 578 |
+
upstream_repo=upstream_repo,
|
| 579 |
+
)
|
| 580 |
+
kwargs["trust_remote_code"] = trust_remote_code
|
| 581 |
+
|
| 582 |
+
# Set the adapter kwargs
|
| 583 |
+
kwargs["adapter_kwargs"] = adapter_kwargs
|
| 584 |
+
|
| 585 |
+
if has_remote_code and trust_remote_code:
|
| 586 |
+
model_class = get_class_from_dynamic_module(
|
| 587 |
+
class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
|
| 588 |
+
)
|
| 589 |
+
_ = hub_kwargs.pop("code_revision", None)
|
| 590 |
+
# This block handles the case where the user is loading a model with `trust_remote_code=True`
|
| 591 |
+
# but a library model exists with the same name. We don't want to override the autoclass
|
| 592 |
+
# mappings in this case, or all future loads of that model will be the remote code model.
|
| 593 |
+
if not has_local_code:
|
| 594 |
+
cls.register(config.__class__, model_class, exist_ok=True)
|
| 595 |
+
model_class.register_for_auto_class(auto_class=cls)
|
| 596 |
+
model_class = add_generation_mixin_to_remote_model(model_class)
|
| 597 |
+
return model_class.from_pretrained(
|
| 598 |
+
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
| 599 |
+
)
|
| 600 |
+
elif type(config) in cls._model_mapping:
|
| 601 |
+
model_class = _get_model_class(config, cls._model_mapping)
|
| 602 |
+
if model_class.config_class == config.sub_configs.get("text_config", None):
|
| 603 |
+
config = config.get_text_config()
|
| 604 |
+
return model_class.from_pretrained(
|
| 605 |
+
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
| 606 |
+
)
|
| 607 |
+
raise ValueError(
|
| 608 |
+
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
| 609 |
+
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}."
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
@classmethod
|
| 613 |
+
def register(cls, config_class, model_class, exist_ok=False) -> None:
|
| 614 |
+
"""
|
| 615 |
+
Register a new model for this class.
|
| 616 |
+
|
| 617 |
+
Args:
|
| 618 |
+
config_class ([`PretrainedConfig`]):
|
| 619 |
+
The configuration corresponding to the model to register.
|
| 620 |
+
model_class ([`PreTrainedModel`]):
|
| 621 |
+
The model to register.
|
| 622 |
+
"""
|
| 623 |
+
if hasattr(model_class, "config_class") and model_class.config_class.__name__ != config_class.__name__:
|
| 624 |
+
raise ValueError(
|
| 625 |
+
"The model class you are passing has a `config_class` attribute that is not consistent with the "
|
| 626 |
+
f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix "
|
| 627 |
+
"one of those so they match!"
|
| 628 |
+
)
|
| 629 |
+
cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok)
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
class _BaseAutoBackboneClass(_BaseAutoModelClass):
|
| 633 |
+
# Base class for auto backbone models.
|
| 634 |
+
_model_mapping = None
|
| 635 |
+
|
| 636 |
+
@classmethod
|
| 637 |
+
def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 638 |
+
requires_backends(cls, ["vision", "timm"])
|
| 639 |
+
from ...models.timm_backbone import TimmBackboneConfig
|
| 640 |
+
|
| 641 |
+
config = kwargs.pop("config", TimmBackboneConfig())
|
| 642 |
+
|
| 643 |
+
if kwargs.get("out_features") is not None:
|
| 644 |
+
raise ValueError("Cannot specify `out_features` for timm backbones")
|
| 645 |
+
|
| 646 |
+
if kwargs.get("output_loading_info", False):
|
| 647 |
+
raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")
|
| 648 |
+
|
| 649 |
+
num_channels = kwargs.pop("num_channels", config.num_channels)
|
| 650 |
+
features_only = kwargs.pop("features_only", config.features_only)
|
| 651 |
+
use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
|
| 652 |
+
out_indices = kwargs.pop("out_indices", config.out_indices)
|
| 653 |
+
config = TimmBackboneConfig(
|
| 654 |
+
backbone=pretrained_model_name_or_path,
|
| 655 |
+
num_channels=num_channels,
|
| 656 |
+
features_only=features_only,
|
| 657 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 658 |
+
out_indices=out_indices,
|
| 659 |
+
)
|
| 660 |
+
return super().from_config(config, **kwargs)
|
| 661 |
+
|
| 662 |
+
@classmethod
|
| 663 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 664 |
+
use_timm_backbone = kwargs.pop("use_timm_backbone", False)
|
| 665 |
+
if use_timm_backbone:
|
| 666 |
+
return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
| 667 |
+
|
| 668 |
+
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def insert_head_doc(docstring, head_doc: str = ""):
|
| 672 |
+
if len(head_doc) > 0:
|
| 673 |
+
return docstring.replace(
|
| 674 |
+
"one of the model classes of the library ",
|
| 675 |
+
f"one of the model classes of the library (with a {head_doc} head) ",
|
| 676 |
+
)
|
| 677 |
+
return docstring.replace(
|
| 678 |
+
"one of the model classes of the library ", "one of the base model classes of the library "
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def auto_class_update(cls, checkpoint_for_example: str = "google-bert/bert-base-cased", head_doc: str = ""):
|
| 683 |
+
# Create a new class with the right name from the base class
|
| 684 |
+
model_mapping = cls._model_mapping
|
| 685 |
+
name = cls.__name__
|
| 686 |
+
class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
|
| 687 |
+
cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
|
| 688 |
+
|
| 689 |
+
# Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
|
| 690 |
+
# have a specific docstrings for them.
|
| 691 |
+
from_config = copy_func(_BaseAutoModelClass.from_config)
|
| 692 |
+
from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc)
|
| 693 |
+
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
|
| 694 |
+
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
|
| 695 |
+
from_config.__doc__ = from_config_docstring
|
| 696 |
+
from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
|
| 697 |
+
cls.from_config = classmethod(from_config)
|
| 698 |
+
|
| 699 |
+
if name.startswith("TF"):
|
| 700 |
+
from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
|
| 701 |
+
elif name.startswith("Flax"):
|
| 702 |
+
from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING
|
| 703 |
+
else:
|
| 704 |
+
from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING
|
| 705 |
+
from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)
|
| 706 |
+
from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc)
|
| 707 |
+
from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name)
|
| 708 |
+
from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
|
| 709 |
+
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
|
| 710 |
+
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
|
| 711 |
+
from_pretrained.__doc__ = from_pretrained_docstring
|
| 712 |
+
from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
|
| 713 |
+
cls.from_pretrained = classmethod(from_pretrained)
|
| 714 |
+
return cls
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
def get_values(model_mapping):
|
| 718 |
+
result = []
|
| 719 |
+
for model in model_mapping.values():
|
| 720 |
+
if isinstance(model, (list, tuple)):
|
| 721 |
+
result += list(model)
|
| 722 |
+
else:
|
| 723 |
+
result.append(model)
|
| 724 |
+
|
| 725 |
+
return result
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def getattribute_from_module(module, attr):
|
| 729 |
+
if attr is None:
|
| 730 |
+
return None
|
| 731 |
+
if isinstance(attr, tuple):
|
| 732 |
+
return tuple(getattribute_from_module(module, a) for a in attr)
|
| 733 |
+
if hasattr(module, attr):
|
| 734 |
+
return getattr(module, attr)
|
| 735 |
+
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
|
| 736 |
+
# object at the top level.
|
| 737 |
+
transformers_module = importlib.import_module("transformers")
|
| 738 |
+
|
| 739 |
+
if module != transformers_module:
|
| 740 |
+
try:
|
| 741 |
+
return getattribute_from_module(transformers_module, attr)
|
| 742 |
+
except ValueError:
|
| 743 |
+
raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!")
|
| 744 |
+
else:
|
| 745 |
+
raise ValueError(f"Could not find {attr} in {transformers_module}!")
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
def add_generation_mixin_to_remote_model(model_class):
|
| 749 |
+
"""
|
| 750 |
+
Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model.
|
| 751 |
+
|
| 752 |
+
This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make
|
| 753 |
+
`PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded
|
| 754 |
+
from the Hub may not have the `generate` method after we remove the inheritance.
|
| 755 |
+
"""
|
| 756 |
+
# 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing
|
| 757 |
+
if "torch.nn.modules.module.Module" not in str(model_class.__mro__):
|
| 758 |
+
return model_class
|
| 759 |
+
|
| 760 |
+
# 2. If it already **directly** inherits from GenerationMixin, do nothing
|
| 761 |
+
if "GenerationMixin" in str(model_class.__bases__):
|
| 762 |
+
return model_class
|
| 763 |
+
|
| 764 |
+
# 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
|
| 765 |
+
# `prepare_inputs_for_generation` method.
|
| 766 |
+
has_custom_generate_in_class = hasattr(model_class, "generate") and "GenerationMixin" not in str(
|
| 767 |
+
getattr(model_class, "generate")
|
| 768 |
+
)
|
| 769 |
+
has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str(
|
| 770 |
+
getattr(model_class, "prepare_inputs_for_generation")
|
| 771 |
+
)
|
| 772 |
+
if has_custom_generate_in_class or has_custom_prepare_inputs:
|
| 773 |
+
model_class_with_generation_mixin = type(
|
| 774 |
+
model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
|
| 775 |
+
)
|
| 776 |
+
return model_class_with_generation_mixin
|
| 777 |
+
return model_class
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
class _LazyAutoMapping(OrderedDict[type[PretrainedConfig], _LazyAutoMappingValue]):
|
| 781 |
+
"""
|
| 782 |
+
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
|
| 783 |
+
|
| 784 |
+
Args:
|
| 785 |
+
- config_mapping: The map model type to config class
|
| 786 |
+
- model_mapping: The map model type to model (or tokenizer) class
|
| 787 |
+
"""
|
| 788 |
+
|
| 789 |
+
def __init__(self, config_mapping, model_mapping) -> None:
|
| 790 |
+
self._config_mapping = config_mapping
|
| 791 |
+
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
|
| 792 |
+
self._model_mapping = model_mapping
|
| 793 |
+
self._model_mapping._model_mapping = self
|
| 794 |
+
self._extra_content = {}
|
| 795 |
+
self._modules = {}
|
| 796 |
+
|
| 797 |
+
def __len__(self) -> int:
|
| 798 |
+
common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys())
|
| 799 |
+
return len(common_keys) + len(self._extra_content)
|
| 800 |
+
|
| 801 |
+
def __getitem__(self, key: type[PretrainedConfig]) -> _LazyAutoMappingValue:
|
| 802 |
+
if key in self._extra_content:
|
| 803 |
+
return self._extra_content[key]
|
| 804 |
+
model_type = self._reverse_config_mapping[key.__name__]
|
| 805 |
+
if model_type in self._model_mapping:
|
| 806 |
+
model_name = self._model_mapping[model_type]
|
| 807 |
+
return self._load_attr_from_module(model_type, model_name)
|
| 808 |
+
|
| 809 |
+
# Maybe there was several model types associated with this config.
|
| 810 |
+
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
|
| 811 |
+
for mtype in model_types:
|
| 812 |
+
if mtype in self._model_mapping:
|
| 813 |
+
model_name = self._model_mapping[mtype]
|
| 814 |
+
return self._load_attr_from_module(mtype, model_name)
|
| 815 |
+
raise KeyError(key)
|
| 816 |
+
|
| 817 |
+
def _load_attr_from_module(self, model_type, attr):
|
| 818 |
+
module_name = model_type_to_module_name(model_type)
|
| 819 |
+
if module_name not in self._modules:
|
| 820 |
+
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
|
| 821 |
+
return getattribute_from_module(self._modules[module_name], attr)
|
| 822 |
+
|
| 823 |
+
def keys(self) -> list[type[PretrainedConfig]]:
|
| 824 |
+
mapping_keys = [
|
| 825 |
+
self._load_attr_from_module(key, name)
|
| 826 |
+
for key, name in self._config_mapping.items()
|
| 827 |
+
if key in self._model_mapping
|
| 828 |
+
]
|
| 829 |
+
return mapping_keys + list(self._extra_content.keys())
|
| 830 |
+
|
| 831 |
+
def get(self, key: type[PretrainedConfig], default: _T) -> Union[_LazyAutoMappingValue, _T]:
|
| 832 |
+
try:
|
| 833 |
+
return self.__getitem__(key)
|
| 834 |
+
except KeyError:
|
| 835 |
+
return default
|
| 836 |
+
|
| 837 |
+
def __bool__(self) -> bool:
|
| 838 |
+
return bool(self.keys())
|
| 839 |
+
|
| 840 |
+
def values(self) -> list[_LazyAutoMappingValue]:
|
| 841 |
+
mapping_values = [
|
| 842 |
+
self._load_attr_from_module(key, name)
|
| 843 |
+
for key, name in self._model_mapping.items()
|
| 844 |
+
if key in self._config_mapping
|
| 845 |
+
]
|
| 846 |
+
return mapping_values + list(self._extra_content.values())
|
| 847 |
+
|
| 848 |
+
def items(self) -> list[tuple[type[PretrainedConfig], _LazyAutoMappingValue]]:
|
| 849 |
+
mapping_items = [
|
| 850 |
+
(
|
| 851 |
+
self._load_attr_from_module(key, self._config_mapping[key]),
|
| 852 |
+
self._load_attr_from_module(key, self._model_mapping[key]),
|
| 853 |
+
)
|
| 854 |
+
for key in self._model_mapping
|
| 855 |
+
if key in self._config_mapping
|
| 856 |
+
]
|
| 857 |
+
return mapping_items + list(self._extra_content.items())
|
| 858 |
+
|
| 859 |
+
def __iter__(self) -> Iterator[type[PretrainedConfig]]:
|
| 860 |
+
return iter(self.keys())
|
| 861 |
+
|
| 862 |
+
def __contains__(self, item: type) -> bool:
|
| 863 |
+
if item in self._extra_content:
|
| 864 |
+
return True
|
| 865 |
+
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
|
| 866 |
+
return False
|
| 867 |
+
model_type = self._reverse_config_mapping[item.__name__]
|
| 868 |
+
return model_type in self._model_mapping
|
| 869 |
+
|
| 870 |
+
def register(self, key: type[PretrainedConfig], value: _LazyAutoMappingValue, exist_ok=False) -> None:
|
| 871 |
+
"""
|
| 872 |
+
Register a new model in this mapping.
|
| 873 |
+
"""
|
| 874 |
+
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
|
| 875 |
+
model_type = self._reverse_config_mapping[key.__name__]
|
| 876 |
+
if model_type in self._model_mapping and not exist_ok:
|
| 877 |
+
raise ValueError(f"'{key}' is already used by a Transformers model.")
|
| 878 |
+
|
| 879 |
+
self._extra_content[key] = value
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
__all__ = ["get_values"]
|
venv/lib/python3.13/site-packages/transformers/models/auto/configuration_auto.py
ADDED
|
@@ -0,0 +1,1404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Auto Config class."""
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import os
|
| 19 |
+
import re
|
| 20 |
+
import warnings
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
from collections.abc import Callable, Iterator, KeysView, ValuesView
|
| 23 |
+
from typing import Any, TypeVar, Union
|
| 24 |
+
|
| 25 |
+
from ...configuration_utils import PretrainedConfig
|
| 26 |
+
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
| 27 |
+
from ...utils import CONFIG_NAME, logging
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
_CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
| 37 |
+
[
|
| 38 |
+
# Add configs here
|
| 39 |
+
("aimv2", "Aimv2Config"),
|
| 40 |
+
("aimv2_vision_model", "Aimv2VisionConfig"),
|
| 41 |
+
("albert", "AlbertConfig"),
|
| 42 |
+
("align", "AlignConfig"),
|
| 43 |
+
("altclip", "AltCLIPConfig"),
|
| 44 |
+
("apertus", "ApertusConfig"),
|
| 45 |
+
("arcee", "ArceeConfig"),
|
| 46 |
+
("aria", "AriaConfig"),
|
| 47 |
+
("aria_text", "AriaTextConfig"),
|
| 48 |
+
("audio-spectrogram-transformer", "ASTConfig"),
|
| 49 |
+
("autoformer", "AutoformerConfig"),
|
| 50 |
+
("aya_vision", "AyaVisionConfig"),
|
| 51 |
+
("bamba", "BambaConfig"),
|
| 52 |
+
("bark", "BarkConfig"),
|
| 53 |
+
("bart", "BartConfig"),
|
| 54 |
+
("beit", "BeitConfig"),
|
| 55 |
+
("bert", "BertConfig"),
|
| 56 |
+
("bert-generation", "BertGenerationConfig"),
|
| 57 |
+
("big_bird", "BigBirdConfig"),
|
| 58 |
+
("bigbird_pegasus", "BigBirdPegasusConfig"),
|
| 59 |
+
("biogpt", "BioGptConfig"),
|
| 60 |
+
("bit", "BitConfig"),
|
| 61 |
+
("bitnet", "BitNetConfig"),
|
| 62 |
+
("blenderbot", "BlenderbotConfig"),
|
| 63 |
+
("blenderbot-small", "BlenderbotSmallConfig"),
|
| 64 |
+
("blip", "BlipConfig"),
|
| 65 |
+
("blip-2", "Blip2Config"),
|
| 66 |
+
("blip_2_qformer", "Blip2QFormerConfig"),
|
| 67 |
+
("bloom", "BloomConfig"),
|
| 68 |
+
("blt", "BltConfig"),
|
| 69 |
+
("bridgetower", "BridgeTowerConfig"),
|
| 70 |
+
("bros", "BrosConfig"),
|
| 71 |
+
("camembert", "CamembertConfig"),
|
| 72 |
+
("canine", "CanineConfig"),
|
| 73 |
+
("chameleon", "ChameleonConfig"),
|
| 74 |
+
("chinese_clip", "ChineseCLIPConfig"),
|
| 75 |
+
("chinese_clip_vision_model", "ChineseCLIPVisionConfig"),
|
| 76 |
+
("clap", "ClapConfig"),
|
| 77 |
+
("clip", "CLIPConfig"),
|
| 78 |
+
("clip_text_model", "CLIPTextConfig"),
|
| 79 |
+
("clip_vision_model", "CLIPVisionConfig"),
|
| 80 |
+
("clipseg", "CLIPSegConfig"),
|
| 81 |
+
("clvp", "ClvpConfig"),
|
| 82 |
+
("code_llama", "LlamaConfig"),
|
| 83 |
+
("codegen", "CodeGenConfig"),
|
| 84 |
+
("cohere", "CohereConfig"),
|
| 85 |
+
("cohere2", "Cohere2Config"),
|
| 86 |
+
("cohere2_vision", "Cohere2VisionConfig"),
|
| 87 |
+
("colpali", "ColPaliConfig"),
|
| 88 |
+
("colqwen2", "ColQwen2Config"),
|
| 89 |
+
("conditional_detr", "ConditionalDetrConfig"),
|
| 90 |
+
("convbert", "ConvBertConfig"),
|
| 91 |
+
("convnext", "ConvNextConfig"),
|
| 92 |
+
("convnextv2", "ConvNextV2Config"),
|
| 93 |
+
("cpmant", "CpmAntConfig"),
|
| 94 |
+
("csm", "CsmConfig"),
|
| 95 |
+
("ctrl", "CTRLConfig"),
|
| 96 |
+
("cvt", "CvtConfig"),
|
| 97 |
+
("d_fine", "DFineConfig"),
|
| 98 |
+
("dab-detr", "DabDetrConfig"),
|
| 99 |
+
("dac", "DacConfig"),
|
| 100 |
+
("data2vec-audio", "Data2VecAudioConfig"),
|
| 101 |
+
("data2vec-text", "Data2VecTextConfig"),
|
| 102 |
+
("data2vec-vision", "Data2VecVisionConfig"),
|
| 103 |
+
("dbrx", "DbrxConfig"),
|
| 104 |
+
("deberta", "DebertaConfig"),
|
| 105 |
+
("deberta-v2", "DebertaV2Config"),
|
| 106 |
+
("decision_transformer", "DecisionTransformerConfig"),
|
| 107 |
+
("deepseek_v2", "DeepseekV2Config"),
|
| 108 |
+
("deepseek_v3", "DeepseekV3Config"),
|
| 109 |
+
("deepseek_vl", "DeepseekVLConfig"),
|
| 110 |
+
("deepseek_vl_hybrid", "DeepseekVLHybridConfig"),
|
| 111 |
+
("deformable_detr", "DeformableDetrConfig"),
|
| 112 |
+
("deit", "DeiTConfig"),
|
| 113 |
+
("depth_anything", "DepthAnythingConfig"),
|
| 114 |
+
("depth_pro", "DepthProConfig"),
|
| 115 |
+
("deta", "DetaConfig"),
|
| 116 |
+
("detr", "DetrConfig"),
|
| 117 |
+
("dia", "DiaConfig"),
|
| 118 |
+
("diffllama", "DiffLlamaConfig"),
|
| 119 |
+
("dinat", "DinatConfig"),
|
| 120 |
+
("dinov2", "Dinov2Config"),
|
| 121 |
+
("dinov2_with_registers", "Dinov2WithRegistersConfig"),
|
| 122 |
+
("dinov3_convnext", "DINOv3ConvNextConfig"),
|
| 123 |
+
("dinov3_vit", "DINOv3ViTConfig"),
|
| 124 |
+
("distilbert", "DistilBertConfig"),
|
| 125 |
+
("doge", "DogeConfig"),
|
| 126 |
+
("donut-swin", "DonutSwinConfig"),
|
| 127 |
+
("dots1", "Dots1Config"),
|
| 128 |
+
("dpr", "DPRConfig"),
|
| 129 |
+
("dpt", "DPTConfig"),
|
| 130 |
+
("edgetam", "EdgeTamConfig"),
|
| 131 |
+
("edgetam_video", "EdgeTamVideoConfig"),
|
| 132 |
+
("edgetam_vision_model", "EdgeTamVisionConfig"),
|
| 133 |
+
("efficientformer", "EfficientFormerConfig"),
|
| 134 |
+
("efficientloftr", "EfficientLoFTRConfig"),
|
| 135 |
+
("efficientnet", "EfficientNetConfig"),
|
| 136 |
+
("electra", "ElectraConfig"),
|
| 137 |
+
("emu3", "Emu3Config"),
|
| 138 |
+
("encodec", "EncodecConfig"),
|
| 139 |
+
("encoder-decoder", "EncoderDecoderConfig"),
|
| 140 |
+
("eomt", "EomtConfig"),
|
| 141 |
+
("ernie", "ErnieConfig"),
|
| 142 |
+
("ernie4_5", "Ernie4_5Config"),
|
| 143 |
+
("ernie4_5_moe", "Ernie4_5_MoeConfig"),
|
| 144 |
+
("ernie_m", "ErnieMConfig"),
|
| 145 |
+
("esm", "EsmConfig"),
|
| 146 |
+
("evolla", "EvollaConfig"),
|
| 147 |
+
("exaone4", "Exaone4Config"),
|
| 148 |
+
("falcon", "FalconConfig"),
|
| 149 |
+
("falcon_h1", "FalconH1Config"),
|
| 150 |
+
("falcon_mamba", "FalconMambaConfig"),
|
| 151 |
+
("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
|
| 152 |
+
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGanConfig"),
|
| 153 |
+
("flaubert", "FlaubertConfig"),
|
| 154 |
+
("flava", "FlavaConfig"),
|
| 155 |
+
("flex_olmo", "FlexOlmoConfig"),
|
| 156 |
+
("florence2", "Florence2Config"),
|
| 157 |
+
("fnet", "FNetConfig"),
|
| 158 |
+
("focalnet", "FocalNetConfig"),
|
| 159 |
+
("fsmt", "FSMTConfig"),
|
| 160 |
+
("funnel", "FunnelConfig"),
|
| 161 |
+
("fuyu", "FuyuConfig"),
|
| 162 |
+
("gemma", "GemmaConfig"),
|
| 163 |
+
("gemma2", "Gemma2Config"),
|
| 164 |
+
("gemma3", "Gemma3Config"),
|
| 165 |
+
("gemma3_text", "Gemma3TextConfig"),
|
| 166 |
+
("gemma3n", "Gemma3nConfig"),
|
| 167 |
+
("gemma3n_audio", "Gemma3nAudioConfig"),
|
| 168 |
+
("gemma3n_text", "Gemma3nTextConfig"),
|
| 169 |
+
("gemma3n_vision", "Gemma3nVisionConfig"),
|
| 170 |
+
("git", "GitConfig"),
|
| 171 |
+
("glm", "GlmConfig"),
|
| 172 |
+
("glm4", "Glm4Config"),
|
| 173 |
+
("glm4_moe", "Glm4MoeConfig"),
|
| 174 |
+
("glm4v", "Glm4vConfig"),
|
| 175 |
+
("glm4v_moe", "Glm4vMoeConfig"),
|
| 176 |
+
("glm4v_moe_text", "Glm4vMoeTextConfig"),
|
| 177 |
+
("glm4v_text", "Glm4vTextConfig"),
|
| 178 |
+
("glpn", "GLPNConfig"),
|
| 179 |
+
("got_ocr2", "GotOcr2Config"),
|
| 180 |
+
("gpt-sw3", "GPT2Config"),
|
| 181 |
+
("gpt2", "GPT2Config"),
|
| 182 |
+
("gpt_bigcode", "GPTBigCodeConfig"),
|
| 183 |
+
("gpt_neo", "GPTNeoConfig"),
|
| 184 |
+
("gpt_neox", "GPTNeoXConfig"),
|
| 185 |
+
("gpt_neox_japanese", "GPTNeoXJapaneseConfig"),
|
| 186 |
+
("gpt_oss", "GptOssConfig"),
|
| 187 |
+
("gptj", "GPTJConfig"),
|
| 188 |
+
("gptsan-japanese", "GPTSanJapaneseConfig"),
|
| 189 |
+
("granite", "GraniteConfig"),
|
| 190 |
+
("granite_speech", "GraniteSpeechConfig"),
|
| 191 |
+
("granitemoe", "GraniteMoeConfig"),
|
| 192 |
+
("granitemoehybrid", "GraniteMoeHybridConfig"),
|
| 193 |
+
("granitemoeshared", "GraniteMoeSharedConfig"),
|
| 194 |
+
("granitevision", "LlavaNextConfig"),
|
| 195 |
+
("graphormer", "GraphormerConfig"),
|
| 196 |
+
("grounding-dino", "GroundingDinoConfig"),
|
| 197 |
+
("groupvit", "GroupViTConfig"),
|
| 198 |
+
("helium", "HeliumConfig"),
|
| 199 |
+
("hgnet_v2", "HGNetV2Config"),
|
| 200 |
+
("hiera", "HieraConfig"),
|
| 201 |
+
("hubert", "HubertConfig"),
|
| 202 |
+
("hunyuan_v1_dense", "HunYuanDenseV1Config"),
|
| 203 |
+
("hunyuan_v1_moe", "HunYuanMoEV1Config"),
|
| 204 |
+
("ibert", "IBertConfig"),
|
| 205 |
+
("idefics", "IdeficsConfig"),
|
| 206 |
+
("idefics2", "Idefics2Config"),
|
| 207 |
+
("idefics3", "Idefics3Config"),
|
| 208 |
+
("idefics3_vision", "Idefics3VisionConfig"),
|
| 209 |
+
("ijepa", "IJepaConfig"),
|
| 210 |
+
("imagegpt", "ImageGPTConfig"),
|
| 211 |
+
("informer", "InformerConfig"),
|
| 212 |
+
("instructblip", "InstructBlipConfig"),
|
| 213 |
+
("instructblipvideo", "InstructBlipVideoConfig"),
|
| 214 |
+
("internvl", "InternVLConfig"),
|
| 215 |
+
("internvl_vision", "InternVLVisionConfig"),
|
| 216 |
+
("jamba", "JambaConfig"),
|
| 217 |
+
("janus", "JanusConfig"),
|
| 218 |
+
("jetmoe", "JetMoeConfig"),
|
| 219 |
+
("jukebox", "JukeboxConfig"),
|
| 220 |
+
("kosmos-2", "Kosmos2Config"),
|
| 221 |
+
("kosmos-2.5", "Kosmos2_5Config"),
|
| 222 |
+
("kyutai_speech_to_text", "KyutaiSpeechToTextConfig"),
|
| 223 |
+
("layoutlm", "LayoutLMConfig"),
|
| 224 |
+
("layoutlmv2", "LayoutLMv2Config"),
|
| 225 |
+
("layoutlmv3", "LayoutLMv3Config"),
|
| 226 |
+
("led", "LEDConfig"),
|
| 227 |
+
("levit", "LevitConfig"),
|
| 228 |
+
("lfm2", "Lfm2Config"),
|
| 229 |
+
("lfm2_vl", "Lfm2VlConfig"),
|
| 230 |
+
("lightglue", "LightGlueConfig"),
|
| 231 |
+
("lilt", "LiltConfig"),
|
| 232 |
+
("llama", "LlamaConfig"),
|
| 233 |
+
("llama4", "Llama4Config"),
|
| 234 |
+
("llama4_text", "Llama4TextConfig"),
|
| 235 |
+
("llava", "LlavaConfig"),
|
| 236 |
+
("llava_next", "LlavaNextConfig"),
|
| 237 |
+
("llava_next_video", "LlavaNextVideoConfig"),
|
| 238 |
+
("llava_onevision", "LlavaOnevisionConfig"),
|
| 239 |
+
("longcat_flash", "LongcatFlashConfig"),
|
| 240 |
+
("longformer", "LongformerConfig"),
|
| 241 |
+
("longt5", "LongT5Config"),
|
| 242 |
+
("luke", "LukeConfig"),
|
| 243 |
+
("lxmert", "LxmertConfig"),
|
| 244 |
+
("m2m_100", "M2M100Config"),
|
| 245 |
+
("mamba", "MambaConfig"),
|
| 246 |
+
("mamba2", "Mamba2Config"),
|
| 247 |
+
("marian", "MarianConfig"),
|
| 248 |
+
("markuplm", "MarkupLMConfig"),
|
| 249 |
+
("mask2former", "Mask2FormerConfig"),
|
| 250 |
+
("maskformer", "MaskFormerConfig"),
|
| 251 |
+
("maskformer-swin", "MaskFormerSwinConfig"),
|
| 252 |
+
("mbart", "MBartConfig"),
|
| 253 |
+
("mctct", "MCTCTConfig"),
|
| 254 |
+
("mega", "MegaConfig"),
|
| 255 |
+
("megatron-bert", "MegatronBertConfig"),
|
| 256 |
+
("metaclip_2", "MetaClip2Config"),
|
| 257 |
+
("mgp-str", "MgpstrConfig"),
|
| 258 |
+
("mimi", "MimiConfig"),
|
| 259 |
+
("minimax", "MiniMaxConfig"),
|
| 260 |
+
("ministral", "MinistralConfig"),
|
| 261 |
+
("mistral", "MistralConfig"),
|
| 262 |
+
("mistral3", "Mistral3Config"),
|
| 263 |
+
("mixtral", "MixtralConfig"),
|
| 264 |
+
("mlcd", "MLCDVisionConfig"),
|
| 265 |
+
("mllama", "MllamaConfig"),
|
| 266 |
+
("mm-grounding-dino", "MMGroundingDinoConfig"),
|
| 267 |
+
("mobilebert", "MobileBertConfig"),
|
| 268 |
+
("mobilenet_v1", "MobileNetV1Config"),
|
| 269 |
+
("mobilenet_v2", "MobileNetV2Config"),
|
| 270 |
+
("mobilevit", "MobileViTConfig"),
|
| 271 |
+
("mobilevitv2", "MobileViTV2Config"),
|
| 272 |
+
("modernbert", "ModernBertConfig"),
|
| 273 |
+
("modernbert-decoder", "ModernBertDecoderConfig"),
|
| 274 |
+
("moonshine", "MoonshineConfig"),
|
| 275 |
+
("moshi", "MoshiConfig"),
|
| 276 |
+
("mpnet", "MPNetConfig"),
|
| 277 |
+
("mpt", "MptConfig"),
|
| 278 |
+
("mra", "MraConfig"),
|
| 279 |
+
("mt5", "MT5Config"),
|
| 280 |
+
("musicgen", "MusicgenConfig"),
|
| 281 |
+
("musicgen_melody", "MusicgenMelodyConfig"),
|
| 282 |
+
("mvp", "MvpConfig"),
|
| 283 |
+
("nat", "NatConfig"),
|
| 284 |
+
("nemotron", "NemotronConfig"),
|
| 285 |
+
("nezha", "NezhaConfig"),
|
| 286 |
+
("nllb-moe", "NllbMoeConfig"),
|
| 287 |
+
("nougat", "VisionEncoderDecoderConfig"),
|
| 288 |
+
("nystromformer", "NystromformerConfig"),
|
| 289 |
+
("olmo", "OlmoConfig"),
|
| 290 |
+
("olmo2", "Olmo2Config"),
|
| 291 |
+
("olmo3", "Olmo3Config"),
|
| 292 |
+
("olmoe", "OlmoeConfig"),
|
| 293 |
+
("omdet-turbo", "OmDetTurboConfig"),
|
| 294 |
+
("oneformer", "OneFormerConfig"),
|
| 295 |
+
("open-llama", "OpenLlamaConfig"),
|
| 296 |
+
("openai-gpt", "OpenAIGPTConfig"),
|
| 297 |
+
("opt", "OPTConfig"),
|
| 298 |
+
("ovis2", "Ovis2Config"),
|
| 299 |
+
("owlv2", "Owlv2Config"),
|
| 300 |
+
("owlvit", "OwlViTConfig"),
|
| 301 |
+
("paligemma", "PaliGemmaConfig"),
|
| 302 |
+
("parakeet_ctc", "ParakeetCTCConfig"),
|
| 303 |
+
("parakeet_encoder", "ParakeetEncoderConfig"),
|
| 304 |
+
("patchtsmixer", "PatchTSMixerConfig"),
|
| 305 |
+
("patchtst", "PatchTSTConfig"),
|
| 306 |
+
("pegasus", "PegasusConfig"),
|
| 307 |
+
("pegasus_x", "PegasusXConfig"),
|
| 308 |
+
("perceiver", "PerceiverConfig"),
|
| 309 |
+
("perception_encoder", "TimmWrapperConfig"),
|
| 310 |
+
("perception_lm", "PerceptionLMConfig"),
|
| 311 |
+
("persimmon", "PersimmonConfig"),
|
| 312 |
+
("phi", "PhiConfig"),
|
| 313 |
+
("phi3", "Phi3Config"),
|
| 314 |
+
("phi4_multimodal", "Phi4MultimodalConfig"),
|
| 315 |
+
("phimoe", "PhimoeConfig"),
|
| 316 |
+
("pix2struct", "Pix2StructConfig"),
|
| 317 |
+
("pixtral", "PixtralVisionConfig"),
|
| 318 |
+
("plbart", "PLBartConfig"),
|
| 319 |
+
("poolformer", "PoolFormerConfig"),
|
| 320 |
+
("pop2piano", "Pop2PianoConfig"),
|
| 321 |
+
("prompt_depth_anything", "PromptDepthAnythingConfig"),
|
| 322 |
+
("prophetnet", "ProphetNetConfig"),
|
| 323 |
+
("pvt", "PvtConfig"),
|
| 324 |
+
("pvt_v2", "PvtV2Config"),
|
| 325 |
+
("qdqbert", "QDQBertConfig"),
|
| 326 |
+
("qwen2", "Qwen2Config"),
|
| 327 |
+
("qwen2_5_omni", "Qwen2_5OmniConfig"),
|
| 328 |
+
("qwen2_5_vl", "Qwen2_5_VLConfig"),
|
| 329 |
+
("qwen2_5_vl_text", "Qwen2_5_VLTextConfig"),
|
| 330 |
+
("qwen2_audio", "Qwen2AudioConfig"),
|
| 331 |
+
("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"),
|
| 332 |
+
("qwen2_moe", "Qwen2MoeConfig"),
|
| 333 |
+
("qwen2_vl", "Qwen2VLConfig"),
|
| 334 |
+
("qwen2_vl_text", "Qwen2VLTextConfig"),
|
| 335 |
+
("qwen3", "Qwen3Config"),
|
| 336 |
+
("qwen3_moe", "Qwen3MoeConfig"),
|
| 337 |
+
("qwen3_next", "Qwen3NextConfig"),
|
| 338 |
+
("qwen3_omni_moe", "Qwen3OmniMoeConfig"),
|
| 339 |
+
("qwen3_vl", "Qwen3VLConfig"),
|
| 340 |
+
("qwen3_vl_moe", "Qwen3VLMoeConfig"),
|
| 341 |
+
("qwen3_vl_moe_text", "Qwen3VLMoeTextConfig"),
|
| 342 |
+
("qwen3_vl_text", "Qwen3VLTextConfig"),
|
| 343 |
+
("rag", "RagConfig"),
|
| 344 |
+
("realm", "RealmConfig"),
|
| 345 |
+
("recurrent_gemma", "RecurrentGemmaConfig"),
|
| 346 |
+
("reformer", "ReformerConfig"),
|
| 347 |
+
("regnet", "RegNetConfig"),
|
| 348 |
+
("rembert", "RemBertConfig"),
|
| 349 |
+
("resnet", "ResNetConfig"),
|
| 350 |
+
("retribert", "RetriBertConfig"),
|
| 351 |
+
("roberta", "RobertaConfig"),
|
| 352 |
+
("roberta-prelayernorm", "RobertaPreLayerNormConfig"),
|
| 353 |
+
("roc_bert", "RoCBertConfig"),
|
| 354 |
+
("roformer", "RoFormerConfig"),
|
| 355 |
+
("rt_detr", "RTDetrConfig"),
|
| 356 |
+
("rt_detr_resnet", "RTDetrResNetConfig"),
|
| 357 |
+
("rt_detr_v2", "RTDetrV2Config"),
|
| 358 |
+
("rwkv", "RwkvConfig"),
|
| 359 |
+
("sam", "SamConfig"),
|
| 360 |
+
("sam2", "Sam2Config"),
|
| 361 |
+
("sam2_hiera_det_model", "Sam2HieraDetConfig"),
|
| 362 |
+
("sam2_video", "Sam2VideoConfig"),
|
| 363 |
+
("sam2_vision_model", "Sam2VisionConfig"),
|
| 364 |
+
("sam_hq", "SamHQConfig"),
|
| 365 |
+
("sam_hq_vision_model", "SamHQVisionConfig"),
|
| 366 |
+
("sam_vision_model", "SamVisionConfig"),
|
| 367 |
+
("seamless_m4t", "SeamlessM4TConfig"),
|
| 368 |
+
("seamless_m4t_v2", "SeamlessM4Tv2Config"),
|
| 369 |
+
("seed_oss", "SeedOssConfig"),
|
| 370 |
+
("segformer", "SegformerConfig"),
|
| 371 |
+
("seggpt", "SegGptConfig"),
|
| 372 |
+
("sew", "SEWConfig"),
|
| 373 |
+
("sew-d", "SEWDConfig"),
|
| 374 |
+
("shieldgemma2", "ShieldGemma2Config"),
|
| 375 |
+
("siglip", "SiglipConfig"),
|
| 376 |
+
("siglip2", "Siglip2Config"),
|
| 377 |
+
("siglip2_vision_model", "Siglip2VisionConfig"),
|
| 378 |
+
("siglip_vision_model", "SiglipVisionConfig"),
|
| 379 |
+
("smollm3", "SmolLM3Config"),
|
| 380 |
+
("smolvlm", "SmolVLMConfig"),
|
| 381 |
+
("smolvlm_vision", "SmolVLMVisionConfig"),
|
| 382 |
+
("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
|
| 383 |
+
("speech_to_text", "Speech2TextConfig"),
|
| 384 |
+
("speech_to_text_2", "Speech2Text2Config"),
|
| 385 |
+
("speecht5", "SpeechT5Config"),
|
| 386 |
+
("splinter", "SplinterConfig"),
|
| 387 |
+
("squeezebert", "SqueezeBertConfig"),
|
| 388 |
+
("stablelm", "StableLmConfig"),
|
| 389 |
+
("starcoder2", "Starcoder2Config"),
|
| 390 |
+
("superglue", "SuperGlueConfig"),
|
| 391 |
+
("superpoint", "SuperPointConfig"),
|
| 392 |
+
("swiftformer", "SwiftFormerConfig"),
|
| 393 |
+
("swin", "SwinConfig"),
|
| 394 |
+
("swin2sr", "Swin2SRConfig"),
|
| 395 |
+
("swinv2", "Swinv2Config"),
|
| 396 |
+
("switch_transformers", "SwitchTransformersConfig"),
|
| 397 |
+
("t5", "T5Config"),
|
| 398 |
+
("t5gemma", "T5GemmaConfig"),
|
| 399 |
+
("table-transformer", "TableTransformerConfig"),
|
| 400 |
+
("tapas", "TapasConfig"),
|
| 401 |
+
("textnet", "TextNetConfig"),
|
| 402 |
+
("time_series_transformer", "TimeSeriesTransformerConfig"),
|
| 403 |
+
("timesfm", "TimesFmConfig"),
|
| 404 |
+
("timesformer", "TimesformerConfig"),
|
| 405 |
+
("timm_backbone", "TimmBackboneConfig"),
|
| 406 |
+
("timm_wrapper", "TimmWrapperConfig"),
|
| 407 |
+
("trajectory_transformer", "TrajectoryTransformerConfig"),
|
| 408 |
+
("transfo-xl", "TransfoXLConfig"),
|
| 409 |
+
("trocr", "TrOCRConfig"),
|
| 410 |
+
("tvlt", "TvltConfig"),
|
| 411 |
+
("tvp", "TvpConfig"),
|
| 412 |
+
("udop", "UdopConfig"),
|
| 413 |
+
("umt5", "UMT5Config"),
|
| 414 |
+
("unispeech", "UniSpeechConfig"),
|
| 415 |
+
("unispeech-sat", "UniSpeechSatConfig"),
|
| 416 |
+
("univnet", "UnivNetConfig"),
|
| 417 |
+
("upernet", "UperNetConfig"),
|
| 418 |
+
("van", "VanConfig"),
|
| 419 |
+
("vaultgemma", "VaultGemmaConfig"),
|
| 420 |
+
("video_llava", "VideoLlavaConfig"),
|
| 421 |
+
("videomae", "VideoMAEConfig"),
|
| 422 |
+
("vilt", "ViltConfig"),
|
| 423 |
+
("vipllava", "VipLlavaConfig"),
|
| 424 |
+
("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
|
| 425 |
+
("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
|
| 426 |
+
("visual_bert", "VisualBertConfig"),
|
| 427 |
+
("vit", "ViTConfig"),
|
| 428 |
+
("vit_hybrid", "ViTHybridConfig"),
|
| 429 |
+
("vit_mae", "ViTMAEConfig"),
|
| 430 |
+
("vit_msn", "ViTMSNConfig"),
|
| 431 |
+
("vitdet", "VitDetConfig"),
|
| 432 |
+
("vitmatte", "VitMatteConfig"),
|
| 433 |
+
("vitpose", "VitPoseConfig"),
|
| 434 |
+
("vitpose_backbone", "VitPoseBackboneConfig"),
|
| 435 |
+
("vits", "VitsConfig"),
|
| 436 |
+
("vivit", "VivitConfig"),
|
| 437 |
+
("vjepa2", "VJEPA2Config"),
|
| 438 |
+
("voxtral", "VoxtralConfig"),
|
| 439 |
+
("voxtral_encoder", "VoxtralEncoderConfig"),
|
| 440 |
+
("wav2vec2", "Wav2Vec2Config"),
|
| 441 |
+
("wav2vec2-bert", "Wav2Vec2BertConfig"),
|
| 442 |
+
("wav2vec2-conformer", "Wav2Vec2ConformerConfig"),
|
| 443 |
+
("wavlm", "WavLMConfig"),
|
| 444 |
+
("whisper", "WhisperConfig"),
|
| 445 |
+
("xclip", "XCLIPConfig"),
|
| 446 |
+
("xcodec", "XcodecConfig"),
|
| 447 |
+
("xglm", "XGLMConfig"),
|
| 448 |
+
("xlm", "XLMConfig"),
|
| 449 |
+
("xlm-prophetnet", "XLMProphetNetConfig"),
|
| 450 |
+
("xlm-roberta", "XLMRobertaConfig"),
|
| 451 |
+
("xlm-roberta-xl", "XLMRobertaXLConfig"),
|
| 452 |
+
("xlnet", "XLNetConfig"),
|
| 453 |
+
("xlstm", "xLSTMConfig"),
|
| 454 |
+
("xmod", "XmodConfig"),
|
| 455 |
+
("yolos", "YolosConfig"),
|
| 456 |
+
("yoso", "YosoConfig"),
|
| 457 |
+
("zamba", "ZambaConfig"),
|
| 458 |
+
("zamba2", "Zamba2Config"),
|
| 459 |
+
("zoedepth", "ZoeDepthConfig"),
|
| 460 |
+
]
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
| 465 |
+
[
|
| 466 |
+
# Add full (and cased) model names here
|
| 467 |
+
("aimv2", "AIMv2"),
|
| 468 |
+
("aimv2_vision_model", "Aimv2VisionModel"),
|
| 469 |
+
("albert", "ALBERT"),
|
| 470 |
+
("align", "ALIGN"),
|
| 471 |
+
("altclip", "AltCLIP"),
|
| 472 |
+
("apertus", "Apertus"),
|
| 473 |
+
("arcee", "Arcee"),
|
| 474 |
+
("aria", "Aria"),
|
| 475 |
+
("aria_text", "AriaText"),
|
| 476 |
+
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
|
| 477 |
+
("autoformer", "Autoformer"),
|
| 478 |
+
("aya_vision", "AyaVision"),
|
| 479 |
+
("bamba", "Bamba"),
|
| 480 |
+
("bark", "Bark"),
|
| 481 |
+
("bart", "BART"),
|
| 482 |
+
("barthez", "BARThez"),
|
| 483 |
+
("bartpho", "BARTpho"),
|
| 484 |
+
("beit", "BEiT"),
|
| 485 |
+
("bert", "BERT"),
|
| 486 |
+
("bert-generation", "Bert Generation"),
|
| 487 |
+
("bert-japanese", "BertJapanese"),
|
| 488 |
+
("bertweet", "BERTweet"),
|
| 489 |
+
("big_bird", "BigBird"),
|
| 490 |
+
("bigbird_pegasus", "BigBird-Pegasus"),
|
| 491 |
+
("biogpt", "BioGpt"),
|
| 492 |
+
("bit", "BiT"),
|
| 493 |
+
("bitnet", "BitNet"),
|
| 494 |
+
("blenderbot", "Blenderbot"),
|
| 495 |
+
("blenderbot-small", "BlenderbotSmall"),
|
| 496 |
+
("blip", "BLIP"),
|
| 497 |
+
("blip-2", "BLIP-2"),
|
| 498 |
+
("blip_2_qformer", "BLIP-2 QFormer"),
|
| 499 |
+
("bloom", "BLOOM"),
|
| 500 |
+
("blt", "Blt"),
|
| 501 |
+
("bort", "BORT"),
|
| 502 |
+
("bridgetower", "BridgeTower"),
|
| 503 |
+
("bros", "BROS"),
|
| 504 |
+
("byt5", "ByT5"),
|
| 505 |
+
("camembert", "CamemBERT"),
|
| 506 |
+
("canine", "CANINE"),
|
| 507 |
+
("chameleon", "Chameleon"),
|
| 508 |
+
("chinese_clip", "Chinese-CLIP"),
|
| 509 |
+
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
|
| 510 |
+
("clap", "CLAP"),
|
| 511 |
+
("clip", "CLIP"),
|
| 512 |
+
("clip_text_model", "CLIPTextModel"),
|
| 513 |
+
("clip_vision_model", "CLIPVisionModel"),
|
| 514 |
+
("clipseg", "CLIPSeg"),
|
| 515 |
+
("clvp", "CLVP"),
|
| 516 |
+
("code_llama", "CodeLlama"),
|
| 517 |
+
("codegen", "CodeGen"),
|
| 518 |
+
("cohere", "Cohere"),
|
| 519 |
+
("cohere2", "Cohere2"),
|
| 520 |
+
("cohere2_vision", "Cohere2Vision"),
|
| 521 |
+
("colpali", "ColPali"),
|
| 522 |
+
("colqwen2", "ColQwen2"),
|
| 523 |
+
("conditional_detr", "Conditional DETR"),
|
| 524 |
+
("convbert", "ConvBERT"),
|
| 525 |
+
("convnext", "ConvNeXT"),
|
| 526 |
+
("convnextv2", "ConvNeXTV2"),
|
| 527 |
+
("cpm", "CPM"),
|
| 528 |
+
("cpmant", "CPM-Ant"),
|
| 529 |
+
("csm", "CSM"),
|
| 530 |
+
("ctrl", "CTRL"),
|
| 531 |
+
("cvt", "CvT"),
|
| 532 |
+
("d_fine", "D-FINE"),
|
| 533 |
+
("dab-detr", "DAB-DETR"),
|
| 534 |
+
("dac", "DAC"),
|
| 535 |
+
("data2vec-audio", "Data2VecAudio"),
|
| 536 |
+
("data2vec-text", "Data2VecText"),
|
| 537 |
+
("data2vec-vision", "Data2VecVision"),
|
| 538 |
+
("dbrx", "DBRX"),
|
| 539 |
+
("deberta", "DeBERTa"),
|
| 540 |
+
("deberta-v2", "DeBERTa-v2"),
|
| 541 |
+
("decision_transformer", "Decision Transformer"),
|
| 542 |
+
("deepseek_v2", "DeepSeek-V2"),
|
| 543 |
+
("deepseek_v3", "DeepSeek-V3"),
|
| 544 |
+
("deepseek_vl", "DeepseekVL"),
|
| 545 |
+
("deepseek_vl_hybrid", "DeepseekVLHybrid"),
|
| 546 |
+
("deformable_detr", "Deformable DETR"),
|
| 547 |
+
("deit", "DeiT"),
|
| 548 |
+
("deplot", "DePlot"),
|
| 549 |
+
("depth_anything", "Depth Anything"),
|
| 550 |
+
("depth_anything_v2", "Depth Anything V2"),
|
| 551 |
+
("depth_pro", "DepthPro"),
|
| 552 |
+
("deta", "DETA"),
|
| 553 |
+
("detr", "DETR"),
|
| 554 |
+
("dia", "Dia"),
|
| 555 |
+
("dialogpt", "DialoGPT"),
|
| 556 |
+
("diffllama", "DiffLlama"),
|
| 557 |
+
("dinat", "DiNAT"),
|
| 558 |
+
("dinov2", "DINOv2"),
|
| 559 |
+
("dinov2_with_registers", "DINOv2 with Registers"),
|
| 560 |
+
("dinov3_convnext", "DINOv3 ConvNext"),
|
| 561 |
+
("dinov3_vit", "DINOv3 ViT"),
|
| 562 |
+
("distilbert", "DistilBERT"),
|
| 563 |
+
("dit", "DiT"),
|
| 564 |
+
("doge", "Doge"),
|
| 565 |
+
("donut-swin", "DonutSwin"),
|
| 566 |
+
("dots1", "dots1"),
|
| 567 |
+
("dpr", "DPR"),
|
| 568 |
+
("dpt", "DPT"),
|
| 569 |
+
("edgetam", "EdgeTAM"),
|
| 570 |
+
("edgetam_video", "EdgeTamVideo"),
|
| 571 |
+
("edgetam_vision_model", "EdgeTamVisionModel"),
|
| 572 |
+
("efficientformer", "EfficientFormer"),
|
| 573 |
+
("efficientloftr", "EfficientLoFTR"),
|
| 574 |
+
("efficientnet", "EfficientNet"),
|
| 575 |
+
("electra", "ELECTRA"),
|
| 576 |
+
("emu3", "Emu3"),
|
| 577 |
+
("encodec", "EnCodec"),
|
| 578 |
+
("encoder-decoder", "Encoder decoder"),
|
| 579 |
+
("eomt", "EoMT"),
|
| 580 |
+
("ernie", "ERNIE"),
|
| 581 |
+
("ernie4_5", "Ernie4_5"),
|
| 582 |
+
("ernie4_5_moe", "Ernie4_5_MoE"),
|
| 583 |
+
("ernie_m", "ErnieM"),
|
| 584 |
+
("esm", "ESM"),
|
| 585 |
+
("evolla", "Evolla"),
|
| 586 |
+
("exaone4", "EXAONE-4.0"),
|
| 587 |
+
("falcon", "Falcon"),
|
| 588 |
+
("falcon3", "Falcon3"),
|
| 589 |
+
("falcon_h1", "FalconH1"),
|
| 590 |
+
("falcon_mamba", "FalconMamba"),
|
| 591 |
+
("fastspeech2_conformer", "FastSpeech2Conformer"),
|
| 592 |
+
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
|
| 593 |
+
("flan-t5", "FLAN-T5"),
|
| 594 |
+
("flan-ul2", "FLAN-UL2"),
|
| 595 |
+
("flaubert", "FlauBERT"),
|
| 596 |
+
("flava", "FLAVA"),
|
| 597 |
+
("flex_olmo", "FlexOlmo"),
|
| 598 |
+
("florence2", "Florence2"),
|
| 599 |
+
("fnet", "FNet"),
|
| 600 |
+
("focalnet", "FocalNet"),
|
| 601 |
+
("fsmt", "FairSeq Machine-Translation"),
|
| 602 |
+
("funnel", "Funnel Transformer"),
|
| 603 |
+
("fuyu", "Fuyu"),
|
| 604 |
+
("gemma", "Gemma"),
|
| 605 |
+
("gemma2", "Gemma2"),
|
| 606 |
+
("gemma3", "Gemma3ForConditionalGeneration"),
|
| 607 |
+
("gemma3_text", "Gemma3ForCausalLM"),
|
| 608 |
+
("gemma3n", "Gemma3nForConditionalGeneration"),
|
| 609 |
+
("gemma3n_audio", "Gemma3nAudioEncoder"),
|
| 610 |
+
("gemma3n_text", "Gemma3nForCausalLM"),
|
| 611 |
+
("gemma3n_vision", "TimmWrapperModel"),
|
| 612 |
+
("git", "GIT"),
|
| 613 |
+
("glm", "GLM"),
|
| 614 |
+
("glm4", "GLM4"),
|
| 615 |
+
("glm4_moe", "Glm4MoE"),
|
| 616 |
+
("glm4v", "GLM4V"),
|
| 617 |
+
("glm4v_moe", "GLM4VMOE"),
|
| 618 |
+
("glm4v_moe_text", "GLM4VMOE"),
|
| 619 |
+
("glm4v_text", "GLM4V"),
|
| 620 |
+
("glpn", "GLPN"),
|
| 621 |
+
("got_ocr2", "GOT-OCR2"),
|
| 622 |
+
("gpt-sw3", "GPT-Sw3"),
|
| 623 |
+
("gpt2", "OpenAI GPT-2"),
|
| 624 |
+
("gpt_bigcode", "GPTBigCode"),
|
| 625 |
+
("gpt_neo", "GPT Neo"),
|
| 626 |
+
("gpt_neox", "GPT NeoX"),
|
| 627 |
+
("gpt_neox_japanese", "GPT NeoX Japanese"),
|
| 628 |
+
("gpt_oss", "GptOss"),
|
| 629 |
+
("gptj", "GPT-J"),
|
| 630 |
+
("gptsan-japanese", "GPTSAN-japanese"),
|
| 631 |
+
("granite", "Granite"),
|
| 632 |
+
("granite_speech", "GraniteSpeech"),
|
| 633 |
+
("granitemoe", "GraniteMoeMoe"),
|
| 634 |
+
("granitemoehybrid", "GraniteMoeHybrid"),
|
| 635 |
+
("granitemoeshared", "GraniteMoeSharedMoe"),
|
| 636 |
+
("granitevision", "LLaVA-NeXT"),
|
| 637 |
+
("graphormer", "Graphormer"),
|
| 638 |
+
("grounding-dino", "Grounding DINO"),
|
| 639 |
+
("groupvit", "GroupViT"),
|
| 640 |
+
("helium", "Helium"),
|
| 641 |
+
("herbert", "HerBERT"),
|
| 642 |
+
("hgnet_v2", "HGNet-V2"),
|
| 643 |
+
("hiera", "Hiera"),
|
| 644 |
+
("hubert", "Hubert"),
|
| 645 |
+
("hunyuan_v1_dense", "HunYuanDenseV1"),
|
| 646 |
+
("hunyuan_v1_moe", "HunYuanMoeV1"),
|
| 647 |
+
("ibert", "I-BERT"),
|
| 648 |
+
("idefics", "IDEFICS"),
|
| 649 |
+
("idefics2", "Idefics2"),
|
| 650 |
+
("idefics3", "Idefics3"),
|
| 651 |
+
("idefics3_vision", "Idefics3VisionTransformer"),
|
| 652 |
+
("ijepa", "I-JEPA"),
|
| 653 |
+
("imagegpt", "ImageGPT"),
|
| 654 |
+
("informer", "Informer"),
|
| 655 |
+
("instructblip", "InstructBLIP"),
|
| 656 |
+
("instructblipvideo", "InstructBlipVideo"),
|
| 657 |
+
("internvl", "InternVL"),
|
| 658 |
+
("internvl_vision", "InternVLVision"),
|
| 659 |
+
("jamba", "Jamba"),
|
| 660 |
+
("janus", "Janus"),
|
| 661 |
+
("jetmoe", "JetMoe"),
|
| 662 |
+
("jukebox", "Jukebox"),
|
| 663 |
+
("kosmos-2", "KOSMOS-2"),
|
| 664 |
+
("kosmos-2.5", "KOSMOS-2.5"),
|
| 665 |
+
("kyutai_speech_to_text", "KyutaiSpeechToText"),
|
| 666 |
+
("layoutlm", "LayoutLM"),
|
| 667 |
+
("layoutlmv2", "LayoutLMv2"),
|
| 668 |
+
("layoutlmv3", "LayoutLMv3"),
|
| 669 |
+
("layoutxlm", "LayoutXLM"),
|
| 670 |
+
("led", "LED"),
|
| 671 |
+
("levit", "LeViT"),
|
| 672 |
+
("lfm2", "Lfm2"),
|
| 673 |
+
("lfm2_vl", "Lfm2Vl"),
|
| 674 |
+
("lightglue", "LightGlue"),
|
| 675 |
+
("lilt", "LiLT"),
|
| 676 |
+
("llama", "LLaMA"),
|
| 677 |
+
("llama2", "Llama2"),
|
| 678 |
+
("llama3", "Llama3"),
|
| 679 |
+
("llama4", "Llama4"),
|
| 680 |
+
("llama4_text", "Llama4ForCausalLM"),
|
| 681 |
+
("llava", "LLaVa"),
|
| 682 |
+
("llava_next", "LLaVA-NeXT"),
|
| 683 |
+
("llava_next_video", "LLaVa-NeXT-Video"),
|
| 684 |
+
("llava_onevision", "LLaVA-Onevision"),
|
| 685 |
+
("longcat_flash", "LongCatFlash"),
|
| 686 |
+
("longformer", "Longformer"),
|
| 687 |
+
("longt5", "LongT5"),
|
| 688 |
+
("luke", "LUKE"),
|
| 689 |
+
("lxmert", "LXMERT"),
|
| 690 |
+
("m2m_100", "M2M100"),
|
| 691 |
+
("madlad-400", "MADLAD-400"),
|
| 692 |
+
("mamba", "Mamba"),
|
| 693 |
+
("mamba2", "mamba2"),
|
| 694 |
+
("marian", "Marian"),
|
| 695 |
+
("markuplm", "MarkupLM"),
|
| 696 |
+
("mask2former", "Mask2Former"),
|
| 697 |
+
("maskformer", "MaskFormer"),
|
| 698 |
+
("maskformer-swin", "MaskFormerSwin"),
|
| 699 |
+
("matcha", "MatCha"),
|
| 700 |
+
("mbart", "mBART"),
|
| 701 |
+
("mbart50", "mBART-50"),
|
| 702 |
+
("mctct", "M-CTC-T"),
|
| 703 |
+
("mega", "MEGA"),
|
| 704 |
+
("megatron-bert", "Megatron-BERT"),
|
| 705 |
+
("megatron_gpt2", "Megatron-GPT2"),
|
| 706 |
+
("metaclip_2", "MetaCLIP 2"),
|
| 707 |
+
("mgp-str", "MGP-STR"),
|
| 708 |
+
("mimi", "Mimi"),
|
| 709 |
+
("minimax", "MiniMax"),
|
| 710 |
+
("ministral", "Ministral"),
|
| 711 |
+
("mistral", "Mistral"),
|
| 712 |
+
("mistral3", "Mistral3"),
|
| 713 |
+
("mixtral", "Mixtral"),
|
| 714 |
+
("mlcd", "MLCD"),
|
| 715 |
+
("mllama", "Mllama"),
|
| 716 |
+
("mluke", "mLUKE"),
|
| 717 |
+
("mm-grounding-dino", "MM Grounding DINO"),
|
| 718 |
+
("mms", "MMS"),
|
| 719 |
+
("mobilebert", "MobileBERT"),
|
| 720 |
+
("mobilenet_v1", "MobileNetV1"),
|
| 721 |
+
("mobilenet_v2", "MobileNetV2"),
|
| 722 |
+
("mobilevit", "MobileViT"),
|
| 723 |
+
("mobilevitv2", "MobileViTV2"),
|
| 724 |
+
("modernbert", "ModernBERT"),
|
| 725 |
+
("modernbert-decoder", "ModernBertDecoder"),
|
| 726 |
+
("moonshine", "Moonshine"),
|
| 727 |
+
("moshi", "Moshi"),
|
| 728 |
+
("mpnet", "MPNet"),
|
| 729 |
+
("mpt", "MPT"),
|
| 730 |
+
("mra", "MRA"),
|
| 731 |
+
("mt5", "MT5"),
|
| 732 |
+
("musicgen", "MusicGen"),
|
| 733 |
+
("musicgen_melody", "MusicGen Melody"),
|
| 734 |
+
("mvp", "MVP"),
|
| 735 |
+
("myt5", "myt5"),
|
| 736 |
+
("nat", "NAT"),
|
| 737 |
+
("nemotron", "Nemotron"),
|
| 738 |
+
("nezha", "Nezha"),
|
| 739 |
+
("nllb", "NLLB"),
|
| 740 |
+
("nllb-moe", "NLLB-MOE"),
|
| 741 |
+
("nougat", "Nougat"),
|
| 742 |
+
("nystromformer", "Nyströmformer"),
|
| 743 |
+
("olmo", "OLMo"),
|
| 744 |
+
("olmo2", "OLMo2"),
|
| 745 |
+
("olmo3", "Olmo3"),
|
| 746 |
+
("olmoe", "OLMoE"),
|
| 747 |
+
("omdet-turbo", "OmDet-Turbo"),
|
| 748 |
+
("oneformer", "OneFormer"),
|
| 749 |
+
("open-llama", "OpenLlama"),
|
| 750 |
+
("openai-gpt", "OpenAI GPT"),
|
| 751 |
+
("opt", "OPT"),
|
| 752 |
+
("ovis2", "Ovis2"),
|
| 753 |
+
("owlv2", "OWLv2"),
|
| 754 |
+
("owlvit", "OWL-ViT"),
|
| 755 |
+
("paligemma", "PaliGemma"),
|
| 756 |
+
("parakeet", "Parakeet"),
|
| 757 |
+
("parakeet_ctc", "Parakeet"),
|
| 758 |
+
("parakeet_encoder", "ParakeetEncoder"),
|
| 759 |
+
("patchtsmixer", "PatchTSMixer"),
|
| 760 |
+
("patchtst", "PatchTST"),
|
| 761 |
+
("pegasus", "Pegasus"),
|
| 762 |
+
("pegasus_x", "PEGASUS-X"),
|
| 763 |
+
("perceiver", "Perceiver"),
|
| 764 |
+
("perception_encoder", "PerceptionEncoder"),
|
| 765 |
+
("perception_lm", "PerceptionLM"),
|
| 766 |
+
("persimmon", "Persimmon"),
|
| 767 |
+
("phi", "Phi"),
|
| 768 |
+
("phi3", "Phi3"),
|
| 769 |
+
("phi4_multimodal", "Phi4Multimodal"),
|
| 770 |
+
("phimoe", "Phimoe"),
|
| 771 |
+
("phobert", "PhoBERT"),
|
| 772 |
+
("pix2struct", "Pix2Struct"),
|
| 773 |
+
("pixtral", "Pixtral"),
|
| 774 |
+
("plbart", "PLBart"),
|
| 775 |
+
("poolformer", "PoolFormer"),
|
| 776 |
+
("pop2piano", "Pop2Piano"),
|
| 777 |
+
("prompt_depth_anything", "PromptDepthAnything"),
|
| 778 |
+
("prophetnet", "ProphetNet"),
|
| 779 |
+
("pvt", "PVT"),
|
| 780 |
+
("pvt_v2", "PVTv2"),
|
| 781 |
+
("qdqbert", "QDQBert"),
|
| 782 |
+
("qwen2", "Qwen2"),
|
| 783 |
+
("qwen2_5_omni", "Qwen2_5Omni"),
|
| 784 |
+
("qwen2_5_vl", "Qwen2_5_VL"),
|
| 785 |
+
("qwen2_5_vl_text", "Qwen2_5_VL"),
|
| 786 |
+
("qwen2_audio", "Qwen2Audio"),
|
| 787 |
+
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
|
| 788 |
+
("qwen2_moe", "Qwen2MoE"),
|
| 789 |
+
("qwen2_vl", "Qwen2VL"),
|
| 790 |
+
("qwen2_vl_text", "Qwen2VL"),
|
| 791 |
+
("qwen3", "Qwen3"),
|
| 792 |
+
("qwen3_moe", "Qwen3MoE"),
|
| 793 |
+
("qwen3_next", "Qwen3Next"),
|
| 794 |
+
("qwen3_omni_moe", "Qwen3OmniMoE"),
|
| 795 |
+
("qwen3_vl", "Qwen3VL"),
|
| 796 |
+
("qwen3_vl_moe", "Qwen3VLMoe"),
|
| 797 |
+
("qwen3_vl_moe_text", "Qwen3VLMoe"),
|
| 798 |
+
("qwen3_vl_text", "Qwen3VL"),
|
| 799 |
+
("rag", "RAG"),
|
| 800 |
+
("realm", "REALM"),
|
| 801 |
+
("recurrent_gemma", "RecurrentGemma"),
|
| 802 |
+
("reformer", "Reformer"),
|
| 803 |
+
("regnet", "RegNet"),
|
| 804 |
+
("rembert", "RemBERT"),
|
| 805 |
+
("resnet", "ResNet"),
|
| 806 |
+
("retribert", "RetriBERT"),
|
| 807 |
+
("roberta", "RoBERTa"),
|
| 808 |
+
("roberta-prelayernorm", "RoBERTa-PreLayerNorm"),
|
| 809 |
+
("roc_bert", "RoCBert"),
|
| 810 |
+
("roformer", "RoFormer"),
|
| 811 |
+
("rt_detr", "RT-DETR"),
|
| 812 |
+
("rt_detr_resnet", "RT-DETR-ResNet"),
|
| 813 |
+
("rt_detr_v2", "RT-DETRv2"),
|
| 814 |
+
("rwkv", "RWKV"),
|
| 815 |
+
("sam", "SAM"),
|
| 816 |
+
("sam2", "SAM2"),
|
| 817 |
+
("sam2_hiera_det_model", "Sam2HieraDetModel"),
|
| 818 |
+
("sam2_video", "Sam2VideoModel"),
|
| 819 |
+
("sam2_vision_model", "Sam2VisionModel"),
|
| 820 |
+
("sam_hq", "SAM-HQ"),
|
| 821 |
+
("sam_hq_vision_model", "SamHQVisionModel"),
|
| 822 |
+
("sam_vision_model", "SamVisionModel"),
|
| 823 |
+
("seamless_m4t", "SeamlessM4T"),
|
| 824 |
+
("seamless_m4t_v2", "SeamlessM4Tv2"),
|
| 825 |
+
("seed_oss", "SeedOss"),
|
| 826 |
+
("segformer", "SegFormer"),
|
| 827 |
+
("seggpt", "SegGPT"),
|
| 828 |
+
("sew", "SEW"),
|
| 829 |
+
("sew-d", "SEW-D"),
|
| 830 |
+
("shieldgemma2", "Shieldgemma2"),
|
| 831 |
+
("siglip", "SigLIP"),
|
| 832 |
+
("siglip2", "SigLIP2"),
|
| 833 |
+
("siglip2_vision_model", "Siglip2VisionModel"),
|
| 834 |
+
("siglip_vision_model", "SiglipVisionModel"),
|
| 835 |
+
("smollm3", "SmolLM3"),
|
| 836 |
+
("smolvlm", "SmolVLM"),
|
| 837 |
+
("smolvlm_vision", "SmolVLMVisionTransformer"),
|
| 838 |
+
("speech-encoder-decoder", "Speech Encoder decoder"),
|
| 839 |
+
("speech_to_text", "Speech2Text"),
|
| 840 |
+
("speech_to_text_2", "Speech2Text2"),
|
| 841 |
+
("speecht5", "SpeechT5"),
|
| 842 |
+
("splinter", "Splinter"),
|
| 843 |
+
("squeezebert", "SqueezeBERT"),
|
| 844 |
+
("stablelm", "StableLm"),
|
| 845 |
+
("starcoder2", "Starcoder2"),
|
| 846 |
+
("superglue", "SuperGlue"),
|
| 847 |
+
("superpoint", "SuperPoint"),
|
| 848 |
+
("swiftformer", "SwiftFormer"),
|
| 849 |
+
("swin", "Swin Transformer"),
|
| 850 |
+
("swin2sr", "Swin2SR"),
|
| 851 |
+
("swinv2", "Swin Transformer V2"),
|
| 852 |
+
("switch_transformers", "SwitchTransformers"),
|
| 853 |
+
("t5", "T5"),
|
| 854 |
+
("t5gemma", "T5Gemma"),
|
| 855 |
+
("t5v1.1", "T5v1.1"),
|
| 856 |
+
("table-transformer", "Table Transformer"),
|
| 857 |
+
("tapas", "TAPAS"),
|
| 858 |
+
("tapex", "TAPEX"),
|
| 859 |
+
("textnet", "TextNet"),
|
| 860 |
+
("time_series_transformer", "Time Series Transformer"),
|
| 861 |
+
("timesfm", "TimesFm"),
|
| 862 |
+
("timesformer", "TimeSformer"),
|
| 863 |
+
("timm_backbone", "TimmBackbone"),
|
| 864 |
+
("timm_wrapper", "TimmWrapperModel"),
|
| 865 |
+
("trajectory_transformer", "Trajectory Transformer"),
|
| 866 |
+
("transfo-xl", "Transformer-XL"),
|
| 867 |
+
("trocr", "TrOCR"),
|
| 868 |
+
("tvlt", "TVLT"),
|
| 869 |
+
("tvp", "TVP"),
|
| 870 |
+
("udop", "UDOP"),
|
| 871 |
+
("ul2", "UL2"),
|
| 872 |
+
("umt5", "UMT5"),
|
| 873 |
+
("unispeech", "UniSpeech"),
|
| 874 |
+
("unispeech-sat", "UniSpeechSat"),
|
| 875 |
+
("univnet", "UnivNet"),
|
| 876 |
+
("upernet", "UPerNet"),
|
| 877 |
+
("van", "VAN"),
|
| 878 |
+
("vaultgemma", "VaultGemma"),
|
| 879 |
+
("video_llava", "VideoLlava"),
|
| 880 |
+
("videomae", "VideoMAE"),
|
| 881 |
+
("vilt", "ViLT"),
|
| 882 |
+
("vipllava", "VipLlava"),
|
| 883 |
+
("vision-encoder-decoder", "Vision Encoder decoder"),
|
| 884 |
+
("vision-text-dual-encoder", "VisionTextDualEncoder"),
|
| 885 |
+
("visual_bert", "VisualBERT"),
|
| 886 |
+
("vit", "ViT"),
|
| 887 |
+
("vit_hybrid", "ViT Hybrid"),
|
| 888 |
+
("vit_mae", "ViTMAE"),
|
| 889 |
+
("vit_msn", "ViTMSN"),
|
| 890 |
+
("vitdet", "VitDet"),
|
| 891 |
+
("vitmatte", "ViTMatte"),
|
| 892 |
+
("vitpose", "ViTPose"),
|
| 893 |
+
("vitpose_backbone", "ViTPoseBackbone"),
|
| 894 |
+
("vits", "VITS"),
|
| 895 |
+
("vivit", "ViViT"),
|
| 896 |
+
("vjepa2", "VJEPA2Model"),
|
| 897 |
+
("voxtral", "Voxtral"),
|
| 898 |
+
("voxtral_encoder", "Voxtral Encoder"),
|
| 899 |
+
("wav2vec2", "Wav2Vec2"),
|
| 900 |
+
("wav2vec2-bert", "Wav2Vec2-BERT"),
|
| 901 |
+
("wav2vec2-conformer", "Wav2Vec2-Conformer"),
|
| 902 |
+
("wav2vec2_phoneme", "Wav2Vec2Phoneme"),
|
| 903 |
+
("wavlm", "WavLM"),
|
| 904 |
+
("whisper", "Whisper"),
|
| 905 |
+
("xclip", "X-CLIP"),
|
| 906 |
+
("xcodec", "X-CODEC"),
|
| 907 |
+
("xglm", "XGLM"),
|
| 908 |
+
("xlm", "XLM"),
|
| 909 |
+
("xlm-prophetnet", "XLM-ProphetNet"),
|
| 910 |
+
("xlm-roberta", "XLM-RoBERTa"),
|
| 911 |
+
("xlm-roberta-xl", "XLM-RoBERTa-XL"),
|
| 912 |
+
("xlm-v", "XLM-V"),
|
| 913 |
+
("xlnet", "XLNet"),
|
| 914 |
+
("xls_r", "XLS-R"),
|
| 915 |
+
("xlsr_wav2vec2", "XLSR-Wav2Vec2"),
|
| 916 |
+
("xlstm", "xLSTM"),
|
| 917 |
+
("xmod", "X-MOD"),
|
| 918 |
+
("yolos", "YOLOS"),
|
| 919 |
+
("yoso", "YOSO"),
|
| 920 |
+
("zamba", "Zamba"),
|
| 921 |
+
("zamba2", "Zamba2"),
|
| 922 |
+
("zoedepth", "ZoeDepth"),
|
| 923 |
+
]
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
# This is tied to the processing `-` -> `_` in `model_type_to_module_name`. For example, instead of putting
|
| 927 |
+
# `transfo-xl` (as in `CONFIG_MAPPING_NAMES`), we should use `transfo_xl`.
|
| 928 |
+
DEPRECATED_MODELS = [
|
| 929 |
+
"bort",
|
| 930 |
+
"deta",
|
| 931 |
+
"efficientformer",
|
| 932 |
+
"ernie_m",
|
| 933 |
+
"gptsan_japanese",
|
| 934 |
+
"graphormer",
|
| 935 |
+
"jukebox",
|
| 936 |
+
"mctct",
|
| 937 |
+
"mega",
|
| 938 |
+
"mmbt",
|
| 939 |
+
"nat",
|
| 940 |
+
"nezha",
|
| 941 |
+
"open_llama",
|
| 942 |
+
"qdqbert",
|
| 943 |
+
"realm",
|
| 944 |
+
"retribert",
|
| 945 |
+
"speech_to_text_2",
|
| 946 |
+
"tapex",
|
| 947 |
+
"trajectory_transformer",
|
| 948 |
+
"transfo_xl",
|
| 949 |
+
"tvlt",
|
| 950 |
+
"van",
|
| 951 |
+
"vit_hybrid",
|
| 952 |
+
"xlm_prophetnet",
|
| 953 |
+
]
|
| 954 |
+
|
| 955 |
+
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
|
| 956 |
+
[
|
| 957 |
+
("openai-gpt", "openai"),
|
| 958 |
+
("data2vec-audio", "data2vec"),
|
| 959 |
+
("data2vec-text", "data2vec"),
|
| 960 |
+
("data2vec-vision", "data2vec"),
|
| 961 |
+
("donut-swin", "donut"),
|
| 962 |
+
("kosmos-2", "kosmos2"),
|
| 963 |
+
("kosmos-2.5", "kosmos2_5"),
|
| 964 |
+
("maskformer-swin", "maskformer"),
|
| 965 |
+
("xclip", "x_clip"),
|
| 966 |
+
("clip_vision_model", "clip"),
|
| 967 |
+
("qwen2_audio_encoder", "qwen2_audio"),
|
| 968 |
+
("voxtral_encoder", "voxtral"),
|
| 969 |
+
("clip_text_model", "clip"),
|
| 970 |
+
("aria_text", "aria"),
|
| 971 |
+
("gemma3_text", "gemma3"),
|
| 972 |
+
("gemma3n_audio", "gemma3n"),
|
| 973 |
+
("gemma3n_text", "gemma3n"),
|
| 974 |
+
("gemma3n_vision", "gemma3n"),
|
| 975 |
+
("glm4v_text", "glm4v"),
|
| 976 |
+
("glm4v_moe_text", "glm4v_moe"),
|
| 977 |
+
("idefics3_vision", "idefics3"),
|
| 978 |
+
("siglip_vision_model", "siglip"),
|
| 979 |
+
("siglip2_vision_model", "siglip2"),
|
| 980 |
+
("aimv2_vision_model", "aimv2"),
|
| 981 |
+
("smolvlm_vision", "smolvlm"),
|
| 982 |
+
("chinese_clip_vision_model", "chinese_clip"),
|
| 983 |
+
("rt_detr_resnet", "rt_detr"),
|
| 984 |
+
("granitevision", "llava_next"),
|
| 985 |
+
("internvl_vision", "internvl"),
|
| 986 |
+
("qwen2_5_vl_text", "qwen2_5_vl"),
|
| 987 |
+
("qwen2_vl_text", "qwen2_vl"),
|
| 988 |
+
("qwen3_vl_text", "qwen3_vl"),
|
| 989 |
+
("qwen3_vl_moe_text", "qwen3_vl_moe"),
|
| 990 |
+
("sam_vision_model", "sam"),
|
| 991 |
+
("sam2_vision_model", "sam2"),
|
| 992 |
+
("edgetam_vision_model", "edgetam"),
|
| 993 |
+
("sam2_hiera_det_model", "sam2"),
|
| 994 |
+
("sam_hq_vision_model", "sam_hq"),
|
| 995 |
+
("llama4_text", "llama4"),
|
| 996 |
+
("blip_2_qformer", "blip_2"),
|
| 997 |
+
("fastspeech2_conformer_with_hifigan", "fastspeech2_conformer"),
|
| 998 |
+
("perception_encoder", "perception_lm"),
|
| 999 |
+
("parakeet_encoder", "parakeet"),
|
| 1000 |
+
("parakeet_ctc", "parakeet"),
|
| 1001 |
+
]
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
def model_type_to_module_name(key) -> str:
|
| 1006 |
+
"""Converts a config key to the corresponding module."""
|
| 1007 |
+
# Special treatment
|
| 1008 |
+
if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
|
| 1009 |
+
key = SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]
|
| 1010 |
+
|
| 1011 |
+
if key in DEPRECATED_MODELS:
|
| 1012 |
+
key = f"deprecated.{key}"
|
| 1013 |
+
return key
|
| 1014 |
+
|
| 1015 |
+
key = key.replace("-", "_")
|
| 1016 |
+
if key in DEPRECATED_MODELS:
|
| 1017 |
+
key = f"deprecated.{key}"
|
| 1018 |
+
|
| 1019 |
+
return key
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
def config_class_to_model_type(config) -> Union[str, None]:
|
| 1023 |
+
"""Converts a config class name to the corresponding model type"""
|
| 1024 |
+
for key, cls in CONFIG_MAPPING_NAMES.items():
|
| 1025 |
+
if cls == config:
|
| 1026 |
+
return key
|
| 1027 |
+
# if key not found check in extra content
|
| 1028 |
+
for key, cls in CONFIG_MAPPING._extra_content.items():
|
| 1029 |
+
if cls.__name__ == config:
|
| 1030 |
+
return key
|
| 1031 |
+
return None
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
class _LazyConfigMapping(OrderedDict[str, type[PretrainedConfig]]):
|
| 1035 |
+
"""
|
| 1036 |
+
A dictionary that lazily load its values when they are requested.
|
| 1037 |
+
"""
|
| 1038 |
+
|
| 1039 |
+
def __init__(self, mapping) -> None:
|
| 1040 |
+
self._mapping = mapping
|
| 1041 |
+
self._extra_content = {}
|
| 1042 |
+
self._modules = {}
|
| 1043 |
+
|
| 1044 |
+
def __getitem__(self, key: str) -> type[PretrainedConfig]:
|
| 1045 |
+
if key in self._extra_content:
|
| 1046 |
+
return self._extra_content[key]
|
| 1047 |
+
if key not in self._mapping:
|
| 1048 |
+
raise KeyError(key)
|
| 1049 |
+
value = self._mapping[key]
|
| 1050 |
+
module_name = model_type_to_module_name(key)
|
| 1051 |
+
if module_name not in self._modules:
|
| 1052 |
+
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
|
| 1053 |
+
if hasattr(self._modules[module_name], value):
|
| 1054 |
+
return getattr(self._modules[module_name], value)
|
| 1055 |
+
|
| 1056 |
+
# Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the
|
| 1057 |
+
# object at the top level.
|
| 1058 |
+
transformers_module = importlib.import_module("transformers")
|
| 1059 |
+
return getattr(transformers_module, value)
|
| 1060 |
+
|
| 1061 |
+
def keys(self) -> list[str]:
|
| 1062 |
+
return list(self._mapping.keys()) + list(self._extra_content.keys())
|
| 1063 |
+
|
| 1064 |
+
def values(self) -> list[type[PretrainedConfig]]:
|
| 1065 |
+
return [self[k] for k in self._mapping] + list(self._extra_content.values())
|
| 1066 |
+
|
| 1067 |
+
def items(self) -> list[tuple[str, type[PretrainedConfig]]]:
|
| 1068 |
+
return [(k, self[k]) for k in self._mapping] + list(self._extra_content.items())
|
| 1069 |
+
|
| 1070 |
+
def __iter__(self) -> Iterator[str]:
|
| 1071 |
+
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
|
| 1072 |
+
|
| 1073 |
+
def __contains__(self, item: object) -> bool:
|
| 1074 |
+
return item in self._mapping or item in self._extra_content
|
| 1075 |
+
|
| 1076 |
+
def register(self, key: str, value: type[PretrainedConfig], exist_ok=False) -> None:
|
| 1077 |
+
"""
|
| 1078 |
+
Register a new configuration in this mapping.
|
| 1079 |
+
"""
|
| 1080 |
+
if key in self._mapping and not exist_ok:
|
| 1081 |
+
raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
|
| 1082 |
+
self._extra_content[key] = value
|
| 1083 |
+
|
| 1084 |
+
|
| 1085 |
+
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
| 1086 |
+
|
| 1087 |
+
|
| 1088 |
+
class _LazyLoadAllMappings(OrderedDict[str, str]):
|
| 1089 |
+
"""
|
| 1090 |
+
A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
|
| 1091 |
+
etc.)
|
| 1092 |
+
|
| 1093 |
+
Args:
|
| 1094 |
+
mapping: The mapping to load.
|
| 1095 |
+
"""
|
| 1096 |
+
|
| 1097 |
+
def __init__(self, mapping):
|
| 1098 |
+
self._mapping = mapping
|
| 1099 |
+
self._initialized = False
|
| 1100 |
+
self._data = {}
|
| 1101 |
+
|
| 1102 |
+
def _initialize(self):
|
| 1103 |
+
if self._initialized:
|
| 1104 |
+
return
|
| 1105 |
+
|
| 1106 |
+
for model_type, map_name in self._mapping.items():
|
| 1107 |
+
module_name = model_type_to_module_name(model_type)
|
| 1108 |
+
module = importlib.import_module(f".{module_name}", "transformers.models")
|
| 1109 |
+
mapping = getattr(module, map_name)
|
| 1110 |
+
self._data.update(mapping)
|
| 1111 |
+
|
| 1112 |
+
self._initialized = True
|
| 1113 |
+
|
| 1114 |
+
def __getitem__(self, key):
|
| 1115 |
+
self._initialize()
|
| 1116 |
+
return self._data[key]
|
| 1117 |
+
|
| 1118 |
+
def keys(self) -> KeysView[str]:
|
| 1119 |
+
self._initialize()
|
| 1120 |
+
return self._data.keys()
|
| 1121 |
+
|
| 1122 |
+
def values(self) -> ValuesView[str]:
|
| 1123 |
+
self._initialize()
|
| 1124 |
+
return self._data.values()
|
| 1125 |
+
|
| 1126 |
+
def items(self) -> KeysView[str]:
|
| 1127 |
+
self._initialize()
|
| 1128 |
+
return self._data.keys()
|
| 1129 |
+
|
| 1130 |
+
def __iter__(self) -> Iterator[str]:
|
| 1131 |
+
self._initialize()
|
| 1132 |
+
return iter(self._data)
|
| 1133 |
+
|
| 1134 |
+
def __contains__(self, item: object) -> bool:
|
| 1135 |
+
self._initialize()
|
| 1136 |
+
return item in self._data
|
| 1137 |
+
|
| 1138 |
+
|
| 1139 |
+
def _get_class_name(model_class: Union[str, list[str]]):
|
| 1140 |
+
if isinstance(model_class, (list, tuple)):
|
| 1141 |
+
return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
|
| 1142 |
+
return f"[`{model_class}`]"
|
| 1143 |
+
|
| 1144 |
+
|
| 1145 |
+
def _list_model_options(indent, config_to_class=None, use_model_types=True):
|
| 1146 |
+
if config_to_class is None and not use_model_types:
|
| 1147 |
+
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
|
| 1148 |
+
if use_model_types:
|
| 1149 |
+
if config_to_class is None:
|
| 1150 |
+
model_type_to_name = {model_type: f"[`{config}`]" for model_type, config in CONFIG_MAPPING_NAMES.items()}
|
| 1151 |
+
else:
|
| 1152 |
+
model_type_to_name = {
|
| 1153 |
+
model_type: _get_class_name(model_class)
|
| 1154 |
+
for model_type, model_class in config_to_class.items()
|
| 1155 |
+
if model_type in MODEL_NAMES_MAPPING
|
| 1156 |
+
}
|
| 1157 |
+
lines = [
|
| 1158 |
+
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
|
| 1159 |
+
for model_type in sorted(model_type_to_name.keys())
|
| 1160 |
+
]
|
| 1161 |
+
else:
|
| 1162 |
+
config_to_name = {
|
| 1163 |
+
CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
|
| 1164 |
+
for config, clas in config_to_class.items()
|
| 1165 |
+
if config in CONFIG_MAPPING_NAMES
|
| 1166 |
+
}
|
| 1167 |
+
config_to_model_name = {
|
| 1168 |
+
config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
|
| 1169 |
+
}
|
| 1170 |
+
lines = [
|
| 1171 |
+
f"{indent}- [`{config_name}`] configuration class:"
|
| 1172 |
+
f" {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
|
| 1173 |
+
for config_name in sorted(config_to_name.keys())
|
| 1174 |
+
]
|
| 1175 |
+
return "\n".join(lines)
|
| 1176 |
+
|
| 1177 |
+
|
| 1178 |
+
def replace_list_option_in_docstrings(
|
| 1179 |
+
config_to_class=None, use_model_types: bool = True
|
| 1180 |
+
) -> Callable[[_CallableT], _CallableT]:
|
| 1181 |
+
def docstring_decorator(fn):
|
| 1182 |
+
docstrings = fn.__doc__
|
| 1183 |
+
if docstrings is None:
|
| 1184 |
+
# Example: -OO
|
| 1185 |
+
return fn
|
| 1186 |
+
lines = docstrings.split("\n")
|
| 1187 |
+
i = 0
|
| 1188 |
+
while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None:
|
| 1189 |
+
i += 1
|
| 1190 |
+
if i < len(lines):
|
| 1191 |
+
indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0]
|
| 1192 |
+
if use_model_types:
|
| 1193 |
+
indent = f"{indent} "
|
| 1194 |
+
lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types)
|
| 1195 |
+
docstrings = "\n".join(lines)
|
| 1196 |
+
else:
|
| 1197 |
+
raise ValueError(
|
| 1198 |
+
f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current"
|
| 1199 |
+
f" docstring is:\n{docstrings}"
|
| 1200 |
+
)
|
| 1201 |
+
fn.__doc__ = docstrings
|
| 1202 |
+
return fn
|
| 1203 |
+
|
| 1204 |
+
return docstring_decorator
|
| 1205 |
+
|
| 1206 |
+
|
| 1207 |
+
class AutoConfig:
|
| 1208 |
+
r"""
|
| 1209 |
+
This is a generic configuration class that will be instantiated as one of the configuration classes of the library
|
| 1210 |
+
when created with the [`~AutoConfig.from_pretrained`] class method.
|
| 1211 |
+
|
| 1212 |
+
This class cannot be instantiated directly using `__init__()` (throws an error).
|
| 1213 |
+
"""
|
| 1214 |
+
|
| 1215 |
+
def __init__(self) -> None:
|
| 1216 |
+
raise OSError(
|
| 1217 |
+
"AutoConfig is designed to be instantiated "
|
| 1218 |
+
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
|
| 1219 |
+
)
|
| 1220 |
+
|
| 1221 |
+
@classmethod
|
| 1222 |
+
def for_model(cls, model_type: str, *args, **kwargs) -> PretrainedConfig:
|
| 1223 |
+
if model_type in CONFIG_MAPPING:
|
| 1224 |
+
config_class = CONFIG_MAPPING[model_type]
|
| 1225 |
+
return config_class(*args, **kwargs)
|
| 1226 |
+
raise ValueError(
|
| 1227 |
+
f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
|
| 1228 |
+
)
|
| 1229 |
+
|
| 1230 |
+
@classmethod
|
| 1231 |
+
@replace_list_option_in_docstrings()
|
| 1232 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], **kwargs):
|
| 1233 |
+
r"""
|
| 1234 |
+
Instantiate one of the configuration classes of the library from a pretrained model configuration.
|
| 1235 |
+
|
| 1236 |
+
The configuration class to instantiate is selected based on the `model_type` property of the config object that
|
| 1237 |
+
is loaded, or when it's missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 1238 |
+
|
| 1239 |
+
List options
|
| 1240 |
+
|
| 1241 |
+
Args:
|
| 1242 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 1243 |
+
Can be either:
|
| 1244 |
+
|
| 1245 |
+
- A string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 1246 |
+
huggingface.co.
|
| 1247 |
+
- A path to a *directory* containing a configuration file saved using the
|
| 1248 |
+
[`~PretrainedConfig.save_pretrained`] method, or the [`~PreTrainedModel.save_pretrained`] method,
|
| 1249 |
+
e.g., `./my_model_directory/`.
|
| 1250 |
+
- A path or url to a saved configuration JSON *file*, e.g.,
|
| 1251 |
+
`./my_model_directory/configuration.json`.
|
| 1252 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 1253 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 1254 |
+
standard cache should not be used.
|
| 1255 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 1256 |
+
Whether or not to force the (re-)download the model weights and configuration files and override the
|
| 1257 |
+
cached versions if they exist.
|
| 1258 |
+
resume_download:
|
| 1259 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 1260 |
+
Will be removed in v5 of Transformers.
|
| 1261 |
+
proxies (`dict[str, str]`, *optional*):
|
| 1262 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 1263 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 1264 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 1265 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 1266 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 1267 |
+
identifier allowed by git.
|
| 1268 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 1269 |
+
If `False`, then this function returns just the final configuration object.
|
| 1270 |
+
|
| 1271 |
+
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
|
| 1272 |
+
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
|
| 1273 |
+
part of `kwargs` which has not been used to update `config` and is otherwise ignored.
|
| 1274 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 1275 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 1276 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 1277 |
+
execute code present on the Hub on your local machine.
|
| 1278 |
+
kwargs(additional keyword arguments, *optional*):
|
| 1279 |
+
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
|
| 1280 |
+
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
| 1281 |
+
by the `return_unused_kwargs` keyword parameter.
|
| 1282 |
+
|
| 1283 |
+
Examples:
|
| 1284 |
+
|
| 1285 |
+
```python
|
| 1286 |
+
>>> from transformers import AutoConfig
|
| 1287 |
+
|
| 1288 |
+
>>> # Download configuration from huggingface.co and cache.
|
| 1289 |
+
>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")
|
| 1290 |
+
|
| 1291 |
+
>>> # Download configuration from huggingface.co (user-uploaded) and cache.
|
| 1292 |
+
>>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")
|
| 1293 |
+
|
| 1294 |
+
>>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
|
| 1295 |
+
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/")
|
| 1296 |
+
|
| 1297 |
+
>>> # Load a specific configuration file.
|
| 1298 |
+
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")
|
| 1299 |
+
|
| 1300 |
+
>>> # Change some config attributes when loading a pretrained config.
|
| 1301 |
+
>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
|
| 1302 |
+
>>> config.output_attentions
|
| 1303 |
+
True
|
| 1304 |
+
|
| 1305 |
+
>>> config, unused_kwargs = AutoConfig.from_pretrained(
|
| 1306 |
+
... "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
|
| 1307 |
+
... )
|
| 1308 |
+
>>> config.output_attentions
|
| 1309 |
+
True
|
| 1310 |
+
|
| 1311 |
+
>>> unused_kwargs
|
| 1312 |
+
{'foo': False}
|
| 1313 |
+
```
|
| 1314 |
+
"""
|
| 1315 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 1316 |
+
if use_auth_token is not None:
|
| 1317 |
+
warnings.warn(
|
| 1318 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 1319 |
+
FutureWarning,
|
| 1320 |
+
)
|
| 1321 |
+
if kwargs.get("token") is not None:
|
| 1322 |
+
raise ValueError(
|
| 1323 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 1324 |
+
)
|
| 1325 |
+
kwargs["token"] = use_auth_token
|
| 1326 |
+
|
| 1327 |
+
kwargs["_from_auto"] = True
|
| 1328 |
+
kwargs["name_or_path"] = pretrained_model_name_or_path
|
| 1329 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 1330 |
+
code_revision = kwargs.pop("code_revision", None)
|
| 1331 |
+
|
| 1332 |
+
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 1333 |
+
has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
|
| 1334 |
+
has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING
|
| 1335 |
+
if has_remote_code:
|
| 1336 |
+
class_ref = config_dict["auto_map"]["AutoConfig"]
|
| 1337 |
+
if "--" in class_ref:
|
| 1338 |
+
upstream_repo = class_ref.split("--")[0]
|
| 1339 |
+
else:
|
| 1340 |
+
upstream_repo = None
|
| 1341 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 1342 |
+
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
|
| 1343 |
+
)
|
| 1344 |
+
|
| 1345 |
+
if has_remote_code and trust_remote_code:
|
| 1346 |
+
config_class = get_class_from_dynamic_module(
|
| 1347 |
+
class_ref, pretrained_model_name_or_path, code_revision=code_revision, **kwargs
|
| 1348 |
+
)
|
| 1349 |
+
config_class.register_for_auto_class()
|
| 1350 |
+
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 1351 |
+
elif "model_type" in config_dict:
|
| 1352 |
+
# Apply heuristic: if model_type is mistral but layer_types is present, treat as ministral
|
| 1353 |
+
if config_dict["model_type"] == "mistral" and "layer_types" in config_dict:
|
| 1354 |
+
logger.info(
|
| 1355 |
+
"Detected mistral model with layer_types, treating as ministral for alternating attention compatibility. "
|
| 1356 |
+
)
|
| 1357 |
+
config_dict["model_type"] = "ministral"
|
| 1358 |
+
|
| 1359 |
+
try:
|
| 1360 |
+
config_class = CONFIG_MAPPING[config_dict["model_type"]]
|
| 1361 |
+
except KeyError:
|
| 1362 |
+
raise ValueError(
|
| 1363 |
+
f"The checkpoint you are trying to load has model type `{config_dict['model_type']}` "
|
| 1364 |
+
"but Transformers does not recognize this architecture. This could be because of an "
|
| 1365 |
+
"issue with the checkpoint, or because your version of Transformers is out of date.\n\n"
|
| 1366 |
+
"You can update Transformers with the command `pip install --upgrade transformers`. If this "
|
| 1367 |
+
"does not work, and the checkpoint is very new, then there may not be a release version "
|
| 1368 |
+
"that supports this model yet. In this case, you can get the most up-to-date code by installing "
|
| 1369 |
+
"Transformers from source with the command "
|
| 1370 |
+
"`pip install git+https://github.com/huggingface/transformers.git`"
|
| 1371 |
+
)
|
| 1372 |
+
return config_class.from_dict(config_dict, **unused_kwargs)
|
| 1373 |
+
else:
|
| 1374 |
+
# Fallback: use pattern matching on the string.
|
| 1375 |
+
# We go from longer names to shorter names to catch roberta before bert (for instance)
|
| 1376 |
+
for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True):
|
| 1377 |
+
if pattern in str(pretrained_model_name_or_path):
|
| 1378 |
+
return CONFIG_MAPPING[pattern].from_dict(config_dict, **unused_kwargs)
|
| 1379 |
+
|
| 1380 |
+
raise ValueError(
|
| 1381 |
+
f"Unrecognized model in {pretrained_model_name_or_path}. "
|
| 1382 |
+
f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
|
| 1383 |
+
f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
|
| 1384 |
+
)
|
| 1385 |
+
|
| 1386 |
+
@staticmethod
|
| 1387 |
+
def register(model_type, config, exist_ok=False) -> None:
|
| 1388 |
+
"""
|
| 1389 |
+
Register a new configuration for this class.
|
| 1390 |
+
|
| 1391 |
+
Args:
|
| 1392 |
+
model_type (`str`): The model type like "bert" or "gpt".
|
| 1393 |
+
config ([`PretrainedConfig`]): The config to register.
|
| 1394 |
+
"""
|
| 1395 |
+
if issubclass(config, PretrainedConfig) and config.model_type != model_type:
|
| 1396 |
+
raise ValueError(
|
| 1397 |
+
"The config you are passing has a `model_type` attribute that is not consistent with the model type "
|
| 1398 |
+
f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
|
| 1399 |
+
"match!"
|
| 1400 |
+
)
|
| 1401 |
+
CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok)
|
| 1402 |
+
|
| 1403 |
+
|
| 1404 |
+
__all__ = ["CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"]
|
venv/lib/python3.13/site-packages/transformers/models/auto/feature_extraction_auto.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""AutoFeatureExtractor class."""
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import warnings
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
from typing import Optional, Union
|
| 23 |
+
|
| 24 |
+
# Build the list of all feature extractors
|
| 25 |
+
from ...configuration_utils import PretrainedConfig
|
| 26 |
+
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
| 27 |
+
from ...feature_extraction_utils import FeatureExtractionMixin
|
| 28 |
+
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, cached_file, logging
|
| 29 |
+
from .auto_factory import _LazyAutoMapping
|
| 30 |
+
from .configuration_auto import (
|
| 31 |
+
CONFIG_MAPPING_NAMES,
|
| 32 |
+
AutoConfig,
|
| 33 |
+
model_type_to_module_name,
|
| 34 |
+
replace_list_option_in_docstrings,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__)
|
| 39 |
+
|
| 40 |
+
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
| 41 |
+
[
|
| 42 |
+
("audio-spectrogram-transformer", "ASTFeatureExtractor"),
|
| 43 |
+
("beit", "BeitFeatureExtractor"),
|
| 44 |
+
("chinese_clip", "ChineseCLIPFeatureExtractor"),
|
| 45 |
+
("clap", "ClapFeatureExtractor"),
|
| 46 |
+
("clip", "CLIPFeatureExtractor"),
|
| 47 |
+
("clipseg", "ViTFeatureExtractor"),
|
| 48 |
+
("clvp", "ClvpFeatureExtractor"),
|
| 49 |
+
("conditional_detr", "ConditionalDetrFeatureExtractor"),
|
| 50 |
+
("convnext", "ConvNextFeatureExtractor"),
|
| 51 |
+
("cvt", "ConvNextFeatureExtractor"),
|
| 52 |
+
("dac", "DacFeatureExtractor"),
|
| 53 |
+
("data2vec-audio", "Wav2Vec2FeatureExtractor"),
|
| 54 |
+
("data2vec-vision", "BeitFeatureExtractor"),
|
| 55 |
+
("deformable_detr", "DeformableDetrFeatureExtractor"),
|
| 56 |
+
("deit", "DeiTFeatureExtractor"),
|
| 57 |
+
("detr", "DetrFeatureExtractor"),
|
| 58 |
+
("dia", "DiaFeatureExtractor"),
|
| 59 |
+
("dinat", "ViTFeatureExtractor"),
|
| 60 |
+
("donut-swin", "DonutFeatureExtractor"),
|
| 61 |
+
("dpt", "DPTFeatureExtractor"),
|
| 62 |
+
("encodec", "EncodecFeatureExtractor"),
|
| 63 |
+
("flava", "FlavaFeatureExtractor"),
|
| 64 |
+
("gemma3n", "Gemma3nAudioFeatureExtractor"),
|
| 65 |
+
("glpn", "GLPNFeatureExtractor"),
|
| 66 |
+
("granite_speech", "GraniteSpeechFeatureExtractor"),
|
| 67 |
+
("groupvit", "CLIPFeatureExtractor"),
|
| 68 |
+
("hubert", "Wav2Vec2FeatureExtractor"),
|
| 69 |
+
("imagegpt", "ImageGPTFeatureExtractor"),
|
| 70 |
+
("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"),
|
| 71 |
+
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
|
| 72 |
+
("layoutlmv3", "LayoutLMv3FeatureExtractor"),
|
| 73 |
+
("levit", "LevitFeatureExtractor"),
|
| 74 |
+
("maskformer", "MaskFormerFeatureExtractor"),
|
| 75 |
+
("mctct", "MCTCTFeatureExtractor"),
|
| 76 |
+
("mimi", "EncodecFeatureExtractor"),
|
| 77 |
+
("mobilenet_v1", "MobileNetV1FeatureExtractor"),
|
| 78 |
+
("mobilenet_v2", "MobileNetV2FeatureExtractor"),
|
| 79 |
+
("mobilevit", "MobileViTFeatureExtractor"),
|
| 80 |
+
("moonshine", "Wav2Vec2FeatureExtractor"),
|
| 81 |
+
("moshi", "EncodecFeatureExtractor"),
|
| 82 |
+
("nat", "ViTFeatureExtractor"),
|
| 83 |
+
("owlvit", "OwlViTFeatureExtractor"),
|
| 84 |
+
("parakeet_ctc", "ParakeetFeatureExtractor"),
|
| 85 |
+
("parakeet_encoder", "ParakeetFeatureExtractor"),
|
| 86 |
+
("perceiver", "PerceiverFeatureExtractor"),
|
| 87 |
+
("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
|
| 88 |
+
("poolformer", "PoolFormerFeatureExtractor"),
|
| 89 |
+
("pop2piano", "Pop2PianoFeatureExtractor"),
|
| 90 |
+
("regnet", "ConvNextFeatureExtractor"),
|
| 91 |
+
("resnet", "ConvNextFeatureExtractor"),
|
| 92 |
+
("seamless_m4t", "SeamlessM4TFeatureExtractor"),
|
| 93 |
+
("seamless_m4t_v2", "SeamlessM4TFeatureExtractor"),
|
| 94 |
+
("segformer", "SegformerFeatureExtractor"),
|
| 95 |
+
("sew", "Wav2Vec2FeatureExtractor"),
|
| 96 |
+
("sew-d", "Wav2Vec2FeatureExtractor"),
|
| 97 |
+
("speech_to_text", "Speech2TextFeatureExtractor"),
|
| 98 |
+
("speecht5", "SpeechT5FeatureExtractor"),
|
| 99 |
+
("swiftformer", "ViTFeatureExtractor"),
|
| 100 |
+
("swin", "ViTFeatureExtractor"),
|
| 101 |
+
("swinv2", "ViTFeatureExtractor"),
|
| 102 |
+
("table-transformer", "DetrFeatureExtractor"),
|
| 103 |
+
("timesformer", "VideoMAEFeatureExtractor"),
|
| 104 |
+
("tvlt", "TvltFeatureExtractor"),
|
| 105 |
+
("unispeech", "Wav2Vec2FeatureExtractor"),
|
| 106 |
+
("unispeech-sat", "Wav2Vec2FeatureExtractor"),
|
| 107 |
+
("univnet", "UnivNetFeatureExtractor"),
|
| 108 |
+
("van", "ConvNextFeatureExtractor"),
|
| 109 |
+
("videomae", "VideoMAEFeatureExtractor"),
|
| 110 |
+
("vilt", "ViltFeatureExtractor"),
|
| 111 |
+
("vit", "ViTFeatureExtractor"),
|
| 112 |
+
("vit_mae", "ViTFeatureExtractor"),
|
| 113 |
+
("vit_msn", "ViTFeatureExtractor"),
|
| 114 |
+
("wav2vec2", "Wav2Vec2FeatureExtractor"),
|
| 115 |
+
("wav2vec2-bert", "Wav2Vec2FeatureExtractor"),
|
| 116 |
+
("wav2vec2-conformer", "Wav2Vec2FeatureExtractor"),
|
| 117 |
+
("wavlm", "Wav2Vec2FeatureExtractor"),
|
| 118 |
+
("whisper", "WhisperFeatureExtractor"),
|
| 119 |
+
("xclip", "CLIPFeatureExtractor"),
|
| 120 |
+
("xcodec", "DacFeatureExtractor"),
|
| 121 |
+
("yolos", "YolosFeatureExtractor"),
|
| 122 |
+
]
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def feature_extractor_class_from_name(class_name: str):
|
| 129 |
+
for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
|
| 130 |
+
if class_name in extractors:
|
| 131 |
+
module_name = model_type_to_module_name(module_name)
|
| 132 |
+
|
| 133 |
+
module = importlib.import_module(f".{module_name}", "transformers.models")
|
| 134 |
+
try:
|
| 135 |
+
return getattr(module, class_name)
|
| 136 |
+
except AttributeError:
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
for extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.values():
|
| 140 |
+
if getattr(extractor, "__name__", None) == class_name:
|
| 141 |
+
return extractor
|
| 142 |
+
|
| 143 |
+
# We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
|
| 144 |
+
# init and we return the proper dummy to get an appropriate error message.
|
| 145 |
+
main_module = importlib.import_module("transformers")
|
| 146 |
+
if hasattr(main_module, class_name):
|
| 147 |
+
return getattr(main_module, class_name)
|
| 148 |
+
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_feature_extractor_config(
|
| 153 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 154 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 155 |
+
force_download: bool = False,
|
| 156 |
+
resume_download: Optional[bool] = None,
|
| 157 |
+
proxies: Optional[dict[str, str]] = None,
|
| 158 |
+
token: Optional[Union[bool, str]] = None,
|
| 159 |
+
revision: Optional[str] = None,
|
| 160 |
+
local_files_only: bool = False,
|
| 161 |
+
**kwargs,
|
| 162 |
+
):
|
| 163 |
+
"""
|
| 164 |
+
Loads the tokenizer configuration from a pretrained model tokenizer configuration.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 168 |
+
This can be either:
|
| 169 |
+
|
| 170 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 171 |
+
huggingface.co.
|
| 172 |
+
- a path to a *directory* containing a configuration file saved using the
|
| 173 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 174 |
+
|
| 175 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 176 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
| 177 |
+
cache should not be used.
|
| 178 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 179 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
| 180 |
+
exist.
|
| 181 |
+
resume_download:
|
| 182 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 183 |
+
Will be removed in v5 of Transformers.
|
| 184 |
+
proxies (`dict[str, str]`, *optional*):
|
| 185 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 186 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 187 |
+
token (`str` or *bool*, *optional*):
|
| 188 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 189 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 190 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 191 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 192 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 193 |
+
identifier allowed by git.
|
| 194 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 195 |
+
If `True`, will only try to load the tokenizer configuration from local files.
|
| 196 |
+
|
| 197 |
+
<Tip>
|
| 198 |
+
|
| 199 |
+
Passing `token=True` is required when you want to use a private model.
|
| 200 |
+
|
| 201 |
+
</Tip>
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
`Dict`: The configuration of the tokenizer.
|
| 205 |
+
|
| 206 |
+
Examples:
|
| 207 |
+
|
| 208 |
+
```python
|
| 209 |
+
# Download configuration from huggingface.co and cache.
|
| 210 |
+
tokenizer_config = get_tokenizer_config("google-bert/bert-base-uncased")
|
| 211 |
+
# This model does not have a tokenizer config so the result will be an empty dict.
|
| 212 |
+
tokenizer_config = get_tokenizer_config("FacebookAI/xlm-roberta-base")
|
| 213 |
+
|
| 214 |
+
# Save a pretrained tokenizer locally and you can reload its config
|
| 215 |
+
from transformers import AutoTokenizer
|
| 216 |
+
|
| 217 |
+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
|
| 218 |
+
tokenizer.save_pretrained("tokenizer-test")
|
| 219 |
+
tokenizer_config = get_tokenizer_config("tokenizer-test")
|
| 220 |
+
```"""
|
| 221 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 222 |
+
if use_auth_token is not None:
|
| 223 |
+
warnings.warn(
|
| 224 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 225 |
+
FutureWarning,
|
| 226 |
+
)
|
| 227 |
+
if token is not None:
|
| 228 |
+
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
| 229 |
+
token = use_auth_token
|
| 230 |
+
|
| 231 |
+
resolved_config_file = cached_file(
|
| 232 |
+
pretrained_model_name_or_path,
|
| 233 |
+
FEATURE_EXTRACTOR_NAME,
|
| 234 |
+
cache_dir=cache_dir,
|
| 235 |
+
force_download=force_download,
|
| 236 |
+
resume_download=resume_download,
|
| 237 |
+
proxies=proxies,
|
| 238 |
+
token=token,
|
| 239 |
+
revision=revision,
|
| 240 |
+
local_files_only=local_files_only,
|
| 241 |
+
_raise_exceptions_for_gated_repo=False,
|
| 242 |
+
_raise_exceptions_for_missing_entries=False,
|
| 243 |
+
_raise_exceptions_for_connection_errors=False,
|
| 244 |
+
)
|
| 245 |
+
if resolved_config_file is None:
|
| 246 |
+
logger.info(
|
| 247 |
+
"Could not locate the feature extractor configuration file, will try to use the model config instead."
|
| 248 |
+
)
|
| 249 |
+
return {}
|
| 250 |
+
|
| 251 |
+
with open(resolved_config_file, encoding="utf-8") as reader:
|
| 252 |
+
return json.load(reader)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class AutoFeatureExtractor:
|
| 256 |
+
r"""
|
| 257 |
+
This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the
|
| 258 |
+
library when created with the [`AutoFeatureExtractor.from_pretrained`] class method.
|
| 259 |
+
|
| 260 |
+
This class cannot be instantiated directly using `__init__()` (throws an error).
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
def __init__(self):
|
| 264 |
+
raise OSError(
|
| 265 |
+
"AutoFeatureExtractor is designed to be instantiated "
|
| 266 |
+
"using the `AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)` method."
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
@classmethod
|
| 270 |
+
@replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING_NAMES)
|
| 271 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 272 |
+
r"""
|
| 273 |
+
Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary.
|
| 274 |
+
|
| 275 |
+
The feature extractor class to instantiate is selected based on the `model_type` property of the config object
|
| 276 |
+
(either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
|
| 277 |
+
missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 278 |
+
|
| 279 |
+
List options
|
| 280 |
+
|
| 281 |
+
Params:
|
| 282 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 283 |
+
This can be either:
|
| 284 |
+
|
| 285 |
+
- a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
|
| 286 |
+
huggingface.co.
|
| 287 |
+
- a path to a *directory* containing a feature extractor file saved using the
|
| 288 |
+
[`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g.,
|
| 289 |
+
`./my_model_directory/`.
|
| 290 |
+
- a path or url to a saved feature extractor JSON *file*, e.g.,
|
| 291 |
+
`./my_model_directory/preprocessor_config.json`.
|
| 292 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 293 |
+
Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
|
| 294 |
+
standard cache should not be used.
|
| 295 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 296 |
+
Whether or not to force to (re-)download the feature extractor files and override the cached versions
|
| 297 |
+
if they exist.
|
| 298 |
+
resume_download:
|
| 299 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 300 |
+
Will be removed in v5 of Transformers.
|
| 301 |
+
proxies (`dict[str, str]`, *optional*):
|
| 302 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 303 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 304 |
+
token (`str` or *bool*, *optional*):
|
| 305 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 306 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 307 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 308 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 309 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 310 |
+
identifier allowed by git.
|
| 311 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 312 |
+
If `False`, then this function returns just the final feature extractor object. If `True`, then this
|
| 313 |
+
functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
|
| 314 |
+
consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
|
| 315 |
+
`kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
|
| 316 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 317 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 318 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 319 |
+
execute code present on the Hub on your local machine.
|
| 320 |
+
kwargs (`dict[str, Any]`, *optional*):
|
| 321 |
+
The values in kwargs of any keys which are feature extractor attributes will be used to override the
|
| 322 |
+
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
|
| 323 |
+
controlled by the `return_unused_kwargs` keyword parameter.
|
| 324 |
+
|
| 325 |
+
<Tip>
|
| 326 |
+
|
| 327 |
+
Passing `token=True` is required when you want to use a private model.
|
| 328 |
+
|
| 329 |
+
</Tip>
|
| 330 |
+
|
| 331 |
+
Examples:
|
| 332 |
+
|
| 333 |
+
```python
|
| 334 |
+
>>> from transformers import AutoFeatureExtractor
|
| 335 |
+
|
| 336 |
+
>>> # Download feature extractor from huggingface.co and cache.
|
| 337 |
+
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
| 338 |
+
|
| 339 |
+
>>> # If feature extractor files are in a directory (e.g. feature extractor was saved using *save_pretrained('./test/saved_model/')*)
|
| 340 |
+
>>> # feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
|
| 341 |
+
```"""
|
| 342 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 343 |
+
if use_auth_token is not None:
|
| 344 |
+
warnings.warn(
|
| 345 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 346 |
+
FutureWarning,
|
| 347 |
+
)
|
| 348 |
+
if kwargs.get("token") is not None:
|
| 349 |
+
raise ValueError(
|
| 350 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 351 |
+
)
|
| 352 |
+
kwargs["token"] = use_auth_token
|
| 353 |
+
|
| 354 |
+
config = kwargs.pop("config", None)
|
| 355 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 356 |
+
kwargs["_from_auto"] = True
|
| 357 |
+
|
| 358 |
+
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
|
| 359 |
+
feature_extractor_class = config_dict.get("feature_extractor_type", None)
|
| 360 |
+
feature_extractor_auto_map = None
|
| 361 |
+
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
|
| 362 |
+
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
|
| 363 |
+
|
| 364 |
+
# If we don't find the feature extractor class in the feature extractor config, let's try the model config.
|
| 365 |
+
if feature_extractor_class is None and feature_extractor_auto_map is None:
|
| 366 |
+
if not isinstance(config, PretrainedConfig):
|
| 367 |
+
config = AutoConfig.from_pretrained(
|
| 368 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 369 |
+
)
|
| 370 |
+
# It could be in `config.feature_extractor_type``
|
| 371 |
+
feature_extractor_class = getattr(config, "feature_extractor_type", None)
|
| 372 |
+
if hasattr(config, "auto_map") and "AutoFeatureExtractor" in config.auto_map:
|
| 373 |
+
feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"]
|
| 374 |
+
|
| 375 |
+
if feature_extractor_class is not None:
|
| 376 |
+
feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
|
| 377 |
+
|
| 378 |
+
has_remote_code = feature_extractor_auto_map is not None
|
| 379 |
+
has_local_code = feature_extractor_class is not None or type(config) in FEATURE_EXTRACTOR_MAPPING
|
| 380 |
+
if has_remote_code:
|
| 381 |
+
if "--" in feature_extractor_auto_map:
|
| 382 |
+
upstream_repo = feature_extractor_auto_map.split("--")[0]
|
| 383 |
+
else:
|
| 384 |
+
upstream_repo = None
|
| 385 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 386 |
+
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
if has_remote_code and trust_remote_code:
|
| 390 |
+
feature_extractor_class = get_class_from_dynamic_module(
|
| 391 |
+
feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs
|
| 392 |
+
)
|
| 393 |
+
_ = kwargs.pop("code_revision", None)
|
| 394 |
+
feature_extractor_class.register_for_auto_class()
|
| 395 |
+
return feature_extractor_class.from_dict(config_dict, **kwargs)
|
| 396 |
+
elif feature_extractor_class is not None:
|
| 397 |
+
return feature_extractor_class.from_dict(config_dict, **kwargs)
|
| 398 |
+
# Last try: we use the FEATURE_EXTRACTOR_MAPPING.
|
| 399 |
+
elif type(config) in FEATURE_EXTRACTOR_MAPPING:
|
| 400 |
+
feature_extractor_class = FEATURE_EXTRACTOR_MAPPING[type(config)]
|
| 401 |
+
return feature_extractor_class.from_dict(config_dict, **kwargs)
|
| 402 |
+
|
| 403 |
+
raise ValueError(
|
| 404 |
+
f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a "
|
| 405 |
+
f"`feature_extractor_type` key in its {FEATURE_EXTRACTOR_NAME} of {CONFIG_NAME}, or one of the following "
|
| 406 |
+
f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES)}"
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
@staticmethod
|
| 410 |
+
def register(config_class, feature_extractor_class, exist_ok=False):
|
| 411 |
+
"""
|
| 412 |
+
Register a new feature extractor for this class.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
config_class ([`PretrainedConfig`]):
|
| 416 |
+
The configuration corresponding to the model to register.
|
| 417 |
+
feature_extractor_class ([`FeatureExtractorMixin`]): The feature extractor to register.
|
| 418 |
+
"""
|
| 419 |
+
FEATURE_EXTRACTOR_MAPPING.register(config_class, feature_extractor_class, exist_ok=exist_ok)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
__all__ = ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"]
|
venv/lib/python3.13/site-packages/transformers/models/auto/image_processing_auto.py
ADDED
|
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""AutoImageProcessor class."""
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import warnings
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
from typing import TYPE_CHECKING, Optional, Union
|
| 23 |
+
|
| 24 |
+
# Build the list of all image processors
|
| 25 |
+
from ...configuration_utils import PretrainedConfig
|
| 26 |
+
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
| 27 |
+
from ...image_processing_utils import ImageProcessingMixin
|
| 28 |
+
from ...image_processing_utils_fast import BaseImageProcessorFast
|
| 29 |
+
from ...utils import (
|
| 30 |
+
CONFIG_NAME,
|
| 31 |
+
IMAGE_PROCESSOR_NAME,
|
| 32 |
+
cached_file,
|
| 33 |
+
is_timm_config_dict,
|
| 34 |
+
is_timm_local_checkpoint,
|
| 35 |
+
is_torchvision_available,
|
| 36 |
+
is_vision_available,
|
| 37 |
+
logging,
|
| 38 |
+
)
|
| 39 |
+
from ...utils.import_utils import requires
|
| 40 |
+
from .auto_factory import _LazyAutoMapping
|
| 41 |
+
from .configuration_auto import (
|
| 42 |
+
CONFIG_MAPPING_NAMES,
|
| 43 |
+
AutoConfig,
|
| 44 |
+
model_type_to_module_name,
|
| 45 |
+
replace_list_option_in_docstrings,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
FORCE_FAST_IMAGE_PROCESSOR = ["Qwen2VLImageProcessor"]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if TYPE_CHECKING:
|
| 56 |
+
# This significantly improves completion suggestion performance when
|
| 57 |
+
# the transformers package is used with Microsoft's Pylance language server.
|
| 58 |
+
IMAGE_PROCESSOR_MAPPING_NAMES: OrderedDict[str, tuple[Optional[str], Optional[str]]] = OrderedDict()
|
| 59 |
+
else:
|
| 60 |
+
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
| 61 |
+
[
|
| 62 |
+
("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 63 |
+
("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 64 |
+
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
| 65 |
+
("aria", ("AriaImageProcessor", None)),
|
| 66 |
+
("beit", ("BeitImageProcessor", "BeitImageProcessorFast")),
|
| 67 |
+
("bit", ("BitImageProcessor", "BitImageProcessorFast")),
|
| 68 |
+
("blip", ("BlipImageProcessor", "BlipImageProcessorFast")),
|
| 69 |
+
("blip-2", ("BlipImageProcessor", "BlipImageProcessorFast")),
|
| 70 |
+
("bridgetower", ("BridgeTowerImageProcessor", "BridgeTowerImageProcessorFast")),
|
| 71 |
+
("chameleon", ("ChameleonImageProcessor", "ChameleonImageProcessorFast")),
|
| 72 |
+
("chinese_clip", ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast")),
|
| 73 |
+
("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 74 |
+
("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 75 |
+
("cohere2_vision", (None, "Cohere2VisionImageProcessorFast")),
|
| 76 |
+
("conditional_detr", ("ConditionalDetrImageProcessor", "ConditionalDetrImageProcessorFast")),
|
| 77 |
+
("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
| 78 |
+
("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
| 79 |
+
("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
| 80 |
+
("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")),
|
| 81 |
+
("deepseek_vl", ("DeepseekVLImageProcessor", "DeepseekVLImageProcessorFast")),
|
| 82 |
+
("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor", "DeepseekVLHybridImageProcessorFast")),
|
| 83 |
+
("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
|
| 84 |
+
("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
|
| 85 |
+
("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")),
|
| 86 |
+
("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")),
|
| 87 |
+
("deta", ("DetaImageProcessor", None)),
|
| 88 |
+
("detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
|
| 89 |
+
("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 90 |
+
("dinov2", ("BitImageProcessor", "BitImageProcessorFast")),
|
| 91 |
+
("dinov3_vit", (None, "DINOv3ViTImageProcessorFast")),
|
| 92 |
+
("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")),
|
| 93 |
+
("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")),
|
| 94 |
+
("edgetam", (None, "Sam2ImageProcessorFast")),
|
| 95 |
+
("efficientformer", ("EfficientFormerImageProcessor", None)),
|
| 96 |
+
("efficientloftr", ("EfficientLoFTRImageProcessor", "EfficientLoFTRImageProcessorFast")),
|
| 97 |
+
("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
| 98 |
+
("eomt", ("EomtImageProcessor", "EomtImageProcessorFast")),
|
| 99 |
+
("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")),
|
| 100 |
+
("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
|
| 101 |
+
("fuyu", ("FuyuImageProcessor", None)),
|
| 102 |
+
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
|
| 103 |
+
("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
| 104 |
+
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 105 |
+
("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")),
|
| 106 |
+
("glpn", ("GLPNImageProcessor", None)),
|
| 107 |
+
("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
|
| 108 |
+
("grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
|
| 109 |
+
("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 110 |
+
("hiera", ("BitImageProcessor", "BitImageProcessorFast")),
|
| 111 |
+
("idefics", ("IdeficsImageProcessor", None)),
|
| 112 |
+
("idefics2", ("Idefics2ImageProcessor", "Idefics2ImageProcessorFast")),
|
| 113 |
+
("idefics3", ("Idefics3ImageProcessor", "Idefics3ImageProcessorFast")),
|
| 114 |
+
("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 115 |
+
("imagegpt", ("ImageGPTImageProcessor", "ImageGPTImageProcessorFast")),
|
| 116 |
+
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
|
| 117 |
+
("instructblipvideo", ("InstructBlipVideoImageProcessor", None)),
|
| 118 |
+
("janus", ("JanusImageProcessor", "JanusImageProcessorFast")),
|
| 119 |
+
("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 120 |
+
("kosmos-2.5", ("Kosmos2_5ImageProcessor", "Kosmos2_5ImageProcessorFast")),
|
| 121 |
+
("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
|
| 122 |
+
("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
|
| 123 |
+
("levit", ("LevitImageProcessor", "LevitImageProcessorFast")),
|
| 124 |
+
("lfm2_vl", (None, "Lfm2VlImageProcessorFast")),
|
| 125 |
+
("lightglue", ("LightGlueImageProcessor", None)),
|
| 126 |
+
("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
|
| 127 |
+
("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
|
| 128 |
+
("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
|
| 129 |
+
("llava_next_video", ("LlavaNextVideoImageProcessor", None)),
|
| 130 |
+
("llava_onevision", ("LlavaOnevisionImageProcessor", "LlavaOnevisionImageProcessorFast")),
|
| 131 |
+
("mask2former", ("Mask2FormerImageProcessor", "Mask2FormerImageProcessorFast")),
|
| 132 |
+
("maskformer", ("MaskFormerImageProcessor", "MaskFormerImageProcessorFast")),
|
| 133 |
+
("metaclip_2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 134 |
+
("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 135 |
+
("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
|
| 136 |
+
("mlcd", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 137 |
+
("mllama", ("MllamaImageProcessor", None)),
|
| 138 |
+
("mm-grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
|
| 139 |
+
("mobilenet_v1", ("MobileNetV1ImageProcessor", "MobileNetV1ImageProcessorFast")),
|
| 140 |
+
("mobilenet_v2", ("MobileNetV2ImageProcessor", "MobileNetV2ImageProcessorFast")),
|
| 141 |
+
("mobilevit", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")),
|
| 142 |
+
("mobilevitv2", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")),
|
| 143 |
+
("nat", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 144 |
+
("nougat", ("NougatImageProcessor", "NougatImageProcessorFast")),
|
| 145 |
+
("oneformer", ("OneFormerImageProcessor", "OneFormerImageProcessorFast")),
|
| 146 |
+
("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")),
|
| 147 |
+
("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")),
|
| 148 |
+
("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
|
| 149 |
+
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
| 150 |
+
("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
|
| 151 |
+
("perception_lm", (None, "PerceptionLMImageProcessorFast")),
|
| 152 |
+
("phi4_multimodal", (None, "Phi4MultimodalImageProcessorFast")),
|
| 153 |
+
("pix2struct", ("Pix2StructImageProcessor", None)),
|
| 154 |
+
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
|
| 155 |
+
("poolformer", ("PoolFormerImageProcessor", "PoolFormerImageProcessorFast")),
|
| 156 |
+
("prompt_depth_anything", ("PromptDepthAnythingImageProcessor", "PromptDepthAnythingImageProcessorFast")),
|
| 157 |
+
("pvt", ("PvtImageProcessor", "PvtImageProcessorFast")),
|
| 158 |
+
("pvt_v2", ("PvtImageProcessor", "PvtImageProcessorFast")),
|
| 159 |
+
("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
| 160 |
+
("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
| 161 |
+
("qwen3_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
| 162 |
+
("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
| 163 |
+
("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
| 164 |
+
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
|
| 165 |
+
("sam", ("SamImageProcessor", "SamImageProcessorFast")),
|
| 166 |
+
("sam2", (None, "Sam2ImageProcessorFast")),
|
| 167 |
+
("sam_hq", ("SamImageProcessor", "SamImageProcessorFast")),
|
| 168 |
+
("segformer", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
|
| 169 |
+
("seggpt", ("SegGptImageProcessor", None)),
|
| 170 |
+
("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
|
| 171 |
+
("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
| 172 |
+
("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
|
| 173 |
+
("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")),
|
| 174 |
+
("superglue", ("SuperGlueImageProcessor", None)),
|
| 175 |
+
("superpoint", ("SuperPointImageProcessor", "SuperPointImageProcessorFast")),
|
| 176 |
+
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 177 |
+
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 178 |
+
("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")),
|
| 179 |
+
("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 180 |
+
("table-transformer", ("DetrImageProcessor", "DetrImageProcessorFast")),
|
| 181 |
+
("textnet", ("TextNetImageProcessor", "TextNetImageProcessorFast")),
|
| 182 |
+
("timesformer", ("VideoMAEImageProcessor", None)),
|
| 183 |
+
("timm_wrapper", ("TimmWrapperImageProcessor", None)),
|
| 184 |
+
("tvlt", ("TvltImageProcessor", None)),
|
| 185 |
+
("tvp", ("TvpImageProcessor", "TvpImageProcessorFast")),
|
| 186 |
+
("udop", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
|
| 187 |
+
("upernet", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
|
| 188 |
+
("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
| 189 |
+
("videomae", ("VideoMAEImageProcessor", None)),
|
| 190 |
+
("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")),
|
| 191 |
+
("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 192 |
+
("vit", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 193 |
+
("vit_hybrid", ("ViTHybridImageProcessor", None)),
|
| 194 |
+
("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 195 |
+
("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 196 |
+
("vitmatte", ("VitMatteImageProcessor", "VitMatteImageProcessorFast")),
|
| 197 |
+
("xclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 198 |
+
("yolos", ("YolosImageProcessor", "YolosImageProcessorFast")),
|
| 199 |
+
("zoedepth", ("ZoeDepthImageProcessor", "ZoeDepthImageProcessorFast")),
|
| 200 |
+
]
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Override to None if the packages are not available
|
| 204 |
+
for model_type, (slow_class, fast_class) in IMAGE_PROCESSOR_MAPPING_NAMES.items():
|
| 205 |
+
if not is_vision_available():
|
| 206 |
+
slow_class = None
|
| 207 |
+
if not is_torchvision_available():
|
| 208 |
+
fast_class = None
|
| 209 |
+
|
| 210 |
+
IMAGE_PROCESSOR_MAPPING_NAMES[model_type] = (slow_class, fast_class)
|
| 211 |
+
|
| 212 |
+
IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def get_image_processor_class_from_name(class_name: str):
|
| 216 |
+
if class_name == "BaseImageProcessorFast":
|
| 217 |
+
return BaseImageProcessorFast
|
| 218 |
+
|
| 219 |
+
for module_name, extractors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
|
| 220 |
+
if class_name in extractors:
|
| 221 |
+
module_name = model_type_to_module_name(module_name)
|
| 222 |
+
|
| 223 |
+
module = importlib.import_module(f".{module_name}", "transformers.models")
|
| 224 |
+
try:
|
| 225 |
+
return getattr(module, class_name)
|
| 226 |
+
except AttributeError:
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
for extractors in IMAGE_PROCESSOR_MAPPING._extra_content.values():
|
| 230 |
+
for extractor in extractors:
|
| 231 |
+
if getattr(extractor, "__name__", None) == class_name:
|
| 232 |
+
return extractor
|
| 233 |
+
|
| 234 |
+
# We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main
|
| 235 |
+
# init and we return the proper dummy to get an appropriate error message.
|
| 236 |
+
main_module = importlib.import_module("transformers")
|
| 237 |
+
if hasattr(main_module, class_name):
|
| 238 |
+
return getattr(main_module, class_name)
|
| 239 |
+
|
| 240 |
+
return None
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def get_image_processor_config(
|
| 244 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 245 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 246 |
+
force_download: bool = False,
|
| 247 |
+
resume_download: Optional[bool] = None,
|
| 248 |
+
proxies: Optional[dict[str, str]] = None,
|
| 249 |
+
token: Optional[Union[bool, str]] = None,
|
| 250 |
+
revision: Optional[str] = None,
|
| 251 |
+
local_files_only: bool = False,
|
| 252 |
+
**kwargs,
|
| 253 |
+
):
|
| 254 |
+
"""
|
| 255 |
+
Loads the image processor configuration from a pretrained model image processor configuration.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 259 |
+
This can be either:
|
| 260 |
+
|
| 261 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 262 |
+
huggingface.co.
|
| 263 |
+
- a path to a *directory* containing a configuration file saved using the
|
| 264 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 265 |
+
|
| 266 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 267 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
| 268 |
+
cache should not be used.
|
| 269 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 270 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
| 271 |
+
exist.
|
| 272 |
+
resume_download:
|
| 273 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 274 |
+
Will be removed in v5 of Transformers.
|
| 275 |
+
proxies (`dict[str, str]`, *optional*):
|
| 276 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 277 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 278 |
+
token (`str` or *bool*, *optional*):
|
| 279 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 280 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 281 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 282 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 283 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 284 |
+
identifier allowed by git.
|
| 285 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 286 |
+
If `True`, will only try to load the image processor configuration from local files.
|
| 287 |
+
|
| 288 |
+
<Tip>
|
| 289 |
+
|
| 290 |
+
Passing `token=True` is required when you want to use a private model.
|
| 291 |
+
|
| 292 |
+
</Tip>
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
`Dict`: The configuration of the image processor.
|
| 296 |
+
|
| 297 |
+
Examples:
|
| 298 |
+
|
| 299 |
+
```python
|
| 300 |
+
# Download configuration from huggingface.co and cache.
|
| 301 |
+
image_processor_config = get_image_processor_config("google-bert/bert-base-uncased")
|
| 302 |
+
# This model does not have a image processor config so the result will be an empty dict.
|
| 303 |
+
image_processor_config = get_image_processor_config("FacebookAI/xlm-roberta-base")
|
| 304 |
+
|
| 305 |
+
# Save a pretrained image processor locally and you can reload its config
|
| 306 |
+
from transformers import AutoTokenizer
|
| 307 |
+
|
| 308 |
+
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
|
| 309 |
+
image_processor.save_pretrained("image-processor-test")
|
| 310 |
+
image_processor_config = get_image_processor_config("image-processor-test")
|
| 311 |
+
```"""
|
| 312 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 313 |
+
if use_auth_token is not None:
|
| 314 |
+
warnings.warn(
|
| 315 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 316 |
+
FutureWarning,
|
| 317 |
+
)
|
| 318 |
+
if token is not None:
|
| 319 |
+
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
| 320 |
+
token = use_auth_token
|
| 321 |
+
|
| 322 |
+
resolved_config_file = cached_file(
|
| 323 |
+
pretrained_model_name_or_path,
|
| 324 |
+
IMAGE_PROCESSOR_NAME,
|
| 325 |
+
cache_dir=cache_dir,
|
| 326 |
+
force_download=force_download,
|
| 327 |
+
resume_download=resume_download,
|
| 328 |
+
proxies=proxies,
|
| 329 |
+
token=token,
|
| 330 |
+
revision=revision,
|
| 331 |
+
local_files_only=local_files_only,
|
| 332 |
+
_raise_exceptions_for_gated_repo=False,
|
| 333 |
+
_raise_exceptions_for_missing_entries=False,
|
| 334 |
+
_raise_exceptions_for_connection_errors=False,
|
| 335 |
+
)
|
| 336 |
+
if resolved_config_file is None:
|
| 337 |
+
logger.info(
|
| 338 |
+
"Could not locate the image processor configuration file, will try to use the model config instead."
|
| 339 |
+
)
|
| 340 |
+
return {}
|
| 341 |
+
|
| 342 |
+
with open(resolved_config_file, encoding="utf-8") as reader:
|
| 343 |
+
return json.load(reader)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def _warning_fast_image_processor_available(fast_class):
|
| 347 |
+
logger.warning(
|
| 348 |
+
f"Fast image processor class {fast_class} is available for this model. "
|
| 349 |
+
"Using slow image processor class. To use the fast image processor class set `use_fast=True`."
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
@requires(backends=("vision",))
|
| 354 |
+
class AutoImageProcessor:
|
| 355 |
+
r"""
|
| 356 |
+
This is a generic image processor class that will be instantiated as one of the image processor classes of the
|
| 357 |
+
library when created with the [`AutoImageProcessor.from_pretrained`] class method.
|
| 358 |
+
|
| 359 |
+
This class cannot be instantiated directly using `__init__()` (throws an error).
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
def __init__(self):
|
| 363 |
+
raise OSError(
|
| 364 |
+
"AutoImageProcessor is designed to be instantiated "
|
| 365 |
+
"using the `AutoImageProcessor.from_pretrained(pretrained_model_name_or_path)` method."
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
@classmethod
|
| 369 |
+
@replace_list_option_in_docstrings(IMAGE_PROCESSOR_MAPPING_NAMES)
|
| 370 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
| 371 |
+
r"""
|
| 372 |
+
Instantiate one of the image processor classes of the library from a pretrained model vocabulary.
|
| 373 |
+
|
| 374 |
+
The image processor class to instantiate is selected based on the `model_type` property of the config object
|
| 375 |
+
(either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
|
| 376 |
+
missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 377 |
+
|
| 378 |
+
List options
|
| 379 |
+
|
| 380 |
+
Params:
|
| 381 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 382 |
+
This can be either:
|
| 383 |
+
|
| 384 |
+
- a string, the *model id* of a pretrained image_processor hosted inside a model repo on
|
| 385 |
+
huggingface.co.
|
| 386 |
+
- a path to a *directory* containing a image processor file saved using the
|
| 387 |
+
[`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
|
| 388 |
+
`./my_model_directory/`.
|
| 389 |
+
- a path or url to a saved image processor JSON *file*, e.g.,
|
| 390 |
+
`./my_model_directory/preprocessor_config.json`.
|
| 391 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 392 |
+
Path to a directory in which a downloaded pretrained model image processor should be cached if the
|
| 393 |
+
standard cache should not be used.
|
| 394 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 395 |
+
Whether or not to force to (re-)download the image processor files and override the cached versions if
|
| 396 |
+
they exist.
|
| 397 |
+
resume_download:
|
| 398 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 399 |
+
Will be removed in v5 of Transformers.
|
| 400 |
+
proxies (`dict[str, str]`, *optional*):
|
| 401 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 402 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 403 |
+
token (`str` or *bool*, *optional*):
|
| 404 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 405 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 406 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 407 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 408 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 409 |
+
identifier allowed by git.
|
| 410 |
+
use_fast (`bool`, *optional*, defaults to `False`):
|
| 411 |
+
Use a fast torchvision-base image processor if it is supported for a given model.
|
| 412 |
+
If a fast image processor is not available for a given model, a normal numpy-based image processor
|
| 413 |
+
is returned instead.
|
| 414 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 415 |
+
If `False`, then this function returns just the final image processor object. If `True`, then this
|
| 416 |
+
functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
|
| 417 |
+
consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
|
| 418 |
+
`kwargs` which has not been used to update `image_processor` and is otherwise ignored.
|
| 419 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 420 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 421 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 422 |
+
execute code present on the Hub on your local machine.
|
| 423 |
+
image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
|
| 424 |
+
The name of the file in the model directory to use for the image processor config.
|
| 425 |
+
kwargs (`dict[str, Any]`, *optional*):
|
| 426 |
+
The values in kwargs of any keys which are image processor attributes will be used to override the
|
| 427 |
+
loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
|
| 428 |
+
controlled by the `return_unused_kwargs` keyword parameter.
|
| 429 |
+
|
| 430 |
+
<Tip>
|
| 431 |
+
|
| 432 |
+
Passing `token=True` is required when you want to use a private model.
|
| 433 |
+
|
| 434 |
+
</Tip>
|
| 435 |
+
|
| 436 |
+
Examples:
|
| 437 |
+
|
| 438 |
+
```python
|
| 439 |
+
>>> from transformers import AutoImageProcessor
|
| 440 |
+
|
| 441 |
+
>>> # Download image processor from huggingface.co and cache.
|
| 442 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
|
| 443 |
+
|
| 444 |
+
>>> # If image processor files are in a directory (e.g. image processor was saved using *save_pretrained('./test/saved_model/')*)
|
| 445 |
+
>>> # image_processor = AutoImageProcessor.from_pretrained("./test/saved_model/")
|
| 446 |
+
```"""
|
| 447 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 448 |
+
if use_auth_token is not None:
|
| 449 |
+
warnings.warn(
|
| 450 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 451 |
+
FutureWarning,
|
| 452 |
+
)
|
| 453 |
+
if kwargs.get("token") is not None:
|
| 454 |
+
raise ValueError(
|
| 455 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 456 |
+
)
|
| 457 |
+
kwargs["token"] = use_auth_token
|
| 458 |
+
|
| 459 |
+
config = kwargs.pop("config", None)
|
| 460 |
+
# TODO: @yoni, change in v4.48 (use_fast set to True by default)
|
| 461 |
+
use_fast = kwargs.pop("use_fast", None)
|
| 462 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 463 |
+
kwargs["_from_auto"] = True
|
| 464 |
+
|
| 465 |
+
# Resolve the image processor config filename
|
| 466 |
+
if "image_processor_filename" in kwargs:
|
| 467 |
+
image_processor_filename = kwargs.pop("image_processor_filename")
|
| 468 |
+
elif is_timm_local_checkpoint(pretrained_model_name_or_path):
|
| 469 |
+
image_processor_filename = CONFIG_NAME
|
| 470 |
+
else:
|
| 471 |
+
image_processor_filename = IMAGE_PROCESSOR_NAME
|
| 472 |
+
|
| 473 |
+
# Load the image processor config
|
| 474 |
+
try:
|
| 475 |
+
# Main path for all transformers models and local TimmWrapper checkpoints
|
| 476 |
+
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
|
| 477 |
+
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
|
| 478 |
+
)
|
| 479 |
+
except Exception as initial_exception:
|
| 480 |
+
# Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
|
| 481 |
+
# instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
|
| 482 |
+
# except the model name, the only way to check if a remote checkpoint is a timm model is to try to
|
| 483 |
+
# load `config.json` and if it fails with some error, we raise the initial exception.
|
| 484 |
+
try:
|
| 485 |
+
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
|
| 486 |
+
pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
|
| 487 |
+
)
|
| 488 |
+
except Exception:
|
| 489 |
+
raise initial_exception
|
| 490 |
+
|
| 491 |
+
# In case we have a config_dict, but it's not a timm config dict, we raise the initial exception,
|
| 492 |
+
# because only timm models have image processing in `config.json`.
|
| 493 |
+
if not is_timm_config_dict(config_dict):
|
| 494 |
+
raise initial_exception
|
| 495 |
+
|
| 496 |
+
image_processor_type = config_dict.get("image_processor_type", None)
|
| 497 |
+
image_processor_auto_map = None
|
| 498 |
+
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
|
| 499 |
+
image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
|
| 500 |
+
|
| 501 |
+
# If we still don't have the image processor class, check if we're loading from a previous feature extractor config
|
| 502 |
+
# and if so, infer the image processor class from there.
|
| 503 |
+
if image_processor_type is None and image_processor_auto_map is None:
|
| 504 |
+
feature_extractor_class = config_dict.pop("feature_extractor_type", None)
|
| 505 |
+
if feature_extractor_class is not None:
|
| 506 |
+
image_processor_type = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
|
| 507 |
+
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
|
| 508 |
+
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
|
| 509 |
+
image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor")
|
| 510 |
+
|
| 511 |
+
# If we don't find the image processor class in the image processor config, let's try the model config.
|
| 512 |
+
if image_processor_type is None and image_processor_auto_map is None:
|
| 513 |
+
if not isinstance(config, PretrainedConfig):
|
| 514 |
+
config = AutoConfig.from_pretrained(
|
| 515 |
+
pretrained_model_name_or_path,
|
| 516 |
+
trust_remote_code=trust_remote_code,
|
| 517 |
+
**kwargs,
|
| 518 |
+
)
|
| 519 |
+
# It could be in `config.image_processor_type``
|
| 520 |
+
image_processor_type = getattr(config, "image_processor_type", None)
|
| 521 |
+
if hasattr(config, "auto_map") and "AutoImageProcessor" in config.auto_map:
|
| 522 |
+
image_processor_auto_map = config.auto_map["AutoImageProcessor"]
|
| 523 |
+
|
| 524 |
+
image_processor_class = None
|
| 525 |
+
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
|
| 526 |
+
if image_processor_type is not None:
|
| 527 |
+
# if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor.
|
| 528 |
+
if use_fast is None:
|
| 529 |
+
use_fast = image_processor_type.endswith("Fast")
|
| 530 |
+
if not use_fast and image_processor_type in FORCE_FAST_IMAGE_PROCESSOR and is_torchvision_available():
|
| 531 |
+
use_fast = True
|
| 532 |
+
logger.warning_once(
|
| 533 |
+
f"The image processor of type `{image_processor_type}` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. "
|
| 534 |
+
"This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. "
|
| 535 |
+
"Note that this behavior will be extended to all models in a future release."
|
| 536 |
+
)
|
| 537 |
+
if not use_fast:
|
| 538 |
+
logger.warning_once(
|
| 539 |
+
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
|
| 540 |
+
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
|
| 541 |
+
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
|
| 542 |
+
)
|
| 543 |
+
if use_fast and not image_processor_type.endswith("Fast"):
|
| 544 |
+
image_processor_type += "Fast"
|
| 545 |
+
if use_fast and not is_torchvision_available():
|
| 546 |
+
# check if there is a slow image processor class to fallback to
|
| 547 |
+
image_processor_class = get_image_processor_class_from_name(image_processor_type[:-4])
|
| 548 |
+
if image_processor_class is None:
|
| 549 |
+
raise ValueError(
|
| 550 |
+
f"`{image_processor_type}` requires `torchvision` to be installed. Please install `torchvision` and try again."
|
| 551 |
+
)
|
| 552 |
+
logger.warning_once(
|
| 553 |
+
"Using `use_fast=True` but `torchvision` is not available. Falling back to the slow image processor."
|
| 554 |
+
)
|
| 555 |
+
use_fast = False
|
| 556 |
+
if use_fast:
|
| 557 |
+
for image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.values():
|
| 558 |
+
if image_processor_type in image_processors:
|
| 559 |
+
break
|
| 560 |
+
else:
|
| 561 |
+
image_processor_type = image_processor_type[:-4]
|
| 562 |
+
use_fast = False
|
| 563 |
+
logger.warning_once(
|
| 564 |
+
"`use_fast` is set to `True` but the image processor class does not have a fast version. "
|
| 565 |
+
" Falling back to the slow version."
|
| 566 |
+
)
|
| 567 |
+
image_processor_class = get_image_processor_class_from_name(image_processor_type)
|
| 568 |
+
else:
|
| 569 |
+
image_processor_type_slow = image_processor_type.removesuffix("Fast")
|
| 570 |
+
image_processor_class = get_image_processor_class_from_name(image_processor_type_slow)
|
| 571 |
+
if image_processor_class is None and image_processor_type.endswith("Fast"):
|
| 572 |
+
raise ValueError(
|
| 573 |
+
f"`{image_processor_type}` does not have a slow version. Please set `use_fast=True` when instantiating the processor."
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
has_remote_code = image_processor_auto_map is not None
|
| 577 |
+
has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
|
| 578 |
+
if has_remote_code:
|
| 579 |
+
if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple):
|
| 580 |
+
# In some configs, only the slow image processor class is stored
|
| 581 |
+
image_processor_auto_map = (image_processor_auto_map, None)
|
| 582 |
+
if use_fast and image_processor_auto_map[1] is not None:
|
| 583 |
+
class_ref = image_processor_auto_map[1]
|
| 584 |
+
else:
|
| 585 |
+
class_ref = image_processor_auto_map[0]
|
| 586 |
+
if "--" in class_ref:
|
| 587 |
+
upstream_repo = class_ref.split("--")[0]
|
| 588 |
+
else:
|
| 589 |
+
upstream_repo = None
|
| 590 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 591 |
+
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
if has_remote_code and trust_remote_code:
|
| 595 |
+
if not use_fast and image_processor_auto_map[1] is not None:
|
| 596 |
+
_warning_fast_image_processor_available(image_processor_auto_map[1])
|
| 597 |
+
|
| 598 |
+
image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
| 599 |
+
_ = kwargs.pop("code_revision", None)
|
| 600 |
+
image_processor_class.register_for_auto_class()
|
| 601 |
+
return image_processor_class.from_dict(config_dict, **kwargs)
|
| 602 |
+
elif image_processor_class is not None:
|
| 603 |
+
return image_processor_class.from_dict(config_dict, **kwargs)
|
| 604 |
+
# Last try: we use the IMAGE_PROCESSOR_MAPPING.
|
| 605 |
+
elif type(config) in IMAGE_PROCESSOR_MAPPING:
|
| 606 |
+
image_processor_tuple = IMAGE_PROCESSOR_MAPPING[type(config)]
|
| 607 |
+
|
| 608 |
+
image_processor_class_py, image_processor_class_fast = image_processor_tuple
|
| 609 |
+
|
| 610 |
+
if not use_fast and image_processor_class_fast is not None:
|
| 611 |
+
_warning_fast_image_processor_available(image_processor_class_fast)
|
| 612 |
+
|
| 613 |
+
if image_processor_class_fast and (use_fast or image_processor_class_py is None):
|
| 614 |
+
return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
| 615 |
+
else:
|
| 616 |
+
if image_processor_class_py is not None:
|
| 617 |
+
return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
| 618 |
+
else:
|
| 619 |
+
raise ValueError(
|
| 620 |
+
"This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
|
| 621 |
+
)
|
| 622 |
+
raise ValueError(
|
| 623 |
+
f"Unrecognized image processor in {pretrained_model_name_or_path}. Should have a "
|
| 624 |
+
f"`image_processor_type` key in its {IMAGE_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
|
| 625 |
+
f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in IMAGE_PROCESSOR_MAPPING_NAMES)}"
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
@staticmethod
|
| 629 |
+
def register(
|
| 630 |
+
config_class,
|
| 631 |
+
image_processor_class=None,
|
| 632 |
+
slow_image_processor_class=None,
|
| 633 |
+
fast_image_processor_class=None,
|
| 634 |
+
exist_ok=False,
|
| 635 |
+
):
|
| 636 |
+
"""
|
| 637 |
+
Register a new image processor for this class.
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
config_class ([`PretrainedConfig`]):
|
| 641 |
+
The configuration corresponding to the model to register.
|
| 642 |
+
image_processor_class ([`ImageProcessingMixin`]): The image processor to register.
|
| 643 |
+
"""
|
| 644 |
+
if image_processor_class is not None:
|
| 645 |
+
if slow_image_processor_class is not None:
|
| 646 |
+
raise ValueError("Cannot specify both image_processor_class and slow_image_processor_class")
|
| 647 |
+
warnings.warn(
|
| 648 |
+
"The image_processor_class argument is deprecated and will be removed in v4.42. Please use `slow_image_processor_class`, or `fast_image_processor_class` instead",
|
| 649 |
+
FutureWarning,
|
| 650 |
+
)
|
| 651 |
+
slow_image_processor_class = image_processor_class
|
| 652 |
+
|
| 653 |
+
if slow_image_processor_class is None and fast_image_processor_class is None:
|
| 654 |
+
raise ValueError("You need to specify either slow_image_processor_class or fast_image_processor_class")
|
| 655 |
+
if slow_image_processor_class is not None and issubclass(slow_image_processor_class, BaseImageProcessorFast):
|
| 656 |
+
raise ValueError("You passed a fast image processor in as the `slow_image_processor_class`.")
|
| 657 |
+
if fast_image_processor_class is not None and not issubclass(
|
| 658 |
+
fast_image_processor_class, BaseImageProcessorFast
|
| 659 |
+
):
|
| 660 |
+
raise ValueError("The `fast_image_processor_class` should inherit from `BaseImageProcessorFast`.")
|
| 661 |
+
|
| 662 |
+
if (
|
| 663 |
+
slow_image_processor_class is not None
|
| 664 |
+
and fast_image_processor_class is not None
|
| 665 |
+
and issubclass(fast_image_processor_class, BaseImageProcessorFast)
|
| 666 |
+
and fast_image_processor_class.slow_image_processor_class != slow_image_processor_class
|
| 667 |
+
):
|
| 668 |
+
raise ValueError(
|
| 669 |
+
"The fast processor class you are passing has a `slow_image_processor_class` attribute that is not "
|
| 670 |
+
"consistent with the slow processor class you passed (fast tokenizer has "
|
| 671 |
+
f"{fast_image_processor_class.slow_image_processor_class} and you passed {slow_image_processor_class}. Fix one of those "
|
| 672 |
+
"so they match!"
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
# Avoid resetting a set slow/fast image processor if we are passing just the other ones.
|
| 676 |
+
if config_class in IMAGE_PROCESSOR_MAPPING._extra_content:
|
| 677 |
+
existing_slow, existing_fast = IMAGE_PROCESSOR_MAPPING[config_class]
|
| 678 |
+
if slow_image_processor_class is None:
|
| 679 |
+
slow_image_processor_class = existing_slow
|
| 680 |
+
if fast_image_processor_class is None:
|
| 681 |
+
fast_image_processor_class = existing_fast
|
| 682 |
+
|
| 683 |
+
IMAGE_PROCESSOR_MAPPING.register(
|
| 684 |
+
config_class, (slow_image_processor_class, fast_image_processor_class), exist_ok=exist_ok
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
__all__ = ["IMAGE_PROCESSOR_MAPPING", "AutoImageProcessor"]
|
venv/lib/python3.13/site-packages/transformers/models/auto/modeling_auto.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
venv/lib/python3.13/site-packages/transformers/models/auto/modeling_flax_auto.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Auto Model class."""
|
| 16 |
+
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
|
| 19 |
+
from ...utils import logging
|
| 20 |
+
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
|
| 21 |
+
from .configuration_auto import CONFIG_MAPPING_NAMES
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
|
| 28 |
+
[
|
| 29 |
+
# Base model mapping
|
| 30 |
+
("albert", "FlaxAlbertModel"),
|
| 31 |
+
("bart", "FlaxBartModel"),
|
| 32 |
+
("beit", "FlaxBeitModel"),
|
| 33 |
+
("bert", "FlaxBertModel"),
|
| 34 |
+
("big_bird", "FlaxBigBirdModel"),
|
| 35 |
+
("blenderbot", "FlaxBlenderbotModel"),
|
| 36 |
+
("blenderbot-small", "FlaxBlenderbotSmallModel"),
|
| 37 |
+
("bloom", "FlaxBloomModel"),
|
| 38 |
+
("clip", "FlaxCLIPModel"),
|
| 39 |
+
("dinov2", "FlaxDinov2Model"),
|
| 40 |
+
("distilbert", "FlaxDistilBertModel"),
|
| 41 |
+
("electra", "FlaxElectraModel"),
|
| 42 |
+
("gemma", "FlaxGemmaModel"),
|
| 43 |
+
("gpt-sw3", "FlaxGPT2Model"),
|
| 44 |
+
("gpt2", "FlaxGPT2Model"),
|
| 45 |
+
("gpt_neo", "FlaxGPTNeoModel"),
|
| 46 |
+
("gptj", "FlaxGPTJModel"),
|
| 47 |
+
("llama", "FlaxLlamaModel"),
|
| 48 |
+
("longt5", "FlaxLongT5Model"),
|
| 49 |
+
("marian", "FlaxMarianModel"),
|
| 50 |
+
("mbart", "FlaxMBartModel"),
|
| 51 |
+
("mistral", "FlaxMistralModel"),
|
| 52 |
+
("mt5", "FlaxMT5Model"),
|
| 53 |
+
("opt", "FlaxOPTModel"),
|
| 54 |
+
("pegasus", "FlaxPegasusModel"),
|
| 55 |
+
("regnet", "FlaxRegNetModel"),
|
| 56 |
+
("resnet", "FlaxResNetModel"),
|
| 57 |
+
("roberta", "FlaxRobertaModel"),
|
| 58 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"),
|
| 59 |
+
("roformer", "FlaxRoFormerModel"),
|
| 60 |
+
("t5", "FlaxT5Model"),
|
| 61 |
+
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
|
| 62 |
+
("vit", "FlaxViTModel"),
|
| 63 |
+
("wav2vec2", "FlaxWav2Vec2Model"),
|
| 64 |
+
("whisper", "FlaxWhisperModel"),
|
| 65 |
+
("xglm", "FlaxXGLMModel"),
|
| 66 |
+
("xlm-roberta", "FlaxXLMRobertaModel"),
|
| 67 |
+
]
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
| 71 |
+
[
|
| 72 |
+
# Model for pre-training mapping
|
| 73 |
+
("albert", "FlaxAlbertForPreTraining"),
|
| 74 |
+
("bart", "FlaxBartForConditionalGeneration"),
|
| 75 |
+
("bert", "FlaxBertForPreTraining"),
|
| 76 |
+
("big_bird", "FlaxBigBirdForPreTraining"),
|
| 77 |
+
("electra", "FlaxElectraForPreTraining"),
|
| 78 |
+
("longt5", "FlaxLongT5ForConditionalGeneration"),
|
| 79 |
+
("mbart", "FlaxMBartForConditionalGeneration"),
|
| 80 |
+
("mt5", "FlaxMT5ForConditionalGeneration"),
|
| 81 |
+
("roberta", "FlaxRobertaForMaskedLM"),
|
| 82 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
|
| 83 |
+
("roformer", "FlaxRoFormerForMaskedLM"),
|
| 84 |
+
("t5", "FlaxT5ForConditionalGeneration"),
|
| 85 |
+
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
|
| 86 |
+
("whisper", "FlaxWhisperForConditionalGeneration"),
|
| 87 |
+
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
|
| 88 |
+
]
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
| 92 |
+
[
|
| 93 |
+
# Model for Masked LM mapping
|
| 94 |
+
("albert", "FlaxAlbertForMaskedLM"),
|
| 95 |
+
("bart", "FlaxBartForConditionalGeneration"),
|
| 96 |
+
("bert", "FlaxBertForMaskedLM"),
|
| 97 |
+
("big_bird", "FlaxBigBirdForMaskedLM"),
|
| 98 |
+
("distilbert", "FlaxDistilBertForMaskedLM"),
|
| 99 |
+
("electra", "FlaxElectraForMaskedLM"),
|
| 100 |
+
("mbart", "FlaxMBartForConditionalGeneration"),
|
| 101 |
+
("roberta", "FlaxRobertaForMaskedLM"),
|
| 102 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
|
| 103 |
+
("roformer", "FlaxRoFormerForMaskedLM"),
|
| 104 |
+
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
|
| 105 |
+
]
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
| 109 |
+
[
|
| 110 |
+
# Model for Seq2Seq Causal LM mapping
|
| 111 |
+
("bart", "FlaxBartForConditionalGeneration"),
|
| 112 |
+
("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
|
| 113 |
+
("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
|
| 114 |
+
("encoder-decoder", "FlaxEncoderDecoderModel"),
|
| 115 |
+
("longt5", "FlaxLongT5ForConditionalGeneration"),
|
| 116 |
+
("marian", "FlaxMarianMTModel"),
|
| 117 |
+
("mbart", "FlaxMBartForConditionalGeneration"),
|
| 118 |
+
("mt5", "FlaxMT5ForConditionalGeneration"),
|
| 119 |
+
("pegasus", "FlaxPegasusForConditionalGeneration"),
|
| 120 |
+
("t5", "FlaxT5ForConditionalGeneration"),
|
| 121 |
+
]
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 125 |
+
[
|
| 126 |
+
# Model for Image-classification
|
| 127 |
+
("beit", "FlaxBeitForImageClassification"),
|
| 128 |
+
("dinov2", "FlaxDinov2ForImageClassification"),
|
| 129 |
+
("regnet", "FlaxRegNetForImageClassification"),
|
| 130 |
+
("resnet", "FlaxResNetForImageClassification"),
|
| 131 |
+
("vit", "FlaxViTForImageClassification"),
|
| 132 |
+
]
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
| 136 |
+
[
|
| 137 |
+
("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"),
|
| 138 |
+
]
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
| 142 |
+
[
|
| 143 |
+
# Model for Causal LM mapping
|
| 144 |
+
("bart", "FlaxBartForCausalLM"),
|
| 145 |
+
("bert", "FlaxBertForCausalLM"),
|
| 146 |
+
("big_bird", "FlaxBigBirdForCausalLM"),
|
| 147 |
+
("bloom", "FlaxBloomForCausalLM"),
|
| 148 |
+
("electra", "FlaxElectraForCausalLM"),
|
| 149 |
+
("gemma", "FlaxGemmaForCausalLM"),
|
| 150 |
+
("gpt-sw3", "FlaxGPT2LMHeadModel"),
|
| 151 |
+
("gpt2", "FlaxGPT2LMHeadModel"),
|
| 152 |
+
("gpt_neo", "FlaxGPTNeoForCausalLM"),
|
| 153 |
+
("gptj", "FlaxGPTJForCausalLM"),
|
| 154 |
+
("llama", "FlaxLlamaForCausalLM"),
|
| 155 |
+
("mistral", "FlaxMistralForCausalLM"),
|
| 156 |
+
("opt", "FlaxOPTForCausalLM"),
|
| 157 |
+
("roberta", "FlaxRobertaForCausalLM"),
|
| 158 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
|
| 159 |
+
("xglm", "FlaxXGLMForCausalLM"),
|
| 160 |
+
("xlm-roberta", "FlaxXLMRobertaForCausalLM"),
|
| 161 |
+
]
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 165 |
+
[
|
| 166 |
+
# Model for Sequence Classification mapping
|
| 167 |
+
("albert", "FlaxAlbertForSequenceClassification"),
|
| 168 |
+
("bart", "FlaxBartForSequenceClassification"),
|
| 169 |
+
("bert", "FlaxBertForSequenceClassification"),
|
| 170 |
+
("big_bird", "FlaxBigBirdForSequenceClassification"),
|
| 171 |
+
("distilbert", "FlaxDistilBertForSequenceClassification"),
|
| 172 |
+
("electra", "FlaxElectraForSequenceClassification"),
|
| 173 |
+
("mbart", "FlaxMBartForSequenceClassification"),
|
| 174 |
+
("roberta", "FlaxRobertaForSequenceClassification"),
|
| 175 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForSequenceClassification"),
|
| 176 |
+
("roformer", "FlaxRoFormerForSequenceClassification"),
|
| 177 |
+
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
|
| 178 |
+
]
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
| 182 |
+
[
|
| 183 |
+
# Model for Question Answering mapping
|
| 184 |
+
("albert", "FlaxAlbertForQuestionAnswering"),
|
| 185 |
+
("bart", "FlaxBartForQuestionAnswering"),
|
| 186 |
+
("bert", "FlaxBertForQuestionAnswering"),
|
| 187 |
+
("big_bird", "FlaxBigBirdForQuestionAnswering"),
|
| 188 |
+
("distilbert", "FlaxDistilBertForQuestionAnswering"),
|
| 189 |
+
("electra", "FlaxElectraForQuestionAnswering"),
|
| 190 |
+
("mbart", "FlaxMBartForQuestionAnswering"),
|
| 191 |
+
("roberta", "FlaxRobertaForQuestionAnswering"),
|
| 192 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForQuestionAnswering"),
|
| 193 |
+
("roformer", "FlaxRoFormerForQuestionAnswering"),
|
| 194 |
+
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
|
| 195 |
+
]
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 199 |
+
[
|
| 200 |
+
# Model for Token Classification mapping
|
| 201 |
+
("albert", "FlaxAlbertForTokenClassification"),
|
| 202 |
+
("bert", "FlaxBertForTokenClassification"),
|
| 203 |
+
("big_bird", "FlaxBigBirdForTokenClassification"),
|
| 204 |
+
("distilbert", "FlaxDistilBertForTokenClassification"),
|
| 205 |
+
("electra", "FlaxElectraForTokenClassification"),
|
| 206 |
+
("roberta", "FlaxRobertaForTokenClassification"),
|
| 207 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForTokenClassification"),
|
| 208 |
+
("roformer", "FlaxRoFormerForTokenClassification"),
|
| 209 |
+
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
|
| 210 |
+
]
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
| 214 |
+
[
|
| 215 |
+
# Model for Multiple Choice mapping
|
| 216 |
+
("albert", "FlaxAlbertForMultipleChoice"),
|
| 217 |
+
("bert", "FlaxBertForMultipleChoice"),
|
| 218 |
+
("big_bird", "FlaxBigBirdForMultipleChoice"),
|
| 219 |
+
("distilbert", "FlaxDistilBertForMultipleChoice"),
|
| 220 |
+
("electra", "FlaxElectraForMultipleChoice"),
|
| 221 |
+
("roberta", "FlaxRobertaForMultipleChoice"),
|
| 222 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMultipleChoice"),
|
| 223 |
+
("roformer", "FlaxRoFormerForMultipleChoice"),
|
| 224 |
+
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
|
| 225 |
+
]
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
| 229 |
+
[
|
| 230 |
+
("bert", "FlaxBertForNextSentencePrediction"),
|
| 231 |
+
]
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
| 235 |
+
[
|
| 236 |
+
("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
|
| 237 |
+
("whisper", "FlaxWhisperForConditionalGeneration"),
|
| 238 |
+
]
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 242 |
+
[
|
| 243 |
+
("whisper", "FlaxWhisperForAudioClassification"),
|
| 244 |
+
]
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
|
| 248 |
+
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
| 249 |
+
FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
| 250 |
+
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
| 251 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
| 252 |
+
)
|
| 253 |
+
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 254 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
| 255 |
+
)
|
| 256 |
+
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
|
| 257 |
+
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
| 258 |
+
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 259 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
| 260 |
+
)
|
| 261 |
+
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
| 262 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
| 263 |
+
)
|
| 264 |
+
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 265 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
| 266 |
+
)
|
| 267 |
+
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
|
| 268 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
|
| 269 |
+
)
|
| 270 |
+
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
|
| 271 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
|
| 272 |
+
)
|
| 273 |
+
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
|
| 274 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
| 275 |
+
)
|
| 276 |
+
FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 277 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class FlaxAutoModel(_BaseAutoModelClass):
|
| 282 |
+
_model_mapping = FLAX_MODEL_MAPPING
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
FlaxAutoModel = auto_class_update(FlaxAutoModel)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class FlaxAutoModelForPreTraining(_BaseAutoModelClass):
|
| 289 |
+
_model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining")
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class FlaxAutoModelForCausalLM(_BaseAutoModelClass):
|
| 296 |
+
_model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling")
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class FlaxAutoModelForMaskedLM(_BaseAutoModelClass):
|
| 303 |
+
_model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling")
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
| 310 |
+
_model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
FlaxAutoModelForSeq2SeqLM = auto_class_update(
|
| 314 |
+
FlaxAutoModelForSeq2SeqLM,
|
| 315 |
+
head_doc="sequence-to-sequence language modeling",
|
| 316 |
+
checkpoint_for_example="google-t5/t5-base",
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass):
|
| 321 |
+
_model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
FlaxAutoModelForSequenceClassification = auto_class_update(
|
| 325 |
+
FlaxAutoModelForSequenceClassification, head_doc="sequence classification"
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass):
|
| 330 |
+
_model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering")
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class FlaxAutoModelForTokenClassification(_BaseAutoModelClass):
|
| 337 |
+
_model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
FlaxAutoModelForTokenClassification = auto_class_update(
|
| 341 |
+
FlaxAutoModelForTokenClassification, head_doc="token classification"
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass):
|
| 346 |
+
_model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice")
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass):
|
| 353 |
+
_model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
FlaxAutoModelForNextSentencePrediction = auto_class_update(
|
| 357 |
+
FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
|
| 362 |
+
_model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
FlaxAutoModelForImageClassification = auto_class_update(
|
| 366 |
+
FlaxAutoModelForImageClassification, head_doc="image classification"
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
|
| 371 |
+
_model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
| 378 |
+
_model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
|
| 382 |
+
FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
__all__ = [
|
| 386 |
+
"FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
|
| 387 |
+
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
|
| 388 |
+
"FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
| 389 |
+
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
|
| 390 |
+
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
| 391 |
+
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
| 392 |
+
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
|
| 393 |
+
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
| 394 |
+
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
|
| 395 |
+
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
| 396 |
+
"FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
| 397 |
+
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
| 398 |
+
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
| 399 |
+
"FLAX_MODEL_MAPPING",
|
| 400 |
+
"FlaxAutoModel",
|
| 401 |
+
"FlaxAutoModelForCausalLM",
|
| 402 |
+
"FlaxAutoModelForImageClassification",
|
| 403 |
+
"FlaxAutoModelForMaskedLM",
|
| 404 |
+
"FlaxAutoModelForMultipleChoice",
|
| 405 |
+
"FlaxAutoModelForNextSentencePrediction",
|
| 406 |
+
"FlaxAutoModelForPreTraining",
|
| 407 |
+
"FlaxAutoModelForQuestionAnswering",
|
| 408 |
+
"FlaxAutoModelForSeq2SeqLM",
|
| 409 |
+
"FlaxAutoModelForSequenceClassification",
|
| 410 |
+
"FlaxAutoModelForSpeechSeq2Seq",
|
| 411 |
+
"FlaxAutoModelForTokenClassification",
|
| 412 |
+
"FlaxAutoModelForVision2Seq",
|
| 413 |
+
]
|
venv/lib/python3.13/site-packages/transformers/models/auto/modeling_tf_auto.py
ADDED
|
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Auto Model class."""
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
from collections import OrderedDict
|
| 19 |
+
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
|
| 22 |
+
from .configuration_auto import CONFIG_MAPPING_NAMES
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
TF_MODEL_MAPPING_NAMES = OrderedDict(
|
| 29 |
+
[
|
| 30 |
+
# Base model mapping
|
| 31 |
+
("albert", "TFAlbertModel"),
|
| 32 |
+
("bart", "TFBartModel"),
|
| 33 |
+
("bert", "TFBertModel"),
|
| 34 |
+
("blenderbot", "TFBlenderbotModel"),
|
| 35 |
+
("blenderbot-small", "TFBlenderbotSmallModel"),
|
| 36 |
+
("blip", "TFBlipModel"),
|
| 37 |
+
("camembert", "TFCamembertModel"),
|
| 38 |
+
("clip", "TFCLIPModel"),
|
| 39 |
+
("convbert", "TFConvBertModel"),
|
| 40 |
+
("convnext", "TFConvNextModel"),
|
| 41 |
+
("convnextv2", "TFConvNextV2Model"),
|
| 42 |
+
("ctrl", "TFCTRLModel"),
|
| 43 |
+
("cvt", "TFCvtModel"),
|
| 44 |
+
("data2vec-vision", "TFData2VecVisionModel"),
|
| 45 |
+
("deberta", "TFDebertaModel"),
|
| 46 |
+
("deberta-v2", "TFDebertaV2Model"),
|
| 47 |
+
("deit", "TFDeiTModel"),
|
| 48 |
+
("distilbert", "TFDistilBertModel"),
|
| 49 |
+
("dpr", "TFDPRQuestionEncoder"),
|
| 50 |
+
("efficientformer", "TFEfficientFormerModel"),
|
| 51 |
+
("electra", "TFElectraModel"),
|
| 52 |
+
("esm", "TFEsmModel"),
|
| 53 |
+
("flaubert", "TFFlaubertModel"),
|
| 54 |
+
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
|
| 55 |
+
("gpt-sw3", "TFGPT2Model"),
|
| 56 |
+
("gpt2", "TFGPT2Model"),
|
| 57 |
+
("gptj", "TFGPTJModel"),
|
| 58 |
+
("groupvit", "TFGroupViTModel"),
|
| 59 |
+
("hubert", "TFHubertModel"),
|
| 60 |
+
("idefics", "TFIdeficsModel"),
|
| 61 |
+
("layoutlm", "TFLayoutLMModel"),
|
| 62 |
+
("layoutlmv3", "TFLayoutLMv3Model"),
|
| 63 |
+
("led", "TFLEDModel"),
|
| 64 |
+
("longformer", "TFLongformerModel"),
|
| 65 |
+
("lxmert", "TFLxmertModel"),
|
| 66 |
+
("marian", "TFMarianModel"),
|
| 67 |
+
("mbart", "TFMBartModel"),
|
| 68 |
+
("mistral", "TFMistralModel"),
|
| 69 |
+
("mobilebert", "TFMobileBertModel"),
|
| 70 |
+
("mobilevit", "TFMobileViTModel"),
|
| 71 |
+
("mpnet", "TFMPNetModel"),
|
| 72 |
+
("mt5", "TFMT5Model"),
|
| 73 |
+
("openai-gpt", "TFOpenAIGPTModel"),
|
| 74 |
+
("opt", "TFOPTModel"),
|
| 75 |
+
("pegasus", "TFPegasusModel"),
|
| 76 |
+
("regnet", "TFRegNetModel"),
|
| 77 |
+
("rembert", "TFRemBertModel"),
|
| 78 |
+
("resnet", "TFResNetModel"),
|
| 79 |
+
("roberta", "TFRobertaModel"),
|
| 80 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
|
| 81 |
+
("roformer", "TFRoFormerModel"),
|
| 82 |
+
("sam", "TFSamModel"),
|
| 83 |
+
("sam_vision_model", "TFSamVisionModel"),
|
| 84 |
+
("segformer", "TFSegformerModel"),
|
| 85 |
+
("speech_to_text", "TFSpeech2TextModel"),
|
| 86 |
+
("swiftformer", "TFSwiftFormerModel"),
|
| 87 |
+
("swin", "TFSwinModel"),
|
| 88 |
+
("t5", "TFT5Model"),
|
| 89 |
+
("tapas", "TFTapasModel"),
|
| 90 |
+
("transfo-xl", "TFTransfoXLModel"),
|
| 91 |
+
("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"),
|
| 92 |
+
("vit", "TFViTModel"),
|
| 93 |
+
("vit_mae", "TFViTMAEModel"),
|
| 94 |
+
("wav2vec2", "TFWav2Vec2Model"),
|
| 95 |
+
("whisper", "TFWhisperModel"),
|
| 96 |
+
("xglm", "TFXGLMModel"),
|
| 97 |
+
("xlm", "TFXLMModel"),
|
| 98 |
+
("xlm-roberta", "TFXLMRobertaModel"),
|
| 99 |
+
("xlnet", "TFXLNetModel"),
|
| 100 |
+
]
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
| 104 |
+
[
|
| 105 |
+
# Model for pre-training mapping
|
| 106 |
+
("albert", "TFAlbertForPreTraining"),
|
| 107 |
+
("bart", "TFBartForConditionalGeneration"),
|
| 108 |
+
("bert", "TFBertForPreTraining"),
|
| 109 |
+
("camembert", "TFCamembertForMaskedLM"),
|
| 110 |
+
("ctrl", "TFCTRLLMHeadModel"),
|
| 111 |
+
("distilbert", "TFDistilBertForMaskedLM"),
|
| 112 |
+
("electra", "TFElectraForPreTraining"),
|
| 113 |
+
("flaubert", "TFFlaubertWithLMHeadModel"),
|
| 114 |
+
("funnel", "TFFunnelForPreTraining"),
|
| 115 |
+
("gpt-sw3", "TFGPT2LMHeadModel"),
|
| 116 |
+
("gpt2", "TFGPT2LMHeadModel"),
|
| 117 |
+
("idefics", "TFIdeficsForVisionText2Text"),
|
| 118 |
+
("layoutlm", "TFLayoutLMForMaskedLM"),
|
| 119 |
+
("lxmert", "TFLxmertForPreTraining"),
|
| 120 |
+
("mobilebert", "TFMobileBertForPreTraining"),
|
| 121 |
+
("mpnet", "TFMPNetForMaskedLM"),
|
| 122 |
+
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
| 123 |
+
("roberta", "TFRobertaForMaskedLM"),
|
| 124 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
|
| 125 |
+
("t5", "TFT5ForConditionalGeneration"),
|
| 126 |
+
("tapas", "TFTapasForMaskedLM"),
|
| 127 |
+
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
| 128 |
+
("vit_mae", "TFViTMAEForPreTraining"),
|
| 129 |
+
("xlm", "TFXLMWithLMHeadModel"),
|
| 130 |
+
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
| 131 |
+
("xlnet", "TFXLNetLMHeadModel"),
|
| 132 |
+
]
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
| 136 |
+
[
|
| 137 |
+
# Model with LM heads mapping
|
| 138 |
+
("albert", "TFAlbertForMaskedLM"),
|
| 139 |
+
("bart", "TFBartForConditionalGeneration"),
|
| 140 |
+
("bert", "TFBertForMaskedLM"),
|
| 141 |
+
("camembert", "TFCamembertForMaskedLM"),
|
| 142 |
+
("convbert", "TFConvBertForMaskedLM"),
|
| 143 |
+
("ctrl", "TFCTRLLMHeadModel"),
|
| 144 |
+
("distilbert", "TFDistilBertForMaskedLM"),
|
| 145 |
+
("electra", "TFElectraForMaskedLM"),
|
| 146 |
+
("esm", "TFEsmForMaskedLM"),
|
| 147 |
+
("flaubert", "TFFlaubertWithLMHeadModel"),
|
| 148 |
+
("funnel", "TFFunnelForMaskedLM"),
|
| 149 |
+
("gpt-sw3", "TFGPT2LMHeadModel"),
|
| 150 |
+
("gpt2", "TFGPT2LMHeadModel"),
|
| 151 |
+
("gptj", "TFGPTJForCausalLM"),
|
| 152 |
+
("layoutlm", "TFLayoutLMForMaskedLM"),
|
| 153 |
+
("led", "TFLEDForConditionalGeneration"),
|
| 154 |
+
("longformer", "TFLongformerForMaskedLM"),
|
| 155 |
+
("marian", "TFMarianMTModel"),
|
| 156 |
+
("mobilebert", "TFMobileBertForMaskedLM"),
|
| 157 |
+
("mpnet", "TFMPNetForMaskedLM"),
|
| 158 |
+
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
| 159 |
+
("rembert", "TFRemBertForMaskedLM"),
|
| 160 |
+
("roberta", "TFRobertaForMaskedLM"),
|
| 161 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
|
| 162 |
+
("roformer", "TFRoFormerForMaskedLM"),
|
| 163 |
+
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
|
| 164 |
+
("t5", "TFT5ForConditionalGeneration"),
|
| 165 |
+
("tapas", "TFTapasForMaskedLM"),
|
| 166 |
+
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
| 167 |
+
("whisper", "TFWhisperForConditionalGeneration"),
|
| 168 |
+
("xlm", "TFXLMWithLMHeadModel"),
|
| 169 |
+
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
| 170 |
+
("xlnet", "TFXLNetLMHeadModel"),
|
| 171 |
+
]
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
| 175 |
+
[
|
| 176 |
+
# Model for Causal LM mapping
|
| 177 |
+
("bert", "TFBertLMHeadModel"),
|
| 178 |
+
("camembert", "TFCamembertForCausalLM"),
|
| 179 |
+
("ctrl", "TFCTRLLMHeadModel"),
|
| 180 |
+
("gpt-sw3", "TFGPT2LMHeadModel"),
|
| 181 |
+
("gpt2", "TFGPT2LMHeadModel"),
|
| 182 |
+
("gptj", "TFGPTJForCausalLM"),
|
| 183 |
+
("mistral", "TFMistralForCausalLM"),
|
| 184 |
+
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
| 185 |
+
("opt", "TFOPTForCausalLM"),
|
| 186 |
+
("rembert", "TFRemBertForCausalLM"),
|
| 187 |
+
("roberta", "TFRobertaForCausalLM"),
|
| 188 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"),
|
| 189 |
+
("roformer", "TFRoFormerForCausalLM"),
|
| 190 |
+
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
| 191 |
+
("xglm", "TFXGLMForCausalLM"),
|
| 192 |
+
("xlm", "TFXLMWithLMHeadModel"),
|
| 193 |
+
("xlm-roberta", "TFXLMRobertaForCausalLM"),
|
| 194 |
+
("xlnet", "TFXLNetLMHeadModel"),
|
| 195 |
+
]
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
|
| 199 |
+
[
|
| 200 |
+
("deit", "TFDeiTForMaskedImageModeling"),
|
| 201 |
+
("swin", "TFSwinForMaskedImageModeling"),
|
| 202 |
+
]
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 206 |
+
[
|
| 207 |
+
# Model for Image-classsification
|
| 208 |
+
("convnext", "TFConvNextForImageClassification"),
|
| 209 |
+
("convnextv2", "TFConvNextV2ForImageClassification"),
|
| 210 |
+
("cvt", "TFCvtForImageClassification"),
|
| 211 |
+
("data2vec-vision", "TFData2VecVisionForImageClassification"),
|
| 212 |
+
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
|
| 213 |
+
(
|
| 214 |
+
"efficientformer",
|
| 215 |
+
("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"),
|
| 216 |
+
),
|
| 217 |
+
("mobilevit", "TFMobileViTForImageClassification"),
|
| 218 |
+
("regnet", "TFRegNetForImageClassification"),
|
| 219 |
+
("resnet", "TFResNetForImageClassification"),
|
| 220 |
+
("segformer", "TFSegformerForImageClassification"),
|
| 221 |
+
("swiftformer", "TFSwiftFormerForImageClassification"),
|
| 222 |
+
("swin", "TFSwinForImageClassification"),
|
| 223 |
+
("vit", "TFViTForImageClassification"),
|
| 224 |
+
]
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 229 |
+
[
|
| 230 |
+
# Model for Zero Shot Image Classification mapping
|
| 231 |
+
("blip", "TFBlipModel"),
|
| 232 |
+
("clip", "TFCLIPModel"),
|
| 233 |
+
]
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
| 238 |
+
[
|
| 239 |
+
# Model for Semantic Segmentation mapping
|
| 240 |
+
("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
|
| 241 |
+
("mobilevit", "TFMobileViTForSemanticSegmentation"),
|
| 242 |
+
("segformer", "TFSegformerForSemanticSegmentation"),
|
| 243 |
+
]
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
| 247 |
+
[
|
| 248 |
+
("blip", "TFBlipForConditionalGeneration"),
|
| 249 |
+
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
|
| 250 |
+
]
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
| 254 |
+
[
|
| 255 |
+
# Model for Masked LM mapping
|
| 256 |
+
("albert", "TFAlbertForMaskedLM"),
|
| 257 |
+
("bert", "TFBertForMaskedLM"),
|
| 258 |
+
("camembert", "TFCamembertForMaskedLM"),
|
| 259 |
+
("convbert", "TFConvBertForMaskedLM"),
|
| 260 |
+
("deberta", "TFDebertaForMaskedLM"),
|
| 261 |
+
("deberta-v2", "TFDebertaV2ForMaskedLM"),
|
| 262 |
+
("distilbert", "TFDistilBertForMaskedLM"),
|
| 263 |
+
("electra", "TFElectraForMaskedLM"),
|
| 264 |
+
("esm", "TFEsmForMaskedLM"),
|
| 265 |
+
("flaubert", "TFFlaubertWithLMHeadModel"),
|
| 266 |
+
("funnel", "TFFunnelForMaskedLM"),
|
| 267 |
+
("layoutlm", "TFLayoutLMForMaskedLM"),
|
| 268 |
+
("longformer", "TFLongformerForMaskedLM"),
|
| 269 |
+
("mobilebert", "TFMobileBertForMaskedLM"),
|
| 270 |
+
("mpnet", "TFMPNetForMaskedLM"),
|
| 271 |
+
("rembert", "TFRemBertForMaskedLM"),
|
| 272 |
+
("roberta", "TFRobertaForMaskedLM"),
|
| 273 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
|
| 274 |
+
("roformer", "TFRoFormerForMaskedLM"),
|
| 275 |
+
("tapas", "TFTapasForMaskedLM"),
|
| 276 |
+
("xlm", "TFXLMWithLMHeadModel"),
|
| 277 |
+
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
| 278 |
+
]
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
| 282 |
+
[
|
| 283 |
+
# Model for Seq2Seq Causal LM mapping
|
| 284 |
+
("bart", "TFBartForConditionalGeneration"),
|
| 285 |
+
("blenderbot", "TFBlenderbotForConditionalGeneration"),
|
| 286 |
+
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
|
| 287 |
+
("encoder-decoder", "TFEncoderDecoderModel"),
|
| 288 |
+
("led", "TFLEDForConditionalGeneration"),
|
| 289 |
+
("marian", "TFMarianMTModel"),
|
| 290 |
+
("mbart", "TFMBartForConditionalGeneration"),
|
| 291 |
+
("mt5", "TFMT5ForConditionalGeneration"),
|
| 292 |
+
("pegasus", "TFPegasusForConditionalGeneration"),
|
| 293 |
+
("t5", "TFT5ForConditionalGeneration"),
|
| 294 |
+
]
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
| 298 |
+
[
|
| 299 |
+
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
|
| 300 |
+
("whisper", "TFWhisperForConditionalGeneration"),
|
| 301 |
+
]
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 305 |
+
[
|
| 306 |
+
# Model for Sequence Classification mapping
|
| 307 |
+
("albert", "TFAlbertForSequenceClassification"),
|
| 308 |
+
("bart", "TFBartForSequenceClassification"),
|
| 309 |
+
("bert", "TFBertForSequenceClassification"),
|
| 310 |
+
("camembert", "TFCamembertForSequenceClassification"),
|
| 311 |
+
("convbert", "TFConvBertForSequenceClassification"),
|
| 312 |
+
("ctrl", "TFCTRLForSequenceClassification"),
|
| 313 |
+
("deberta", "TFDebertaForSequenceClassification"),
|
| 314 |
+
("deberta-v2", "TFDebertaV2ForSequenceClassification"),
|
| 315 |
+
("distilbert", "TFDistilBertForSequenceClassification"),
|
| 316 |
+
("electra", "TFElectraForSequenceClassification"),
|
| 317 |
+
("esm", "TFEsmForSequenceClassification"),
|
| 318 |
+
("flaubert", "TFFlaubertForSequenceClassification"),
|
| 319 |
+
("funnel", "TFFunnelForSequenceClassification"),
|
| 320 |
+
("gpt-sw3", "TFGPT2ForSequenceClassification"),
|
| 321 |
+
("gpt2", "TFGPT2ForSequenceClassification"),
|
| 322 |
+
("gptj", "TFGPTJForSequenceClassification"),
|
| 323 |
+
("layoutlm", "TFLayoutLMForSequenceClassification"),
|
| 324 |
+
("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"),
|
| 325 |
+
("longformer", "TFLongformerForSequenceClassification"),
|
| 326 |
+
("mistral", "TFMistralForSequenceClassification"),
|
| 327 |
+
("mobilebert", "TFMobileBertForSequenceClassification"),
|
| 328 |
+
("mpnet", "TFMPNetForSequenceClassification"),
|
| 329 |
+
("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
|
| 330 |
+
("rembert", "TFRemBertForSequenceClassification"),
|
| 331 |
+
("roberta", "TFRobertaForSequenceClassification"),
|
| 332 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"),
|
| 333 |
+
("roformer", "TFRoFormerForSequenceClassification"),
|
| 334 |
+
("tapas", "TFTapasForSequenceClassification"),
|
| 335 |
+
("transfo-xl", "TFTransfoXLForSequenceClassification"),
|
| 336 |
+
("xlm", "TFXLMForSequenceClassification"),
|
| 337 |
+
("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
|
| 338 |
+
("xlnet", "TFXLNetForSequenceClassification"),
|
| 339 |
+
]
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
| 343 |
+
[
|
| 344 |
+
# Model for Question Answering mapping
|
| 345 |
+
("albert", "TFAlbertForQuestionAnswering"),
|
| 346 |
+
("bert", "TFBertForQuestionAnswering"),
|
| 347 |
+
("camembert", "TFCamembertForQuestionAnswering"),
|
| 348 |
+
("convbert", "TFConvBertForQuestionAnswering"),
|
| 349 |
+
("deberta", "TFDebertaForQuestionAnswering"),
|
| 350 |
+
("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
|
| 351 |
+
("distilbert", "TFDistilBertForQuestionAnswering"),
|
| 352 |
+
("electra", "TFElectraForQuestionAnswering"),
|
| 353 |
+
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
|
| 354 |
+
("funnel", "TFFunnelForQuestionAnswering"),
|
| 355 |
+
("gptj", "TFGPTJForQuestionAnswering"),
|
| 356 |
+
("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
|
| 357 |
+
("longformer", "TFLongformerForQuestionAnswering"),
|
| 358 |
+
("mobilebert", "TFMobileBertForQuestionAnswering"),
|
| 359 |
+
("mpnet", "TFMPNetForQuestionAnswering"),
|
| 360 |
+
("rembert", "TFRemBertForQuestionAnswering"),
|
| 361 |
+
("roberta", "TFRobertaForQuestionAnswering"),
|
| 362 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"),
|
| 363 |
+
("roformer", "TFRoFormerForQuestionAnswering"),
|
| 364 |
+
("xlm", "TFXLMForQuestionAnsweringSimple"),
|
| 365 |
+
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
|
| 366 |
+
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
|
| 367 |
+
]
|
| 368 |
+
)
|
| 369 |
+
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])
|
| 370 |
+
|
| 371 |
+
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
| 372 |
+
[
|
| 373 |
+
("layoutlm", "TFLayoutLMForQuestionAnswering"),
|
| 374 |
+
("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
|
| 375 |
+
]
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
| 380 |
+
[
|
| 381 |
+
# Model for Table Question Answering mapping
|
| 382 |
+
("tapas", "TFTapasForQuestionAnswering"),
|
| 383 |
+
]
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 387 |
+
[
|
| 388 |
+
# Model for Token Classification mapping
|
| 389 |
+
("albert", "TFAlbertForTokenClassification"),
|
| 390 |
+
("bert", "TFBertForTokenClassification"),
|
| 391 |
+
("camembert", "TFCamembertForTokenClassification"),
|
| 392 |
+
("convbert", "TFConvBertForTokenClassification"),
|
| 393 |
+
("deberta", "TFDebertaForTokenClassification"),
|
| 394 |
+
("deberta-v2", "TFDebertaV2ForTokenClassification"),
|
| 395 |
+
("distilbert", "TFDistilBertForTokenClassification"),
|
| 396 |
+
("electra", "TFElectraForTokenClassification"),
|
| 397 |
+
("esm", "TFEsmForTokenClassification"),
|
| 398 |
+
("flaubert", "TFFlaubertForTokenClassification"),
|
| 399 |
+
("funnel", "TFFunnelForTokenClassification"),
|
| 400 |
+
("layoutlm", "TFLayoutLMForTokenClassification"),
|
| 401 |
+
("layoutlmv3", "TFLayoutLMv3ForTokenClassification"),
|
| 402 |
+
("longformer", "TFLongformerForTokenClassification"),
|
| 403 |
+
("mobilebert", "TFMobileBertForTokenClassification"),
|
| 404 |
+
("mpnet", "TFMPNetForTokenClassification"),
|
| 405 |
+
("rembert", "TFRemBertForTokenClassification"),
|
| 406 |
+
("roberta", "TFRobertaForTokenClassification"),
|
| 407 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"),
|
| 408 |
+
("roformer", "TFRoFormerForTokenClassification"),
|
| 409 |
+
("xlm", "TFXLMForTokenClassification"),
|
| 410 |
+
("xlm-roberta", "TFXLMRobertaForTokenClassification"),
|
| 411 |
+
("xlnet", "TFXLNetForTokenClassification"),
|
| 412 |
+
]
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
| 416 |
+
[
|
| 417 |
+
# Model for Multiple Choice mapping
|
| 418 |
+
("albert", "TFAlbertForMultipleChoice"),
|
| 419 |
+
("bert", "TFBertForMultipleChoice"),
|
| 420 |
+
("camembert", "TFCamembertForMultipleChoice"),
|
| 421 |
+
("convbert", "TFConvBertForMultipleChoice"),
|
| 422 |
+
("deberta-v2", "TFDebertaV2ForMultipleChoice"),
|
| 423 |
+
("distilbert", "TFDistilBertForMultipleChoice"),
|
| 424 |
+
("electra", "TFElectraForMultipleChoice"),
|
| 425 |
+
("flaubert", "TFFlaubertForMultipleChoice"),
|
| 426 |
+
("funnel", "TFFunnelForMultipleChoice"),
|
| 427 |
+
("longformer", "TFLongformerForMultipleChoice"),
|
| 428 |
+
("mobilebert", "TFMobileBertForMultipleChoice"),
|
| 429 |
+
("mpnet", "TFMPNetForMultipleChoice"),
|
| 430 |
+
("rembert", "TFRemBertForMultipleChoice"),
|
| 431 |
+
("roberta", "TFRobertaForMultipleChoice"),
|
| 432 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"),
|
| 433 |
+
("roformer", "TFRoFormerForMultipleChoice"),
|
| 434 |
+
("xlm", "TFXLMForMultipleChoice"),
|
| 435 |
+
("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
|
| 436 |
+
("xlnet", "TFXLNetForMultipleChoice"),
|
| 437 |
+
]
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
| 441 |
+
[
|
| 442 |
+
("bert", "TFBertForNextSentencePrediction"),
|
| 443 |
+
("mobilebert", "TFMobileBertForNextSentencePrediction"),
|
| 444 |
+
]
|
| 445 |
+
)
|
| 446 |
+
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
|
| 447 |
+
[
|
| 448 |
+
("sam", "TFSamModel"),
|
| 449 |
+
]
|
| 450 |
+
)
|
| 451 |
+
TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
| 452 |
+
[
|
| 453 |
+
("albert", "TFAlbertModel"),
|
| 454 |
+
("bert", "TFBertModel"),
|
| 455 |
+
("convbert", "TFConvBertModel"),
|
| 456 |
+
("deberta", "TFDebertaModel"),
|
| 457 |
+
("deberta-v2", "TFDebertaV2Model"),
|
| 458 |
+
("distilbert", "TFDistilBertModel"),
|
| 459 |
+
("electra", "TFElectraModel"),
|
| 460 |
+
("flaubert", "TFFlaubertModel"),
|
| 461 |
+
("longformer", "TFLongformerModel"),
|
| 462 |
+
("mobilebert", "TFMobileBertModel"),
|
| 463 |
+
("mt5", "TFMT5EncoderModel"),
|
| 464 |
+
("rembert", "TFRemBertModel"),
|
| 465 |
+
("roberta", "TFRobertaModel"),
|
| 466 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
|
| 467 |
+
("roformer", "TFRoFormerModel"),
|
| 468 |
+
("t5", "TFT5EncoderModel"),
|
| 469 |
+
("xlm", "TFXLMModel"),
|
| 470 |
+
("xlm-roberta", "TFXLMRobertaModel"),
|
| 471 |
+
]
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
|
| 475 |
+
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
| 476 |
+
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
| 477 |
+
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
| 478 |
+
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
| 479 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
|
| 480 |
+
)
|
| 481 |
+
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 482 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
| 483 |
+
)
|
| 484 |
+
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 485 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
| 486 |
+
)
|
| 487 |
+
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
| 488 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
|
| 489 |
+
)
|
| 490 |
+
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
|
| 491 |
+
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
| 492 |
+
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
| 493 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
| 494 |
+
)
|
| 495 |
+
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 496 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
| 497 |
+
)
|
| 498 |
+
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
|
| 499 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
| 500 |
+
)
|
| 501 |
+
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
| 502 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
| 503 |
+
)
|
| 504 |
+
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
| 505 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
|
| 506 |
+
)
|
| 507 |
+
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
| 508 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
|
| 509 |
+
)
|
| 510 |
+
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 511 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
| 512 |
+
)
|
| 513 |
+
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
|
| 514 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
|
| 515 |
+
)
|
| 516 |
+
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
|
| 517 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
|
| 518 |
+
)
|
| 519 |
+
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 520 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
|
| 524 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
|
| 531 |
+
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class TFAutoModelForTextEncoding(_BaseAutoModelClass):
|
| 535 |
+
_model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class TFAutoModel(_BaseAutoModelClass):
|
| 539 |
+
_model_mapping = TF_MODEL_MAPPING
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
TFAutoModel = auto_class_update(TFAutoModel)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
class TFAutoModelForAudioClassification(_BaseAutoModelClass):
|
| 546 |
+
_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
TFAutoModelForAudioClassification = auto_class_update(
|
| 550 |
+
TFAutoModelForAudioClassification, head_doc="audio classification"
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
class TFAutoModelForPreTraining(_BaseAutoModelClass):
|
| 555 |
+
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
# Private on purpose, the public class will add the deprecation warnings.
|
| 562 |
+
class _TFAutoModelWithLMHead(_BaseAutoModelClass):
|
| 563 |
+
_model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
class TFAutoModelForCausalLM(_BaseAutoModelClass):
|
| 570 |
+
_model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass):
|
| 577 |
+
_model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
TFAutoModelForMaskedImageModeling = auto_class_update(
|
| 581 |
+
TFAutoModelForMaskedImageModeling, head_doc="masked image modeling"
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
class TFAutoModelForImageClassification(_BaseAutoModelClass):
|
| 586 |
+
_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
TFAutoModelForImageClassification = auto_class_update(
|
| 590 |
+
TFAutoModelForImageClassification, head_doc="image classification"
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
|
| 595 |
+
_model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
TFAutoModelForZeroShotImageClassification = auto_class_update(
|
| 599 |
+
TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
|
| 604 |
+
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
TFAutoModelForSemanticSegmentation = auto_class_update(
|
| 608 |
+
TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class TFAutoModelForVision2Seq(_BaseAutoModelClass):
|
| 613 |
+
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
class TFAutoModelForMaskedLM(_BaseAutoModelClass):
|
| 620 |
+
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
| 627 |
+
_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
TFAutoModelForSeq2SeqLM = auto_class_update(
|
| 631 |
+
TFAutoModelForSeq2SeqLM,
|
| 632 |
+
head_doc="sequence-to-sequence language modeling",
|
| 633 |
+
checkpoint_for_example="google-t5/t5-base",
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
|
| 638 |
+
_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
TFAutoModelForSequenceClassification = auto_class_update(
|
| 642 |
+
TFAutoModelForSequenceClassification, head_doc="sequence classification"
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
|
| 647 |
+
_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
|
| 654 |
+
_model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
TFAutoModelForDocumentQuestionAnswering = auto_class_update(
|
| 658 |
+
TFAutoModelForDocumentQuestionAnswering,
|
| 659 |
+
head_doc="document question answering",
|
| 660 |
+
checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass):
|
| 665 |
+
_model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
TFAutoModelForTableQuestionAnswering = auto_class_update(
|
| 669 |
+
TFAutoModelForTableQuestionAnswering,
|
| 670 |
+
head_doc="table question answering",
|
| 671 |
+
checkpoint_for_example="google/tapas-base-finetuned-wtq",
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
class TFAutoModelForTokenClassification(_BaseAutoModelClass):
|
| 676 |
+
_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
TFAutoModelForTokenClassification = auto_class_update(
|
| 680 |
+
TFAutoModelForTokenClassification, head_doc="token classification"
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
|
| 685 |
+
_model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice")
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
|
| 692 |
+
_model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
TFAutoModelForNextSentencePrediction = auto_class_update(
|
| 696 |
+
TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
| 701 |
+
_model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
TFAutoModelForSpeechSeq2Seq = auto_class_update(
|
| 705 |
+
TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
|
| 710 |
+
@classmethod
|
| 711 |
+
def from_config(cls, config):
|
| 712 |
+
warnings.warn(
|
| 713 |
+
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
|
| 714 |
+
" `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
|
| 715 |
+
" and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
|
| 716 |
+
FutureWarning,
|
| 717 |
+
)
|
| 718 |
+
return super().from_config(config)
|
| 719 |
+
|
| 720 |
+
@classmethod
|
| 721 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 722 |
+
warnings.warn(
|
| 723 |
+
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
|
| 724 |
+
" `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
|
| 725 |
+
" and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
|
| 726 |
+
FutureWarning,
|
| 727 |
+
)
|
| 728 |
+
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
__all__ = [
|
| 732 |
+
"TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
|
| 733 |
+
"TF_MODEL_FOR_CAUSAL_LM_MAPPING",
|
| 734 |
+
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
| 735 |
+
"TF_MODEL_FOR_MASK_GENERATION_MAPPING",
|
| 736 |
+
"TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
| 737 |
+
"TF_MODEL_FOR_MASKED_LM_MAPPING",
|
| 738 |
+
"TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
| 739 |
+
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
| 740 |
+
"TF_MODEL_FOR_PRETRAINING_MAPPING",
|
| 741 |
+
"TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
| 742 |
+
"TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
|
| 743 |
+
"TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
|
| 744 |
+
"TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
|
| 745 |
+
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
| 746 |
+
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
| 747 |
+
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
| 748 |
+
"TF_MODEL_FOR_TEXT_ENCODING_MAPPING",
|
| 749 |
+
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
| 750 |
+
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
| 751 |
+
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
|
| 752 |
+
"TF_MODEL_MAPPING",
|
| 753 |
+
"TF_MODEL_WITH_LM_HEAD_MAPPING",
|
| 754 |
+
"TFAutoModel",
|
| 755 |
+
"TFAutoModelForAudioClassification",
|
| 756 |
+
"TFAutoModelForCausalLM",
|
| 757 |
+
"TFAutoModelForImageClassification",
|
| 758 |
+
"TFAutoModelForMaskedImageModeling",
|
| 759 |
+
"TFAutoModelForMaskedLM",
|
| 760 |
+
"TFAutoModelForMaskGeneration",
|
| 761 |
+
"TFAutoModelForMultipleChoice",
|
| 762 |
+
"TFAutoModelForNextSentencePrediction",
|
| 763 |
+
"TFAutoModelForPreTraining",
|
| 764 |
+
"TFAutoModelForDocumentQuestionAnswering",
|
| 765 |
+
"TFAutoModelForQuestionAnswering",
|
| 766 |
+
"TFAutoModelForSemanticSegmentation",
|
| 767 |
+
"TFAutoModelForSeq2SeqLM",
|
| 768 |
+
"TFAutoModelForSequenceClassification",
|
| 769 |
+
"TFAutoModelForSpeechSeq2Seq",
|
| 770 |
+
"TFAutoModelForTableQuestionAnswering",
|
| 771 |
+
"TFAutoModelForTextEncoding",
|
| 772 |
+
"TFAutoModelForTokenClassification",
|
| 773 |
+
"TFAutoModelForVision2Seq",
|
| 774 |
+
"TFAutoModelForZeroShotImageClassification",
|
| 775 |
+
"TFAutoModelWithLMHead",
|
| 776 |
+
]
|
venv/lib/python3.13/site-packages/transformers/models/auto/processing_auto.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""AutoProcessor class."""
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import inspect
|
| 19 |
+
import json
|
| 20 |
+
import warnings
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
|
| 23 |
+
# Build the list of all feature extractors
|
| 24 |
+
from ...configuration_utils import PretrainedConfig
|
| 25 |
+
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
| 26 |
+
from ...feature_extraction_utils import FeatureExtractionMixin
|
| 27 |
+
from ...image_processing_utils import ImageProcessingMixin
|
| 28 |
+
from ...processing_utils import ProcessorMixin
|
| 29 |
+
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
|
| 30 |
+
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, VIDEO_PROCESSOR_NAME, cached_file, logging
|
| 31 |
+
from ...video_processing_utils import BaseVideoProcessor
|
| 32 |
+
from .auto_factory import _LazyAutoMapping
|
| 33 |
+
from .configuration_auto import (
|
| 34 |
+
CONFIG_MAPPING_NAMES,
|
| 35 |
+
AutoConfig,
|
| 36 |
+
model_type_to_module_name,
|
| 37 |
+
replace_list_option_in_docstrings,
|
| 38 |
+
)
|
| 39 |
+
from .feature_extraction_auto import AutoFeatureExtractor
|
| 40 |
+
from .image_processing_auto import AutoImageProcessor
|
| 41 |
+
from .tokenization_auto import AutoTokenizer
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__)
|
| 45 |
+
|
| 46 |
+
PROCESSOR_MAPPING_NAMES = OrderedDict(
|
| 47 |
+
[
|
| 48 |
+
("aimv2", "CLIPProcessor"),
|
| 49 |
+
("align", "AlignProcessor"),
|
| 50 |
+
("altclip", "AltCLIPProcessor"),
|
| 51 |
+
("aria", "AriaProcessor"),
|
| 52 |
+
("aya_vision", "AyaVisionProcessor"),
|
| 53 |
+
("bark", "BarkProcessor"),
|
| 54 |
+
("blip", "BlipProcessor"),
|
| 55 |
+
("blip-2", "Blip2Processor"),
|
| 56 |
+
("bridgetower", "BridgeTowerProcessor"),
|
| 57 |
+
("chameleon", "ChameleonProcessor"),
|
| 58 |
+
("chinese_clip", "ChineseCLIPProcessor"),
|
| 59 |
+
("clap", "ClapProcessor"),
|
| 60 |
+
("clip", "CLIPProcessor"),
|
| 61 |
+
("clipseg", "CLIPSegProcessor"),
|
| 62 |
+
("clvp", "ClvpProcessor"),
|
| 63 |
+
("cohere2_vision", "Cohere2VisionProcessor"),
|
| 64 |
+
("colpali", "ColPaliProcessor"),
|
| 65 |
+
("colqwen2", "ColQwen2Processor"),
|
| 66 |
+
("deepseek_vl", "DeepseekVLProcessor"),
|
| 67 |
+
("deepseek_vl_hybrid", "DeepseekVLHybridProcessor"),
|
| 68 |
+
("dia", "DiaProcessor"),
|
| 69 |
+
("edgetam", "Sam2Processor"),
|
| 70 |
+
("emu3", "Emu3Processor"),
|
| 71 |
+
("evolla", "EvollaProcessor"),
|
| 72 |
+
("flava", "FlavaProcessor"),
|
| 73 |
+
("florence2", "Florence2Processor"),
|
| 74 |
+
("fuyu", "FuyuProcessor"),
|
| 75 |
+
("gemma3", "Gemma3Processor"),
|
| 76 |
+
("gemma3n", "Gemma3nProcessor"),
|
| 77 |
+
("git", "GitProcessor"),
|
| 78 |
+
("glm4v", "Glm4vProcessor"),
|
| 79 |
+
("glm4v_moe", "Glm4vProcessor"),
|
| 80 |
+
("got_ocr2", "GotOcr2Processor"),
|
| 81 |
+
("granite_speech", "GraniteSpeechProcessor"),
|
| 82 |
+
("grounding-dino", "GroundingDinoProcessor"),
|
| 83 |
+
("groupvit", "CLIPProcessor"),
|
| 84 |
+
("hubert", "Wav2Vec2Processor"),
|
| 85 |
+
("idefics", "IdeficsProcessor"),
|
| 86 |
+
("idefics2", "Idefics2Processor"),
|
| 87 |
+
("idefics3", "Idefics3Processor"),
|
| 88 |
+
("instructblip", "InstructBlipProcessor"),
|
| 89 |
+
("instructblipvideo", "InstructBlipVideoProcessor"),
|
| 90 |
+
("internvl", "InternVLProcessor"),
|
| 91 |
+
("janus", "JanusProcessor"),
|
| 92 |
+
("kosmos-2", "Kosmos2Processor"),
|
| 93 |
+
("kosmos-2.5", "Kosmos2_5Processor"),
|
| 94 |
+
("kyutai_speech_to_text", "KyutaiSpeechToTextProcessor"),
|
| 95 |
+
("layoutlmv2", "LayoutLMv2Processor"),
|
| 96 |
+
("layoutlmv3", "LayoutLMv3Processor"),
|
| 97 |
+
("lfm2_vl", "Lfm2VlProcessor"),
|
| 98 |
+
("llama4", "Llama4Processor"),
|
| 99 |
+
("llava", "LlavaProcessor"),
|
| 100 |
+
("llava_next", "LlavaNextProcessor"),
|
| 101 |
+
("llava_next_video", "LlavaNextVideoProcessor"),
|
| 102 |
+
("llava_onevision", "LlavaOnevisionProcessor"),
|
| 103 |
+
("markuplm", "MarkupLMProcessor"),
|
| 104 |
+
("mctct", "MCTCTProcessor"),
|
| 105 |
+
("metaclip_2", "CLIPProcessor"),
|
| 106 |
+
("mgp-str", "MgpstrProcessor"),
|
| 107 |
+
("mistral3", "PixtralProcessor"),
|
| 108 |
+
("mllama", "MllamaProcessor"),
|
| 109 |
+
("mm-grounding-dino", "GroundingDinoProcessor"),
|
| 110 |
+
("moonshine", "Wav2Vec2Processor"),
|
| 111 |
+
("oneformer", "OneFormerProcessor"),
|
| 112 |
+
("ovis2", "Ovis2Processor"),
|
| 113 |
+
("owlv2", "Owlv2Processor"),
|
| 114 |
+
("owlvit", "OwlViTProcessor"),
|
| 115 |
+
("paligemma", "PaliGemmaProcessor"),
|
| 116 |
+
("perception_lm", "PerceptionLMProcessor"),
|
| 117 |
+
("phi4_multimodal", "Phi4MultimodalProcessor"),
|
| 118 |
+
("pix2struct", "Pix2StructProcessor"),
|
| 119 |
+
("pixtral", "PixtralProcessor"),
|
| 120 |
+
("pop2piano", "Pop2PianoProcessor"),
|
| 121 |
+
("qwen2_5_omni", "Qwen2_5OmniProcessor"),
|
| 122 |
+
("qwen2_5_vl", "Qwen2_5_VLProcessor"),
|
| 123 |
+
("qwen2_audio", "Qwen2AudioProcessor"),
|
| 124 |
+
("qwen2_vl", "Qwen2VLProcessor"),
|
| 125 |
+
("qwen3_omni_moe", "Qwen3OmniMoeProcessor"),
|
| 126 |
+
("qwen3_vl", "Qwen3VLProcessor"),
|
| 127 |
+
("qwen3_vl_moe", "Qwen3VLProcessor"),
|
| 128 |
+
("sam", "SamProcessor"),
|
| 129 |
+
("sam2", "Sam2Processor"),
|
| 130 |
+
("sam_hq", "SamHQProcessor"),
|
| 131 |
+
("seamless_m4t", "SeamlessM4TProcessor"),
|
| 132 |
+
("sew", "Wav2Vec2Processor"),
|
| 133 |
+
("sew-d", "Wav2Vec2Processor"),
|
| 134 |
+
("shieldgemma2", "ShieldGemma2Processor"),
|
| 135 |
+
("siglip", "SiglipProcessor"),
|
| 136 |
+
("siglip2", "Siglip2Processor"),
|
| 137 |
+
("smolvlm", "SmolVLMProcessor"),
|
| 138 |
+
("speech_to_text", "Speech2TextProcessor"),
|
| 139 |
+
("speech_to_text_2", "Speech2Text2Processor"),
|
| 140 |
+
("speecht5", "SpeechT5Processor"),
|
| 141 |
+
("trocr", "TrOCRProcessor"),
|
| 142 |
+
("tvlt", "TvltProcessor"),
|
| 143 |
+
("tvp", "TvpProcessor"),
|
| 144 |
+
("udop", "UdopProcessor"),
|
| 145 |
+
("unispeech", "Wav2Vec2Processor"),
|
| 146 |
+
("unispeech-sat", "Wav2Vec2Processor"),
|
| 147 |
+
("video_llava", "VideoLlavaProcessor"),
|
| 148 |
+
("vilt", "ViltProcessor"),
|
| 149 |
+
("vipllava", "LlavaProcessor"),
|
| 150 |
+
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
|
| 151 |
+
("voxtral", "VoxtralProcessor"),
|
| 152 |
+
("wav2vec2", "Wav2Vec2Processor"),
|
| 153 |
+
("wav2vec2-bert", "Wav2Vec2Processor"),
|
| 154 |
+
("wav2vec2-conformer", "Wav2Vec2Processor"),
|
| 155 |
+
("wavlm", "Wav2Vec2Processor"),
|
| 156 |
+
("whisper", "WhisperProcessor"),
|
| 157 |
+
("xclip", "XCLIPProcessor"),
|
| 158 |
+
]
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, PROCESSOR_MAPPING_NAMES)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def processor_class_from_name(class_name: str):
|
| 165 |
+
for module_name, processors in PROCESSOR_MAPPING_NAMES.items():
|
| 166 |
+
if class_name in processors:
|
| 167 |
+
module_name = model_type_to_module_name(module_name)
|
| 168 |
+
|
| 169 |
+
module = importlib.import_module(f".{module_name}", "transformers.models")
|
| 170 |
+
try:
|
| 171 |
+
return getattr(module, class_name)
|
| 172 |
+
except AttributeError:
|
| 173 |
+
continue
|
| 174 |
+
|
| 175 |
+
for processor in PROCESSOR_MAPPING._extra_content.values():
|
| 176 |
+
if getattr(processor, "__name__", None) == class_name:
|
| 177 |
+
return processor
|
| 178 |
+
|
| 179 |
+
# We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
|
| 180 |
+
# init and we return the proper dummy to get an appropriate error message.
|
| 181 |
+
main_module = importlib.import_module("transformers")
|
| 182 |
+
if hasattr(main_module, class_name):
|
| 183 |
+
return getattr(main_module, class_name)
|
| 184 |
+
|
| 185 |
+
return None
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class AutoProcessor:
|
| 189 |
+
r"""
|
| 190 |
+
This is a generic processor class that will be instantiated as one of the processor classes of the library when
|
| 191 |
+
created with the [`AutoProcessor.from_pretrained`] class method.
|
| 192 |
+
|
| 193 |
+
This class cannot be instantiated directly using `__init__()` (throws an error).
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
def __init__(self):
|
| 197 |
+
raise OSError(
|
| 198 |
+
"AutoProcessor is designed to be instantiated "
|
| 199 |
+
"using the `AutoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
@classmethod
|
| 203 |
+
@replace_list_option_in_docstrings(PROCESSOR_MAPPING_NAMES)
|
| 204 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 205 |
+
r"""
|
| 206 |
+
Instantiate one of the processor classes of the library from a pretrained model vocabulary.
|
| 207 |
+
|
| 208 |
+
The processor class to instantiate is selected based on the `model_type` property of the config object (either
|
| 209 |
+
passed as an argument or loaded from `pretrained_model_name_or_path` if possible):
|
| 210 |
+
|
| 211 |
+
List options
|
| 212 |
+
|
| 213 |
+
Params:
|
| 214 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 215 |
+
This can be either:
|
| 216 |
+
|
| 217 |
+
- a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
|
| 218 |
+
huggingface.co.
|
| 219 |
+
- a path to a *directory* containing a processor files saved using the `save_pretrained()` method,
|
| 220 |
+
e.g., `./my_model_directory/`.
|
| 221 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 222 |
+
Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
|
| 223 |
+
standard cache should not be used.
|
| 224 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 225 |
+
Whether or not to force to (re-)download the feature extractor files and override the cached versions
|
| 226 |
+
if they exist.
|
| 227 |
+
resume_download:
|
| 228 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 229 |
+
Will be removed in v5 of Transformers.
|
| 230 |
+
proxies (`dict[str, str]`, *optional*):
|
| 231 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 232 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 233 |
+
token (`str` or *bool*, *optional*):
|
| 234 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 235 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 236 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 237 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 238 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 239 |
+
identifier allowed by git.
|
| 240 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 241 |
+
If `False`, then this function returns just the final feature extractor object. If `True`, then this
|
| 242 |
+
functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
|
| 243 |
+
consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
|
| 244 |
+
`kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
|
| 245 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 246 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 247 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 248 |
+
execute code present on the Hub on your local machine.
|
| 249 |
+
kwargs (`dict[str, Any]`, *optional*):
|
| 250 |
+
The values in kwargs of any keys which are feature extractor attributes will be used to override the
|
| 251 |
+
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
|
| 252 |
+
controlled by the `return_unused_kwargs` keyword parameter.
|
| 253 |
+
|
| 254 |
+
<Tip>
|
| 255 |
+
|
| 256 |
+
Passing `token=True` is required when you want to use a private model.
|
| 257 |
+
|
| 258 |
+
</Tip>
|
| 259 |
+
|
| 260 |
+
Examples:
|
| 261 |
+
|
| 262 |
+
```python
|
| 263 |
+
>>> from transformers import AutoProcessor
|
| 264 |
+
|
| 265 |
+
>>> # Download processor from huggingface.co and cache.
|
| 266 |
+
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
| 267 |
+
|
| 268 |
+
>>> # If processor files are in a directory (e.g. processor was saved using *save_pretrained('./test/saved_model/')*)
|
| 269 |
+
>>> # processor = AutoProcessor.from_pretrained("./test/saved_model/")
|
| 270 |
+
```"""
|
| 271 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 272 |
+
if use_auth_token is not None:
|
| 273 |
+
warnings.warn(
|
| 274 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 275 |
+
FutureWarning,
|
| 276 |
+
)
|
| 277 |
+
if kwargs.get("token") is not None:
|
| 278 |
+
raise ValueError(
|
| 279 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 280 |
+
)
|
| 281 |
+
kwargs["token"] = use_auth_token
|
| 282 |
+
|
| 283 |
+
config = kwargs.pop("config", None)
|
| 284 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 285 |
+
kwargs["_from_auto"] = True
|
| 286 |
+
|
| 287 |
+
processor_class = None
|
| 288 |
+
processor_auto_map = None
|
| 289 |
+
|
| 290 |
+
# First, let's see if we have a processor or preprocessor config.
|
| 291 |
+
# Filter the kwargs for `cached_file`.
|
| 292 |
+
cached_file_kwargs = {key: kwargs[key] for key in inspect.signature(cached_file).parameters if key in kwargs}
|
| 293 |
+
# We don't want to raise
|
| 294 |
+
cached_file_kwargs.update(
|
| 295 |
+
{
|
| 296 |
+
"_raise_exceptions_for_gated_repo": False,
|
| 297 |
+
"_raise_exceptions_for_missing_entries": False,
|
| 298 |
+
"_raise_exceptions_for_connection_errors": False,
|
| 299 |
+
}
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Let's start by checking whether the processor class is saved in a processor config
|
| 303 |
+
processor_config_file = cached_file(pretrained_model_name_or_path, PROCESSOR_NAME, **cached_file_kwargs)
|
| 304 |
+
if processor_config_file is not None:
|
| 305 |
+
config_dict, _ = ProcessorMixin.get_processor_dict(pretrained_model_name_or_path, **kwargs)
|
| 306 |
+
processor_class = config_dict.get("processor_class", None)
|
| 307 |
+
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
| 308 |
+
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
| 309 |
+
|
| 310 |
+
if processor_class is None:
|
| 311 |
+
# If not found, let's check whether the processor class is saved in an image processor config
|
| 312 |
+
preprocessor_config_file = cached_file(
|
| 313 |
+
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
|
| 314 |
+
)
|
| 315 |
+
if preprocessor_config_file is not None:
|
| 316 |
+
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
|
| 317 |
+
processor_class = config_dict.get("processor_class", None)
|
| 318 |
+
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
| 319 |
+
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
| 320 |
+
|
| 321 |
+
# Saved as video processor
|
| 322 |
+
if preprocessor_config_file is None:
|
| 323 |
+
preprocessor_config_file = cached_file(
|
| 324 |
+
pretrained_model_name_or_path, VIDEO_PROCESSOR_NAME, **cached_file_kwargs
|
| 325 |
+
)
|
| 326 |
+
if preprocessor_config_file is not None:
|
| 327 |
+
config_dict, _ = BaseVideoProcessor.get_video_processor_dict(
|
| 328 |
+
pretrained_model_name_or_path, **kwargs
|
| 329 |
+
)
|
| 330 |
+
processor_class = config_dict.get("processor_class", None)
|
| 331 |
+
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
| 332 |
+
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
| 333 |
+
|
| 334 |
+
# Saved as feature extractor
|
| 335 |
+
if preprocessor_config_file is None:
|
| 336 |
+
preprocessor_config_file = cached_file(
|
| 337 |
+
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
|
| 338 |
+
)
|
| 339 |
+
if preprocessor_config_file is not None and processor_class is None:
|
| 340 |
+
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(
|
| 341 |
+
pretrained_model_name_or_path, **kwargs
|
| 342 |
+
)
|
| 343 |
+
processor_class = config_dict.get("processor_class", None)
|
| 344 |
+
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
| 345 |
+
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
| 346 |
+
|
| 347 |
+
if processor_class is None:
|
| 348 |
+
# Next, let's check whether the processor class is saved in a tokenizer
|
| 349 |
+
tokenizer_config_file = cached_file(
|
| 350 |
+
pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **cached_file_kwargs
|
| 351 |
+
)
|
| 352 |
+
if tokenizer_config_file is not None:
|
| 353 |
+
with open(tokenizer_config_file, encoding="utf-8") as reader:
|
| 354 |
+
config_dict = json.load(reader)
|
| 355 |
+
|
| 356 |
+
processor_class = config_dict.get("processor_class", None)
|
| 357 |
+
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
| 358 |
+
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
| 359 |
+
|
| 360 |
+
if processor_class is None:
|
| 361 |
+
# Otherwise, load config, if it can be loaded.
|
| 362 |
+
if not isinstance(config, PretrainedConfig):
|
| 363 |
+
config = AutoConfig.from_pretrained(
|
| 364 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# And check if the config contains the processor class.
|
| 368 |
+
processor_class = getattr(config, "processor_class", None)
|
| 369 |
+
if hasattr(config, "auto_map") and "AutoProcessor" in config.auto_map:
|
| 370 |
+
processor_auto_map = config.auto_map["AutoProcessor"]
|
| 371 |
+
|
| 372 |
+
if processor_class is not None:
|
| 373 |
+
processor_class = processor_class_from_name(processor_class)
|
| 374 |
+
|
| 375 |
+
has_remote_code = processor_auto_map is not None
|
| 376 |
+
has_local_code = processor_class is not None or type(config) in PROCESSOR_MAPPING
|
| 377 |
+
if has_remote_code:
|
| 378 |
+
if "--" in processor_auto_map:
|
| 379 |
+
upstream_repo = processor_auto_map.split("--")[0]
|
| 380 |
+
else:
|
| 381 |
+
upstream_repo = None
|
| 382 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 383 |
+
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
if has_remote_code and trust_remote_code:
|
| 387 |
+
processor_class = get_class_from_dynamic_module(
|
| 388 |
+
processor_auto_map, pretrained_model_name_or_path, **kwargs
|
| 389 |
+
)
|
| 390 |
+
_ = kwargs.pop("code_revision", None)
|
| 391 |
+
processor_class.register_for_auto_class()
|
| 392 |
+
return processor_class.from_pretrained(
|
| 393 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 394 |
+
)
|
| 395 |
+
elif processor_class is not None:
|
| 396 |
+
return processor_class.from_pretrained(
|
| 397 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 398 |
+
)
|
| 399 |
+
# Last try: we use the PROCESSOR_MAPPING.
|
| 400 |
+
elif type(config) in PROCESSOR_MAPPING:
|
| 401 |
+
return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 402 |
+
|
| 403 |
+
# At this stage, there doesn't seem to be a `Processor` class available for this model, so let's try a
|
| 404 |
+
# tokenizer.
|
| 405 |
+
try:
|
| 406 |
+
return AutoTokenizer.from_pretrained(
|
| 407 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 408 |
+
)
|
| 409 |
+
except Exception:
|
| 410 |
+
try:
|
| 411 |
+
return AutoImageProcessor.from_pretrained(
|
| 412 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 413 |
+
)
|
| 414 |
+
except Exception:
|
| 415 |
+
pass
|
| 416 |
+
|
| 417 |
+
try:
|
| 418 |
+
return AutoFeatureExtractor.from_pretrained(
|
| 419 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 420 |
+
)
|
| 421 |
+
except Exception:
|
| 422 |
+
pass
|
| 423 |
+
|
| 424 |
+
raise ValueError(
|
| 425 |
+
f"Unrecognized processing class in {pretrained_model_name_or_path}. Can't instantiate a processor, a "
|
| 426 |
+
"tokenizer, an image processor or a feature extractor for this model. Make sure the repository contains "
|
| 427 |
+
"the files of at least one of those processing classes."
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
@staticmethod
|
| 431 |
+
def register(config_class, processor_class, exist_ok=False):
|
| 432 |
+
"""
|
| 433 |
+
Register a new processor for this class.
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
config_class ([`PretrainedConfig`]):
|
| 437 |
+
The configuration corresponding to the model to register.
|
| 438 |
+
processor_class ([`ProcessorMixin`]): The processor to register.
|
| 439 |
+
"""
|
| 440 |
+
PROCESSOR_MAPPING.register(config_class, processor_class, exist_ok=exist_ok)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
__all__ = ["PROCESSOR_MAPPING", "AutoProcessor"]
|
venv/lib/python3.13/site-packages/transformers/models/auto/tokenization_auto.py
ADDED
|
@@ -0,0 +1,1235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Auto Tokenizer class."""
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import warnings
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
from typing import Any, Optional, Union
|
| 23 |
+
|
| 24 |
+
from transformers.utils.import_utils import is_mistral_common_available
|
| 25 |
+
|
| 26 |
+
from ...configuration_utils import PretrainedConfig
|
| 27 |
+
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
| 28 |
+
from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint
|
| 29 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 30 |
+
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
| 31 |
+
from ...utils import (
|
| 32 |
+
cached_file,
|
| 33 |
+
extract_commit_hash,
|
| 34 |
+
is_g2p_en_available,
|
| 35 |
+
is_sentencepiece_available,
|
| 36 |
+
is_tokenizers_available,
|
| 37 |
+
logging,
|
| 38 |
+
)
|
| 39 |
+
from ..encoder_decoder import EncoderDecoderConfig
|
| 40 |
+
from .auto_factory import _LazyAutoMapping
|
| 41 |
+
from .configuration_auto import (
|
| 42 |
+
CONFIG_MAPPING_NAMES,
|
| 43 |
+
AutoConfig,
|
| 44 |
+
config_class_to_model_type,
|
| 45 |
+
model_type_to_module_name,
|
| 46 |
+
replace_list_option_in_docstrings,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if is_tokenizers_available():
|
| 51 |
+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
| 52 |
+
else:
|
| 53 |
+
PreTrainedTokenizerFast = None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
logger = logging.get_logger(__name__)
|
| 57 |
+
|
| 58 |
+
# Explicit rather than inferred generics to significantly improves completion suggestion performance for language servers.
|
| 59 |
+
TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
| 60 |
+
[
|
| 61 |
+
(
|
| 62 |
+
"aimv2",
|
| 63 |
+
(
|
| 64 |
+
"CLIPTokenizer",
|
| 65 |
+
"CLIPTokenizerFast" if is_tokenizers_available() else None,
|
| 66 |
+
),
|
| 67 |
+
),
|
| 68 |
+
(
|
| 69 |
+
"albert",
|
| 70 |
+
(
|
| 71 |
+
"AlbertTokenizer" if is_sentencepiece_available() else None,
|
| 72 |
+
"AlbertTokenizerFast" if is_tokenizers_available() else None,
|
| 73 |
+
),
|
| 74 |
+
),
|
| 75 |
+
("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 76 |
+
("arcee", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 77 |
+
("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 78 |
+
("aya_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
| 79 |
+
("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 80 |
+
("bart", ("BartTokenizer", "BartTokenizerFast")),
|
| 81 |
+
(
|
| 82 |
+
"barthez",
|
| 83 |
+
(
|
| 84 |
+
"BarthezTokenizer" if is_sentencepiece_available() else None,
|
| 85 |
+
"BarthezTokenizerFast" if is_tokenizers_available() else None,
|
| 86 |
+
),
|
| 87 |
+
),
|
| 88 |
+
("bartpho", ("BartphoTokenizer", None)),
|
| 89 |
+
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 90 |
+
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
|
| 91 |
+
("bert-japanese", ("BertJapaneseTokenizer", None)),
|
| 92 |
+
("bertweet", ("BertweetTokenizer", None)),
|
| 93 |
+
(
|
| 94 |
+
"big_bird",
|
| 95 |
+
(
|
| 96 |
+
"BigBirdTokenizer" if is_sentencepiece_available() else None,
|
| 97 |
+
"BigBirdTokenizerFast" if is_tokenizers_available() else None,
|
| 98 |
+
),
|
| 99 |
+
),
|
| 100 |
+
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
|
| 101 |
+
("biogpt", ("BioGptTokenizer", None)),
|
| 102 |
+
("bitnet", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 103 |
+
("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
|
| 104 |
+
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
|
| 105 |
+
("blip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 106 |
+
("blip-2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 107 |
+
("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)),
|
| 108 |
+
("blt", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 109 |
+
("bridgetower", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
| 110 |
+
("bros", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 111 |
+
("byt5", ("ByT5Tokenizer", None)),
|
| 112 |
+
(
|
| 113 |
+
"camembert",
|
| 114 |
+
(
|
| 115 |
+
"CamembertTokenizer" if is_sentencepiece_available() else None,
|
| 116 |
+
"CamembertTokenizerFast" if is_tokenizers_available() else None,
|
| 117 |
+
),
|
| 118 |
+
),
|
| 119 |
+
("canine", ("CanineTokenizer", None)),
|
| 120 |
+
(
|
| 121 |
+
"chameleon",
|
| 122 |
+
(
|
| 123 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 124 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 125 |
+
),
|
| 126 |
+
),
|
| 127 |
+
("chinese_clip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 128 |
+
(
|
| 129 |
+
"clap",
|
| 130 |
+
(
|
| 131 |
+
"RobertaTokenizer",
|
| 132 |
+
"RobertaTokenizerFast" if is_tokenizers_available() else None,
|
| 133 |
+
),
|
| 134 |
+
),
|
| 135 |
+
(
|
| 136 |
+
"clip",
|
| 137 |
+
(
|
| 138 |
+
"CLIPTokenizer",
|
| 139 |
+
"CLIPTokenizerFast" if is_tokenizers_available() else None,
|
| 140 |
+
),
|
| 141 |
+
),
|
| 142 |
+
(
|
| 143 |
+
"clipseg",
|
| 144 |
+
(
|
| 145 |
+
"CLIPTokenizer",
|
| 146 |
+
"CLIPTokenizerFast" if is_tokenizers_available() else None,
|
| 147 |
+
),
|
| 148 |
+
),
|
| 149 |
+
("clvp", ("ClvpTokenizer", None)),
|
| 150 |
+
(
|
| 151 |
+
"code_llama",
|
| 152 |
+
(
|
| 153 |
+
"CodeLlamaTokenizer" if is_sentencepiece_available() else None,
|
| 154 |
+
"CodeLlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 155 |
+
),
|
| 156 |
+
),
|
| 157 |
+
("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
|
| 158 |
+
("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
| 159 |
+
("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
| 160 |
+
("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 161 |
+
("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 162 |
+
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
|
| 163 |
+
(
|
| 164 |
+
"cpm",
|
| 165 |
+
(
|
| 166 |
+
"CpmTokenizer" if is_sentencepiece_available() else None,
|
| 167 |
+
"CpmTokenizerFast" if is_tokenizers_available() else None,
|
| 168 |
+
),
|
| 169 |
+
),
|
| 170 |
+
("cpmant", ("CpmAntTokenizer", None)),
|
| 171 |
+
("csm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 172 |
+
("ctrl", ("CTRLTokenizer", None)),
|
| 173 |
+
("data2vec-audio", ("Wav2Vec2CTCTokenizer", None)),
|
| 174 |
+
("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
| 175 |
+
("dbrx", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 176 |
+
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
|
| 177 |
+
(
|
| 178 |
+
"deberta-v2",
|
| 179 |
+
(
|
| 180 |
+
"DebertaV2Tokenizer" if is_sentencepiece_available() else None,
|
| 181 |
+
"DebertaV2TokenizerFast" if is_tokenizers_available() else None,
|
| 182 |
+
),
|
| 183 |
+
),
|
| 184 |
+
(
|
| 185 |
+
"deepseek_v2",
|
| 186 |
+
(
|
| 187 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 188 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 189 |
+
),
|
| 190 |
+
),
|
| 191 |
+
(
|
| 192 |
+
"deepseek_v3",
|
| 193 |
+
(
|
| 194 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 195 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 196 |
+
),
|
| 197 |
+
),
|
| 198 |
+
(
|
| 199 |
+
"deepseek_vl",
|
| 200 |
+
(
|
| 201 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 202 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 203 |
+
),
|
| 204 |
+
),
|
| 205 |
+
(
|
| 206 |
+
"deepseek_vl_hybrid",
|
| 207 |
+
(
|
| 208 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 209 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 210 |
+
),
|
| 211 |
+
),
|
| 212 |
+
("dia", ("DiaTokenizer", None)),
|
| 213 |
+
(
|
| 214 |
+
"diffllama",
|
| 215 |
+
(
|
| 216 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 217 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 218 |
+
),
|
| 219 |
+
),
|
| 220 |
+
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
|
| 221 |
+
(
|
| 222 |
+
"dpr",
|
| 223 |
+
(
|
| 224 |
+
"DPRQuestionEncoderTokenizer",
|
| 225 |
+
"DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
|
| 226 |
+
),
|
| 227 |
+
),
|
| 228 |
+
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
|
| 229 |
+
("emu3", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 230 |
+
("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 231 |
+
("ernie4_5", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 232 |
+
("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 233 |
+
("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
|
| 234 |
+
("esm", ("EsmTokenizer", None)),
|
| 235 |
+
(
|
| 236 |
+
"exaone4",
|
| 237 |
+
(
|
| 238 |
+
"GPT2Tokenizer" if is_tokenizers_available() else None,
|
| 239 |
+
"GPT2TokenizerFast" if is_tokenizers_available() else None,
|
| 240 |
+
),
|
| 241 |
+
),
|
| 242 |
+
("falcon", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 243 |
+
("falcon_mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 244 |
+
(
|
| 245 |
+
"fastspeech2_conformer",
|
| 246 |
+
("FastSpeech2ConformerTokenizer" if is_g2p_en_available() else None, None),
|
| 247 |
+
),
|
| 248 |
+
("flaubert", ("FlaubertTokenizer", None)),
|
| 249 |
+
("flex_olmo", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 250 |
+
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
|
| 251 |
+
("fsmt", ("FSMTTokenizer", None)),
|
| 252 |
+
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
|
| 253 |
+
(
|
| 254 |
+
"gemma",
|
| 255 |
+
(
|
| 256 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 257 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 258 |
+
),
|
| 259 |
+
),
|
| 260 |
+
(
|
| 261 |
+
"gemma2",
|
| 262 |
+
(
|
| 263 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 264 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 265 |
+
),
|
| 266 |
+
),
|
| 267 |
+
(
|
| 268 |
+
"gemma3",
|
| 269 |
+
(
|
| 270 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 271 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 272 |
+
),
|
| 273 |
+
),
|
| 274 |
+
(
|
| 275 |
+
"gemma3_text",
|
| 276 |
+
(
|
| 277 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 278 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 279 |
+
),
|
| 280 |
+
),
|
| 281 |
+
(
|
| 282 |
+
"gemma3n",
|
| 283 |
+
(
|
| 284 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 285 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 286 |
+
),
|
| 287 |
+
),
|
| 288 |
+
(
|
| 289 |
+
"gemma3n_text",
|
| 290 |
+
(
|
| 291 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 292 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 293 |
+
),
|
| 294 |
+
),
|
| 295 |
+
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 296 |
+
("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 297 |
+
("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 298 |
+
("glm4_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 299 |
+
("glm4v", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 300 |
+
("glm4v_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 301 |
+
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
|
| 302 |
+
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 303 |
+
("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 304 |
+
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 305 |
+
("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 306 |
+
("gpt_neox_japanese", ("GPTNeoXJapaneseTokenizer", None)),
|
| 307 |
+
("gpt_oss", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 308 |
+
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 309 |
+
("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
|
| 310 |
+
("granite", ("GPT2Tokenizer", None)),
|
| 311 |
+
("granitemoe", ("GPT2Tokenizer", None)),
|
| 312 |
+
("granitemoehybrid", ("GPT2Tokenizer", None)),
|
| 313 |
+
("granitemoeshared", ("GPT2Tokenizer", None)),
|
| 314 |
+
("grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 315 |
+
("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
| 316 |
+
("helium", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 317 |
+
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
|
| 318 |
+
("hubert", ("Wav2Vec2CTCTokenizer", None)),
|
| 319 |
+
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
| 320 |
+
("idefics", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 321 |
+
("idefics2", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 322 |
+
("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 323 |
+
("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 324 |
+
("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 325 |
+
("internvl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 326 |
+
(
|
| 327 |
+
"jamba",
|
| 328 |
+
(
|
| 329 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 330 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 331 |
+
),
|
| 332 |
+
),
|
| 333 |
+
("janus", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 334 |
+
(
|
| 335 |
+
"jetmoe",
|
| 336 |
+
(
|
| 337 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 338 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 339 |
+
),
|
| 340 |
+
),
|
| 341 |
+
("jukebox", ("JukeboxTokenizer", None)),
|
| 342 |
+
(
|
| 343 |
+
"kosmos-2",
|
| 344 |
+
(
|
| 345 |
+
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
|
| 346 |
+
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
|
| 347 |
+
),
|
| 348 |
+
),
|
| 349 |
+
("kosmos-2.5", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 350 |
+
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
|
| 351 |
+
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
|
| 352 |
+
("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
|
| 353 |
+
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
|
| 354 |
+
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
|
| 355 |
+
("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
|
| 356 |
+
(
|
| 357 |
+
"llama",
|
| 358 |
+
(
|
| 359 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 360 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 361 |
+
),
|
| 362 |
+
),
|
| 363 |
+
(
|
| 364 |
+
"llama4",
|
| 365 |
+
(
|
| 366 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 367 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 368 |
+
),
|
| 369 |
+
),
|
| 370 |
+
(
|
| 371 |
+
"llama4_text",
|
| 372 |
+
(
|
| 373 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 374 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 375 |
+
),
|
| 376 |
+
),
|
| 377 |
+
("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 378 |
+
("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 379 |
+
("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 380 |
+
("llava_onevision", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 381 |
+
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
|
| 382 |
+
(
|
| 383 |
+
"longt5",
|
| 384 |
+
(
|
| 385 |
+
"T5Tokenizer" if is_sentencepiece_available() else None,
|
| 386 |
+
"T5TokenizerFast" if is_tokenizers_available() else None,
|
| 387 |
+
),
|
| 388 |
+
),
|
| 389 |
+
("luke", ("LukeTokenizer", None)),
|
| 390 |
+
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
|
| 391 |
+
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
|
| 392 |
+
("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 393 |
+
("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 394 |
+
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
|
| 395 |
+
(
|
| 396 |
+
"mbart",
|
| 397 |
+
(
|
| 398 |
+
"MBartTokenizer" if is_sentencepiece_available() else None,
|
| 399 |
+
"MBartTokenizerFast" if is_tokenizers_available() else None,
|
| 400 |
+
),
|
| 401 |
+
),
|
| 402 |
+
(
|
| 403 |
+
"mbart50",
|
| 404 |
+
(
|
| 405 |
+
"MBart50Tokenizer" if is_sentencepiece_available() else None,
|
| 406 |
+
"MBart50TokenizerFast" if is_tokenizers_available() else None,
|
| 407 |
+
),
|
| 408 |
+
),
|
| 409 |
+
("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
| 410 |
+
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 411 |
+
(
|
| 412 |
+
"metaclip_2",
|
| 413 |
+
(
|
| 414 |
+
"XLMRobertaTokenizer",
|
| 415 |
+
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
|
| 416 |
+
),
|
| 417 |
+
),
|
| 418 |
+
("mgp-str", ("MgpstrTokenizer", None)),
|
| 419 |
+
(
|
| 420 |
+
"minimax",
|
| 421 |
+
(
|
| 422 |
+
"GPT2Tokenizer" if is_sentencepiece_available() else None,
|
| 423 |
+
"GPT2TokenizerFast" if is_tokenizers_available() else None,
|
| 424 |
+
),
|
| 425 |
+
),
|
| 426 |
+
(
|
| 427 |
+
"ministral",
|
| 428 |
+
(
|
| 429 |
+
"MistralCommonTokenizer"
|
| 430 |
+
if is_mistral_common_available()
|
| 431 |
+
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
|
| 432 |
+
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
| 433 |
+
),
|
| 434 |
+
),
|
| 435 |
+
(
|
| 436 |
+
"mistral",
|
| 437 |
+
(
|
| 438 |
+
"MistralCommonTokenizer"
|
| 439 |
+
if is_mistral_common_available()
|
| 440 |
+
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
|
| 441 |
+
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
| 442 |
+
),
|
| 443 |
+
),
|
| 444 |
+
(
|
| 445 |
+
"mistral3",
|
| 446 |
+
(
|
| 447 |
+
"MistralCommonTokenizer"
|
| 448 |
+
if is_mistral_common_available()
|
| 449 |
+
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
|
| 450 |
+
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
| 451 |
+
),
|
| 452 |
+
),
|
| 453 |
+
(
|
| 454 |
+
"mixtral",
|
| 455 |
+
(
|
| 456 |
+
"MistralCommonTokenizer"
|
| 457 |
+
if is_mistral_common_available()
|
| 458 |
+
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
|
| 459 |
+
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
| 460 |
+
),
|
| 461 |
+
),
|
| 462 |
+
("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 463 |
+
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
|
| 464 |
+
("mm-grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 465 |
+
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
|
| 466 |
+
("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 467 |
+
("moonshine", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 468 |
+
("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 469 |
+
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
|
| 470 |
+
("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 471 |
+
("mra", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
| 472 |
+
(
|
| 473 |
+
"mt5",
|
| 474 |
+
(
|
| 475 |
+
"MT5Tokenizer" if is_sentencepiece_available() else None,
|
| 476 |
+
"MT5TokenizerFast" if is_tokenizers_available() else None,
|
| 477 |
+
),
|
| 478 |
+
),
|
| 479 |
+
("musicgen", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
| 480 |
+
("musicgen_melody", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
| 481 |
+
("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)),
|
| 482 |
+
("myt5", ("MyT5Tokenizer", None)),
|
| 483 |
+
("nemotron", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 484 |
+
("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 485 |
+
(
|
| 486 |
+
"nllb",
|
| 487 |
+
(
|
| 488 |
+
"NllbTokenizer" if is_sentencepiece_available() else None,
|
| 489 |
+
"NllbTokenizerFast" if is_tokenizers_available() else None,
|
| 490 |
+
),
|
| 491 |
+
),
|
| 492 |
+
(
|
| 493 |
+
"nllb-moe",
|
| 494 |
+
(
|
| 495 |
+
"NllbTokenizer" if is_sentencepiece_available() else None,
|
| 496 |
+
"NllbTokenizerFast" if is_tokenizers_available() else None,
|
| 497 |
+
),
|
| 498 |
+
),
|
| 499 |
+
(
|
| 500 |
+
"nystromformer",
|
| 501 |
+
(
|
| 502 |
+
"AlbertTokenizer" if is_sentencepiece_available() else None,
|
| 503 |
+
"AlbertTokenizerFast" if is_tokenizers_available() else None,
|
| 504 |
+
),
|
| 505 |
+
),
|
| 506 |
+
("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 507 |
+
("olmo2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 508 |
+
("olmo3", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 509 |
+
("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 510 |
+
(
|
| 511 |
+
"omdet-turbo",
|
| 512 |
+
("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None),
|
| 513 |
+
),
|
| 514 |
+
("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
| 515 |
+
(
|
| 516 |
+
"openai-gpt",
|
| 517 |
+
("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None),
|
| 518 |
+
),
|
| 519 |
+
("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 520 |
+
("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
| 521 |
+
("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
| 522 |
+
("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 523 |
+
("parakeet", ("ParakeetCTCTokenizer", None)),
|
| 524 |
+
(
|
| 525 |
+
"pegasus",
|
| 526 |
+
(
|
| 527 |
+
"PegasusTokenizer" if is_sentencepiece_available() else None,
|
| 528 |
+
"PegasusTokenizerFast" if is_tokenizers_available() else None,
|
| 529 |
+
),
|
| 530 |
+
),
|
| 531 |
+
(
|
| 532 |
+
"pegasus_x",
|
| 533 |
+
(
|
| 534 |
+
"PegasusTokenizer" if is_sentencepiece_available() else None,
|
| 535 |
+
"PegasusTokenizerFast" if is_tokenizers_available() else None,
|
| 536 |
+
),
|
| 537 |
+
),
|
| 538 |
+
(
|
| 539 |
+
"perceiver",
|
| 540 |
+
(
|
| 541 |
+
"PerceiverTokenizer",
|
| 542 |
+
None,
|
| 543 |
+
),
|
| 544 |
+
),
|
| 545 |
+
(
|
| 546 |
+
"persimmon",
|
| 547 |
+
(
|
| 548 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 549 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 550 |
+
),
|
| 551 |
+
),
|
| 552 |
+
("phi", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
|
| 553 |
+
("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 554 |
+
("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 555 |
+
("phobert", ("PhobertTokenizer", None)),
|
| 556 |
+
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
| 557 |
+
(
|
| 558 |
+
"pixtral",
|
| 559 |
+
(
|
| 560 |
+
None,
|
| 561 |
+
"MistralCommonTokenizer"
|
| 562 |
+
if is_mistral_common_available()
|
| 563 |
+
else ("PreTrainedTokenizerFast" if is_tokenizers_available() else None),
|
| 564 |
+
),
|
| 565 |
+
),
|
| 566 |
+
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
|
| 567 |
+
("prophetnet", ("ProphetNetTokenizer", None)),
|
| 568 |
+
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 569 |
+
(
|
| 570 |
+
"qwen2",
|
| 571 |
+
(
|
| 572 |
+
"Qwen2Tokenizer",
|
| 573 |
+
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
|
| 574 |
+
),
|
| 575 |
+
),
|
| 576 |
+
("qwen2_5_omni", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 577 |
+
("qwen2_5_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 578 |
+
("qwen2_audio", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 579 |
+
(
|
| 580 |
+
"qwen2_moe",
|
| 581 |
+
(
|
| 582 |
+
"Qwen2Tokenizer",
|
| 583 |
+
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
|
| 584 |
+
),
|
| 585 |
+
),
|
| 586 |
+
("qwen2_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 587 |
+
(
|
| 588 |
+
"qwen3",
|
| 589 |
+
(
|
| 590 |
+
"Qwen2Tokenizer",
|
| 591 |
+
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
|
| 592 |
+
),
|
| 593 |
+
),
|
| 594 |
+
(
|
| 595 |
+
"qwen3_moe",
|
| 596 |
+
(
|
| 597 |
+
"Qwen2Tokenizer",
|
| 598 |
+
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
|
| 599 |
+
),
|
| 600 |
+
),
|
| 601 |
+
(
|
| 602 |
+
"qwen3_next",
|
| 603 |
+
(
|
| 604 |
+
"Qwen2Tokenizer",
|
| 605 |
+
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
|
| 606 |
+
),
|
| 607 |
+
),
|
| 608 |
+
("qwen3_omni_moe", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 609 |
+
("qwen3_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 610 |
+
("qwen3_vl_moe", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 611 |
+
("rag", ("RagTokenizer", None)),
|
| 612 |
+
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
|
| 613 |
+
(
|
| 614 |
+
"recurrent_gemma",
|
| 615 |
+
(
|
| 616 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 617 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 618 |
+
),
|
| 619 |
+
),
|
| 620 |
+
(
|
| 621 |
+
"reformer",
|
| 622 |
+
(
|
| 623 |
+
"ReformerTokenizer" if is_sentencepiece_available() else None,
|
| 624 |
+
"ReformerTokenizerFast" if is_tokenizers_available() else None,
|
| 625 |
+
),
|
| 626 |
+
),
|
| 627 |
+
(
|
| 628 |
+
"rembert",
|
| 629 |
+
(
|
| 630 |
+
"RemBertTokenizer" if is_sentencepiece_available() else None,
|
| 631 |
+
"RemBertTokenizerFast" if is_tokenizers_available() else None,
|
| 632 |
+
),
|
| 633 |
+
),
|
| 634 |
+
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
|
| 635 |
+
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
| 636 |
+
(
|
| 637 |
+
"roberta-prelayernorm",
|
| 638 |
+
("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None),
|
| 639 |
+
),
|
| 640 |
+
("roc_bert", ("RoCBertTokenizer", None)),
|
| 641 |
+
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
|
| 642 |
+
("rwkv", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 643 |
+
(
|
| 644 |
+
"seamless_m4t",
|
| 645 |
+
(
|
| 646 |
+
"SeamlessM4TTokenizer" if is_sentencepiece_available() else None,
|
| 647 |
+
"SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
|
| 648 |
+
),
|
| 649 |
+
),
|
| 650 |
+
(
|
| 651 |
+
"seamless_m4t_v2",
|
| 652 |
+
(
|
| 653 |
+
"SeamlessM4TTokenizer" if is_sentencepiece_available() else None,
|
| 654 |
+
"SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
|
| 655 |
+
),
|
| 656 |
+
),
|
| 657 |
+
(
|
| 658 |
+
"shieldgemma2",
|
| 659 |
+
(
|
| 660 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 661 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 662 |
+
),
|
| 663 |
+
),
|
| 664 |
+
("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)),
|
| 665 |
+
(
|
| 666 |
+
"siglip2",
|
| 667 |
+
(
|
| 668 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 669 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 670 |
+
),
|
| 671 |
+
),
|
| 672 |
+
("smollm3", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 673 |
+
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
|
| 674 |
+
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
|
| 675 |
+
("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
|
| 676 |
+
("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
|
| 677 |
+
(
|
| 678 |
+
"squeezebert",
|
| 679 |
+
("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
|
| 680 |
+
),
|
| 681 |
+
("stablelm", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 682 |
+
("starcoder2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 683 |
+
(
|
| 684 |
+
"switch_transformers",
|
| 685 |
+
(
|
| 686 |
+
"T5Tokenizer" if is_sentencepiece_available() else None,
|
| 687 |
+
"T5TokenizerFast" if is_tokenizers_available() else None,
|
| 688 |
+
),
|
| 689 |
+
),
|
| 690 |
+
(
|
| 691 |
+
"t5",
|
| 692 |
+
(
|
| 693 |
+
"T5Tokenizer" if is_sentencepiece_available() else None,
|
| 694 |
+
"T5TokenizerFast" if is_tokenizers_available() else None,
|
| 695 |
+
),
|
| 696 |
+
),
|
| 697 |
+
(
|
| 698 |
+
"t5gemma",
|
| 699 |
+
(
|
| 700 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 701 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 702 |
+
),
|
| 703 |
+
),
|
| 704 |
+
("tapas", ("TapasTokenizer", None)),
|
| 705 |
+
("tapex", ("TapexTokenizer", None)),
|
| 706 |
+
("transfo-xl", ("TransfoXLTokenizer", None)),
|
| 707 |
+
("tvp", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 708 |
+
(
|
| 709 |
+
"udop",
|
| 710 |
+
(
|
| 711 |
+
"UdopTokenizer" if is_sentencepiece_available() else None,
|
| 712 |
+
"UdopTokenizerFast" if is_tokenizers_available() else None,
|
| 713 |
+
),
|
| 714 |
+
),
|
| 715 |
+
(
|
| 716 |
+
"umt5",
|
| 717 |
+
(
|
| 718 |
+
"T5Tokenizer" if is_sentencepiece_available() else None,
|
| 719 |
+
"T5TokenizerFast" if is_tokenizers_available() else None,
|
| 720 |
+
),
|
| 721 |
+
),
|
| 722 |
+
("video_llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 723 |
+
("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 724 |
+
("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 725 |
+
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 726 |
+
("vits", ("VitsTokenizer", None)),
|
| 727 |
+
(
|
| 728 |
+
"voxtral",
|
| 729 |
+
(
|
| 730 |
+
"MistralCommonTokenizer" if is_mistral_common_available() else None,
|
| 731 |
+
"PreTrainedTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
| 732 |
+
),
|
| 733 |
+
),
|
| 734 |
+
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
|
| 735 |
+
("wav2vec2-bert", ("Wav2Vec2CTCTokenizer", None)),
|
| 736 |
+
("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
|
| 737 |
+
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
|
| 738 |
+
("whisper", ("WhisperTokenizer", "WhisperTokenizerFast" if is_tokenizers_available() else None)),
|
| 739 |
+
("xclip", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
| 740 |
+
(
|
| 741 |
+
"xglm",
|
| 742 |
+
(
|
| 743 |
+
"XGLMTokenizer" if is_sentencepiece_available() else None,
|
| 744 |
+
"XGLMTokenizerFast" if is_tokenizers_available() else None,
|
| 745 |
+
),
|
| 746 |
+
),
|
| 747 |
+
("xlm", ("XLMTokenizer", None)),
|
| 748 |
+
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
|
| 749 |
+
(
|
| 750 |
+
"xlm-roberta",
|
| 751 |
+
(
|
| 752 |
+
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
|
| 753 |
+
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
|
| 754 |
+
),
|
| 755 |
+
),
|
| 756 |
+
(
|
| 757 |
+
"xlm-roberta-xl",
|
| 758 |
+
(
|
| 759 |
+
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
|
| 760 |
+
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
|
| 761 |
+
),
|
| 762 |
+
),
|
| 763 |
+
(
|
| 764 |
+
"xlnet",
|
| 765 |
+
(
|
| 766 |
+
"XLNetTokenizer" if is_sentencepiece_available() else None,
|
| 767 |
+
"XLNetTokenizerFast" if is_tokenizers_available() else None,
|
| 768 |
+
),
|
| 769 |
+
),
|
| 770 |
+
("xlstm", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 771 |
+
(
|
| 772 |
+
"xmod",
|
| 773 |
+
(
|
| 774 |
+
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
|
| 775 |
+
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
|
| 776 |
+
),
|
| 777 |
+
),
|
| 778 |
+
(
|
| 779 |
+
"yoso",
|
| 780 |
+
(
|
| 781 |
+
"AlbertTokenizer" if is_sentencepiece_available() else None,
|
| 782 |
+
"AlbertTokenizerFast" if is_tokenizers_available() else None,
|
| 783 |
+
),
|
| 784 |
+
),
|
| 785 |
+
(
|
| 786 |
+
"zamba",
|
| 787 |
+
(
|
| 788 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 789 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 790 |
+
),
|
| 791 |
+
),
|
| 792 |
+
(
|
| 793 |
+
"zamba2",
|
| 794 |
+
(
|
| 795 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 796 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 797 |
+
),
|
| 798 |
+
),
|
| 799 |
+
]
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
|
| 803 |
+
|
| 804 |
+
CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
|
| 808 |
+
if class_name == "PreTrainedTokenizerFast":
|
| 809 |
+
return PreTrainedTokenizerFast
|
| 810 |
+
|
| 811 |
+
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
|
| 812 |
+
if class_name in tokenizers:
|
| 813 |
+
module_name = model_type_to_module_name(module_name)
|
| 814 |
+
if module_name in ["mistral", "mixtral", "ministral"] and class_name == "MistralCommonTokenizer":
|
| 815 |
+
module = importlib.import_module(".tokenization_mistral_common", "transformers")
|
| 816 |
+
else:
|
| 817 |
+
module = importlib.import_module(f".{module_name}", "transformers.models")
|
| 818 |
+
try:
|
| 819 |
+
return getattr(module, class_name)
|
| 820 |
+
except AttributeError:
|
| 821 |
+
continue
|
| 822 |
+
|
| 823 |
+
for tokenizers in TOKENIZER_MAPPING._extra_content.values():
|
| 824 |
+
for tokenizer in tokenizers:
|
| 825 |
+
if getattr(tokenizer, "__name__", None) == class_name:
|
| 826 |
+
return tokenizer
|
| 827 |
+
|
| 828 |
+
# We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
|
| 829 |
+
# init and we return the proper dummy to get an appropriate error message.
|
| 830 |
+
main_module = importlib.import_module("transformers")
|
| 831 |
+
if hasattr(main_module, class_name):
|
| 832 |
+
return getattr(main_module, class_name)
|
| 833 |
+
|
| 834 |
+
return None
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
def get_tokenizer_config(
|
| 838 |
+
pretrained_model_name_or_path: Union[str, os.PathLike[str]],
|
| 839 |
+
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
|
| 840 |
+
force_download: bool = False,
|
| 841 |
+
resume_download: Optional[bool] = None,
|
| 842 |
+
proxies: Optional[dict[str, str]] = None,
|
| 843 |
+
token: Optional[Union[bool, str]] = None,
|
| 844 |
+
revision: Optional[str] = None,
|
| 845 |
+
local_files_only: bool = False,
|
| 846 |
+
subfolder: str = "",
|
| 847 |
+
**kwargs,
|
| 848 |
+
) -> dict[str, Any]:
|
| 849 |
+
"""
|
| 850 |
+
Loads the tokenizer configuration from a pretrained model tokenizer configuration.
|
| 851 |
+
|
| 852 |
+
Args:
|
| 853 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 854 |
+
This can be either:
|
| 855 |
+
|
| 856 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 857 |
+
huggingface.co.
|
| 858 |
+
- a path to a *directory* containing a configuration file saved using the
|
| 859 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 860 |
+
|
| 861 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 862 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
| 863 |
+
cache should not be used.
|
| 864 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 865 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
| 866 |
+
exist.
|
| 867 |
+
resume_download:
|
| 868 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 869 |
+
Will be removed in v5 of Transformers.
|
| 870 |
+
proxies (`dict[str, str]`, *optional*):
|
| 871 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 872 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 873 |
+
token (`str` or *bool*, *optional*):
|
| 874 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 875 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 876 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 877 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 878 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 879 |
+
identifier allowed by git.
|
| 880 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 881 |
+
If `True`, will only try to load the tokenizer configuration from local files.
|
| 882 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 883 |
+
In case the tokenizer config is located inside a subfolder of the model repo on huggingface.co, you can
|
| 884 |
+
specify the folder name here.
|
| 885 |
+
|
| 886 |
+
<Tip>
|
| 887 |
+
|
| 888 |
+
Passing `token=True` is required when you want to use a private model.
|
| 889 |
+
|
| 890 |
+
</Tip>
|
| 891 |
+
|
| 892 |
+
Returns:
|
| 893 |
+
`dict`: The configuration of the tokenizer.
|
| 894 |
+
|
| 895 |
+
Examples:
|
| 896 |
+
|
| 897 |
+
```python
|
| 898 |
+
# Download configuration from huggingface.co and cache.
|
| 899 |
+
tokenizer_config = get_tokenizer_config("google-bert/bert-base-uncased")
|
| 900 |
+
# This model does not have a tokenizer config so the result will be an empty dict.
|
| 901 |
+
tokenizer_config = get_tokenizer_config("FacebookAI/xlm-roberta-base")
|
| 902 |
+
|
| 903 |
+
# Save a pretrained tokenizer locally and you can reload its config
|
| 904 |
+
from transformers import AutoTokenizer
|
| 905 |
+
|
| 906 |
+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
|
| 907 |
+
tokenizer.save_pretrained("tokenizer-test")
|
| 908 |
+
tokenizer_config = get_tokenizer_config("tokenizer-test")
|
| 909 |
+
```"""
|
| 910 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 911 |
+
if use_auth_token is not None:
|
| 912 |
+
warnings.warn(
|
| 913 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 914 |
+
FutureWarning,
|
| 915 |
+
)
|
| 916 |
+
if token is not None:
|
| 917 |
+
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
| 918 |
+
token = use_auth_token
|
| 919 |
+
|
| 920 |
+
commit_hash = kwargs.get("_commit_hash")
|
| 921 |
+
resolved_config_file = cached_file(
|
| 922 |
+
pretrained_model_name_or_path,
|
| 923 |
+
TOKENIZER_CONFIG_FILE,
|
| 924 |
+
cache_dir=cache_dir,
|
| 925 |
+
force_download=force_download,
|
| 926 |
+
resume_download=resume_download,
|
| 927 |
+
proxies=proxies,
|
| 928 |
+
token=token,
|
| 929 |
+
revision=revision,
|
| 930 |
+
local_files_only=local_files_only,
|
| 931 |
+
subfolder=subfolder,
|
| 932 |
+
_raise_exceptions_for_gated_repo=False,
|
| 933 |
+
_raise_exceptions_for_missing_entries=False,
|
| 934 |
+
_raise_exceptions_for_connection_errors=False,
|
| 935 |
+
_commit_hash=commit_hash,
|
| 936 |
+
)
|
| 937 |
+
if resolved_config_file is None:
|
| 938 |
+
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
|
| 939 |
+
return {}
|
| 940 |
+
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
| 941 |
+
|
| 942 |
+
with open(resolved_config_file, encoding="utf-8") as reader:
|
| 943 |
+
result = json.load(reader)
|
| 944 |
+
result["_commit_hash"] = commit_hash
|
| 945 |
+
return result
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
class AutoTokenizer:
|
| 949 |
+
r"""
|
| 950 |
+
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
|
| 951 |
+
created with the [`AutoTokenizer.from_pretrained`] class method.
|
| 952 |
+
|
| 953 |
+
This class cannot be instantiated directly using `__init__()` (throws an error).
|
| 954 |
+
"""
|
| 955 |
+
|
| 956 |
+
def __init__(self):
|
| 957 |
+
raise OSError(
|
| 958 |
+
"AutoTokenizer is designed to be instantiated "
|
| 959 |
+
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
@classmethod
|
| 963 |
+
@replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
|
| 964 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
| 965 |
+
r"""
|
| 966 |
+
Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.
|
| 967 |
+
|
| 968 |
+
The tokenizer class to instantiate is selected based on the `model_type` property of the config object (either
|
| 969 |
+
passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
|
| 970 |
+
falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 971 |
+
|
| 972 |
+
List options
|
| 973 |
+
|
| 974 |
+
Params:
|
| 975 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 976 |
+
Can be either:
|
| 977 |
+
|
| 978 |
+
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
|
| 979 |
+
- A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
|
| 980 |
+
using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 981 |
+
- A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
|
| 982 |
+
single vocabulary file (like Bert or XLNet), e.g.: `./my_model_directory/vocab.txt`. (Not
|
| 983 |
+
applicable to all derived classes)
|
| 984 |
+
inputs (additional positional arguments, *optional*):
|
| 985 |
+
Will be passed along to the Tokenizer `__init__()` method.
|
| 986 |
+
config ([`PretrainedConfig`], *optional*)
|
| 987 |
+
The configuration object used to determine the tokenizer class to instantiate.
|
| 988 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 989 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 990 |
+
standard cache should not be used.
|
| 991 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 992 |
+
Whether or not to force the (re-)download the model weights and configuration files and override the
|
| 993 |
+
cached versions if they exist.
|
| 994 |
+
resume_download:
|
| 995 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 996 |
+
Will be removed in v5 of Transformers.
|
| 997 |
+
proxies (`dict[str, str]`, *optional*):
|
| 998 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 999 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 1000 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 1001 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 1002 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 1003 |
+
identifier allowed by git.
|
| 1004 |
+
subfolder (`str`, *optional*):
|
| 1005 |
+
In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
|
| 1006 |
+
facebook/rag-token-base), specify it here.
|
| 1007 |
+
use_fast (`bool`, *optional*, defaults to `True`):
|
| 1008 |
+
Use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported for
|
| 1009 |
+
a given model. If a fast tokenizer is not available for a given model, a normal Python-based tokenizer
|
| 1010 |
+
is returned instead.
|
| 1011 |
+
tokenizer_type (`str`, *optional*):
|
| 1012 |
+
Tokenizer type to be loaded.
|
| 1013 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 1014 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 1015 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 1016 |
+
execute code present on the Hub on your local machine.
|
| 1017 |
+
kwargs (additional keyword arguments, *optional*):
|
| 1018 |
+
Will be passed to the Tokenizer `__init__()` method. Can be used to set special tokens like
|
| 1019 |
+
`bos_token`, `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`,
|
| 1020 |
+
`additional_special_tokens`. See parameters in the `__init__()` for more details.
|
| 1021 |
+
|
| 1022 |
+
Examples:
|
| 1023 |
+
|
| 1024 |
+
```python
|
| 1025 |
+
>>> from transformers import AutoTokenizer
|
| 1026 |
+
|
| 1027 |
+
>>> # Download vocabulary from huggingface.co and cache.
|
| 1028 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 1029 |
+
|
| 1030 |
+
>>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
|
| 1031 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
|
| 1032 |
+
|
| 1033 |
+
>>> # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
|
| 1034 |
+
>>> # tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")
|
| 1035 |
+
|
| 1036 |
+
>>> # Download vocabulary from huggingface.co and define model-specific arguments
|
| 1037 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base", add_prefix_space=True)
|
| 1038 |
+
```"""
|
| 1039 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 1040 |
+
if use_auth_token is not None:
|
| 1041 |
+
warnings.warn(
|
| 1042 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 1043 |
+
FutureWarning,
|
| 1044 |
+
)
|
| 1045 |
+
if kwargs.get("token") is not None:
|
| 1046 |
+
raise ValueError(
|
| 1047 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 1048 |
+
)
|
| 1049 |
+
kwargs["token"] = use_auth_token
|
| 1050 |
+
|
| 1051 |
+
config = kwargs.pop("config", None)
|
| 1052 |
+
kwargs["_from_auto"] = True
|
| 1053 |
+
|
| 1054 |
+
use_fast = kwargs.pop("use_fast", True)
|
| 1055 |
+
tokenizer_type = kwargs.pop("tokenizer_type", None)
|
| 1056 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 1057 |
+
gguf_file = kwargs.get("gguf_file")
|
| 1058 |
+
|
| 1059 |
+
# First, let's see whether the tokenizer_type is passed so that we can leverage it
|
| 1060 |
+
if tokenizer_type is not None:
|
| 1061 |
+
tokenizer_class = None
|
| 1062 |
+
tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None)
|
| 1063 |
+
|
| 1064 |
+
if tokenizer_class_tuple is None:
|
| 1065 |
+
raise ValueError(
|
| 1066 |
+
f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of "
|
| 1067 |
+
f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES)}."
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple
|
| 1071 |
+
|
| 1072 |
+
if use_fast:
|
| 1073 |
+
if tokenizer_fast_class_name is not None:
|
| 1074 |
+
tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name)
|
| 1075 |
+
else:
|
| 1076 |
+
logger.warning(
|
| 1077 |
+
"`use_fast` is set to `True` but the tokenizer class does not have a fast version. "
|
| 1078 |
+
" Falling back to the slow version."
|
| 1079 |
+
)
|
| 1080 |
+
if tokenizer_class is None:
|
| 1081 |
+
tokenizer_class = tokenizer_class_from_name(tokenizer_class_name)
|
| 1082 |
+
|
| 1083 |
+
if tokenizer_class is None:
|
| 1084 |
+
raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.")
|
| 1085 |
+
|
| 1086 |
+
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
| 1087 |
+
|
| 1088 |
+
# Next, let's try to use the tokenizer_config file to get the tokenizer class.
|
| 1089 |
+
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
| 1090 |
+
if "_commit_hash" in tokenizer_config:
|
| 1091 |
+
kwargs["_commit_hash"] = tokenizer_config["_commit_hash"]
|
| 1092 |
+
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
|
| 1093 |
+
tokenizer_auto_map = None
|
| 1094 |
+
if "auto_map" in tokenizer_config:
|
| 1095 |
+
if isinstance(tokenizer_config["auto_map"], (tuple, list)):
|
| 1096 |
+
# Legacy format for dynamic tokenizers
|
| 1097 |
+
tokenizer_auto_map = tokenizer_config["auto_map"]
|
| 1098 |
+
else:
|
| 1099 |
+
tokenizer_auto_map = tokenizer_config["auto_map"].get("AutoTokenizer", None)
|
| 1100 |
+
|
| 1101 |
+
# If that did not work, let's try to use the config.
|
| 1102 |
+
if config_tokenizer_class is None:
|
| 1103 |
+
if not isinstance(config, PretrainedConfig):
|
| 1104 |
+
if gguf_file:
|
| 1105 |
+
gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **kwargs)
|
| 1106 |
+
config_dict = load_gguf_checkpoint(gguf_path, return_tensors=False)["config"]
|
| 1107 |
+
config = AutoConfig.for_model(**config_dict)
|
| 1108 |
+
else:
|
| 1109 |
+
config = AutoConfig.from_pretrained(
|
| 1110 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 1111 |
+
)
|
| 1112 |
+
config_tokenizer_class = config.tokenizer_class
|
| 1113 |
+
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
|
| 1114 |
+
tokenizer_auto_map = config.auto_map["AutoTokenizer"]
|
| 1115 |
+
|
| 1116 |
+
has_remote_code = tokenizer_auto_map is not None
|
| 1117 |
+
has_local_code = type(config) in TOKENIZER_MAPPING or (
|
| 1118 |
+
config_tokenizer_class is not None
|
| 1119 |
+
and (
|
| 1120 |
+
tokenizer_class_from_name(config_tokenizer_class) is not None
|
| 1121 |
+
or tokenizer_class_from_name(config_tokenizer_class + "Fast") is not None
|
| 1122 |
+
)
|
| 1123 |
+
)
|
| 1124 |
+
if has_remote_code:
|
| 1125 |
+
if use_fast and tokenizer_auto_map[1] is not None:
|
| 1126 |
+
class_ref = tokenizer_auto_map[1]
|
| 1127 |
+
else:
|
| 1128 |
+
class_ref = tokenizer_auto_map[0]
|
| 1129 |
+
if "--" in class_ref:
|
| 1130 |
+
upstream_repo = class_ref.split("--")[0]
|
| 1131 |
+
else:
|
| 1132 |
+
upstream_repo = None
|
| 1133 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 1134 |
+
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
if has_remote_code and trust_remote_code:
|
| 1138 |
+
tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
| 1139 |
+
_ = kwargs.pop("code_revision", None)
|
| 1140 |
+
tokenizer_class.register_for_auto_class()
|
| 1141 |
+
return tokenizer_class.from_pretrained(
|
| 1142 |
+
pretrained_model_name_or_path, *inputs, trust_remote_code=trust_remote_code, **kwargs
|
| 1143 |
+
)
|
| 1144 |
+
elif config_tokenizer_class is not None:
|
| 1145 |
+
tokenizer_class = None
|
| 1146 |
+
if use_fast and not config_tokenizer_class.endswith("Fast"):
|
| 1147 |
+
tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
|
| 1148 |
+
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
| 1149 |
+
if tokenizer_class is None:
|
| 1150 |
+
tokenizer_class_candidate = config_tokenizer_class
|
| 1151 |
+
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
| 1152 |
+
if tokenizer_class is None:
|
| 1153 |
+
raise ValueError(
|
| 1154 |
+
f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
|
| 1155 |
+
)
|
| 1156 |
+
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
| 1157 |
+
|
| 1158 |
+
# Otherwise we have to be creative.
|
| 1159 |
+
# if model is an encoder decoder, the encoder tokenizer class is used by default
|
| 1160 |
+
if isinstance(config, EncoderDecoderConfig):
|
| 1161 |
+
if type(config.decoder) is not type(config.encoder):
|
| 1162 |
+
logger.warning(
|
| 1163 |
+
f"The encoder model config class: {config.encoder.__class__} is different from the decoder model "
|
| 1164 |
+
f"config class: {config.decoder.__class__}. It is not recommended to use the "
|
| 1165 |
+
"`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder "
|
| 1166 |
+
"specific tokenizer classes."
|
| 1167 |
+
)
|
| 1168 |
+
config = config.encoder
|
| 1169 |
+
|
| 1170 |
+
model_type = config_class_to_model_type(type(config).__name__)
|
| 1171 |
+
if model_type is not None:
|
| 1172 |
+
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
|
| 1173 |
+
|
| 1174 |
+
if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
|
| 1175 |
+
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
| 1176 |
+
else:
|
| 1177 |
+
if tokenizer_class_py is not None:
|
| 1178 |
+
return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
| 1179 |
+
else:
|
| 1180 |
+
raise ValueError(
|
| 1181 |
+
"This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed "
|
| 1182 |
+
"in order to use this tokenizer."
|
| 1183 |
+
)
|
| 1184 |
+
|
| 1185 |
+
raise ValueError(
|
| 1186 |
+
f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n"
|
| 1187 |
+
f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING)}."
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
@staticmethod
|
| 1191 |
+
def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, exist_ok=False):
|
| 1192 |
+
"""
|
| 1193 |
+
Register a new tokenizer in this mapping.
|
| 1194 |
+
|
| 1195 |
+
|
| 1196 |
+
Args:
|
| 1197 |
+
config_class ([`PretrainedConfig`]):
|
| 1198 |
+
The configuration corresponding to the model to register.
|
| 1199 |
+
slow_tokenizer_class ([`PretrainedTokenizer`], *optional*):
|
| 1200 |
+
The slow tokenizer to register.
|
| 1201 |
+
fast_tokenizer_class ([`PretrainedTokenizerFast`], *optional*):
|
| 1202 |
+
The fast tokenizer to register.
|
| 1203 |
+
"""
|
| 1204 |
+
if slow_tokenizer_class is None and fast_tokenizer_class is None:
|
| 1205 |
+
raise ValueError("You need to pass either a `slow_tokenizer_class` or a `fast_tokenizer_class")
|
| 1206 |
+
if slow_tokenizer_class is not None and issubclass(slow_tokenizer_class, PreTrainedTokenizerFast):
|
| 1207 |
+
raise ValueError("You passed a fast tokenizer in the `slow_tokenizer_class`.")
|
| 1208 |
+
if fast_tokenizer_class is not None and issubclass(fast_tokenizer_class, PreTrainedTokenizer):
|
| 1209 |
+
raise ValueError("You passed a slow tokenizer in the `fast_tokenizer_class`.")
|
| 1210 |
+
|
| 1211 |
+
if (
|
| 1212 |
+
slow_tokenizer_class is not None
|
| 1213 |
+
and fast_tokenizer_class is not None
|
| 1214 |
+
and issubclass(fast_tokenizer_class, PreTrainedTokenizerFast)
|
| 1215 |
+
and fast_tokenizer_class.slow_tokenizer_class != slow_tokenizer_class
|
| 1216 |
+
):
|
| 1217 |
+
raise ValueError(
|
| 1218 |
+
"The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not "
|
| 1219 |
+
"consistent with the slow tokenizer class you passed (fast tokenizer has "
|
| 1220 |
+
f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those "
|
| 1221 |
+
"so they match!"
|
| 1222 |
+
)
|
| 1223 |
+
|
| 1224 |
+
# Avoid resetting a set slow/fast tokenizer if we are passing just the other ones.
|
| 1225 |
+
if config_class in TOKENIZER_MAPPING._extra_content:
|
| 1226 |
+
existing_slow, existing_fast = TOKENIZER_MAPPING[config_class]
|
| 1227 |
+
if slow_tokenizer_class is None:
|
| 1228 |
+
slow_tokenizer_class = existing_slow
|
| 1229 |
+
if fast_tokenizer_class is None:
|
| 1230 |
+
fast_tokenizer_class = existing_fast
|
| 1231 |
+
|
| 1232 |
+
TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class), exist_ok=exist_ok)
|
| 1233 |
+
|
| 1234 |
+
|
| 1235 |
+
__all__ = ["TOKENIZER_MAPPING", "AutoTokenizer"]
|
venv/lib/python3.13/site-packages/transformers/models/auto/video_processing_auto.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""AutoVideoProcessor class."""
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import warnings
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
from typing import TYPE_CHECKING, Optional, Union
|
| 23 |
+
|
| 24 |
+
# Build the list of all video processors
|
| 25 |
+
from ...configuration_utils import PretrainedConfig
|
| 26 |
+
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
| 27 |
+
from ...utils import CONFIG_NAME, VIDEO_PROCESSOR_NAME, cached_file, is_torchvision_available, logging
|
| 28 |
+
from ...utils.import_utils import requires
|
| 29 |
+
from ...video_processing_utils import BaseVideoProcessor
|
| 30 |
+
from .auto_factory import _LazyAutoMapping
|
| 31 |
+
from .configuration_auto import (
|
| 32 |
+
CONFIG_MAPPING_NAMES,
|
| 33 |
+
AutoConfig,
|
| 34 |
+
model_type_to_module_name,
|
| 35 |
+
replace_list_option_in_docstrings,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if TYPE_CHECKING:
|
| 43 |
+
# This significantly improves completion suggestion performance when
|
| 44 |
+
# the transformers package is used with Microsoft's Pylance language server.
|
| 45 |
+
VIDEO_PROCESSOR_MAPPING_NAMES: OrderedDict[str, tuple[Optional[str], Optional[str]]] = OrderedDict()
|
| 46 |
+
else:
|
| 47 |
+
VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
| 48 |
+
[
|
| 49 |
+
("glm4v", "Glm4vVideoProcessor"),
|
| 50 |
+
("instructblip", "InstructBlipVideoVideoProcessor"),
|
| 51 |
+
("instructblipvideo", "InstructBlipVideoVideoProcessor"),
|
| 52 |
+
("internvl", "InternVLVideoProcessor"),
|
| 53 |
+
("llava_next_video", "LlavaNextVideoVideoProcessor"),
|
| 54 |
+
("llava_onevision", "LlavaOnevisionVideoProcessor"),
|
| 55 |
+
("perception_lm", "PerceptionLMVideoProcessor"),
|
| 56 |
+
("qwen2_5_omni", "Qwen2VLVideoProcessor"),
|
| 57 |
+
("qwen2_5_vl", "Qwen2VLVideoProcessor"),
|
| 58 |
+
("qwen2_vl", "Qwen2VLVideoProcessor"),
|
| 59 |
+
("qwen3_omni_moe", "Qwen2VLVideoProcessor"),
|
| 60 |
+
("qwen3_vl", "Qwen3VLVideoProcessor"),
|
| 61 |
+
("qwen3_vl_moe", "Qwen3VLVideoProcessor"),
|
| 62 |
+
("sam2_video", "Sam2VideoVideoProcessor"),
|
| 63 |
+
("smolvlm", "SmolVLMVideoProcessor"),
|
| 64 |
+
("video_llava", "VideoLlavaVideoProcessor"),
|
| 65 |
+
("vjepa2", "VJEPA2VideoProcessor"),
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
for model_type, video_processors in VIDEO_PROCESSOR_MAPPING_NAMES.items():
|
| 70 |
+
fast_video_processor_class = video_processors
|
| 71 |
+
|
| 72 |
+
# If the torchvision is not available, we set it to None
|
| 73 |
+
if not is_torchvision_available():
|
| 74 |
+
fast_video_processor_class = None
|
| 75 |
+
|
| 76 |
+
VIDEO_PROCESSOR_MAPPING_NAMES[model_type] = fast_video_processor_class
|
| 77 |
+
|
| 78 |
+
VIDEO_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, VIDEO_PROCESSOR_MAPPING_NAMES)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def video_processor_class_from_name(class_name: str):
|
| 82 |
+
for module_name, extractors in VIDEO_PROCESSOR_MAPPING_NAMES.items():
|
| 83 |
+
if class_name in extractors:
|
| 84 |
+
module_name = model_type_to_module_name(module_name)
|
| 85 |
+
|
| 86 |
+
module = importlib.import_module(f".{module_name}", "transformers.models")
|
| 87 |
+
try:
|
| 88 |
+
return getattr(module, class_name)
|
| 89 |
+
except AttributeError:
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
for extractor in VIDEO_PROCESSOR_MAPPING._extra_content.values():
|
| 93 |
+
if getattr(extractor, "__name__", None) == class_name:
|
| 94 |
+
return extractor
|
| 95 |
+
|
| 96 |
+
# We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main
|
| 97 |
+
# init and we return the proper dummy to get an appropriate error message.
|
| 98 |
+
main_module = importlib.import_module("transformers")
|
| 99 |
+
if hasattr(main_module, class_name):
|
| 100 |
+
return getattr(main_module, class_name)
|
| 101 |
+
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_video_processor_config(
|
| 106 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 107 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 108 |
+
force_download: bool = False,
|
| 109 |
+
resume_download: Optional[bool] = None,
|
| 110 |
+
proxies: Optional[dict[str, str]] = None,
|
| 111 |
+
token: Optional[Union[bool, str]] = None,
|
| 112 |
+
revision: Optional[str] = None,
|
| 113 |
+
local_files_only: bool = False,
|
| 114 |
+
**kwargs,
|
| 115 |
+
):
|
| 116 |
+
"""
|
| 117 |
+
Loads the video processor configuration from a pretrained model video processor configuration.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 121 |
+
This can be either:
|
| 122 |
+
|
| 123 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 124 |
+
huggingface.co.
|
| 125 |
+
- a path to a *directory* containing a configuration file saved using the
|
| 126 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 127 |
+
|
| 128 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 129 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
| 130 |
+
cache should not be used.
|
| 131 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 132 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
| 133 |
+
exist.
|
| 134 |
+
resume_download:
|
| 135 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 136 |
+
Will be removed in v5 of Transformers.
|
| 137 |
+
proxies (`dict[str, str]`, *optional*):
|
| 138 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 139 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 140 |
+
token (`str` or *bool*, *optional*):
|
| 141 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 142 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 143 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 144 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 145 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 146 |
+
identifier allowed by git.
|
| 147 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 148 |
+
If `True`, will only try to load the video processor configuration from local files.
|
| 149 |
+
|
| 150 |
+
<Tip>
|
| 151 |
+
|
| 152 |
+
Passing `token=True` is required when you want to use a private model.
|
| 153 |
+
|
| 154 |
+
</Tip>
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
`Dict`: The configuration of the video processor.
|
| 158 |
+
|
| 159 |
+
Examples:
|
| 160 |
+
|
| 161 |
+
```python
|
| 162 |
+
# Download configuration from huggingface.co and cache.
|
| 163 |
+
video_processor_config = get_video_processor_config("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
| 164 |
+
# This model does not have a video processor config so the result will be an empty dict.
|
| 165 |
+
video_processor_config = get_video_processor_config("FacebookAI/xlm-roberta-base")
|
| 166 |
+
|
| 167 |
+
# Save a pretrained video processor locally and you can reload its config
|
| 168 |
+
from transformers import AutoVideoProcessor
|
| 169 |
+
|
| 170 |
+
video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
| 171 |
+
video_processor.save_pretrained("video-processor-test")
|
| 172 |
+
video_processor = get_video_processor_config("video-processor-test")
|
| 173 |
+
```"""
|
| 174 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 175 |
+
if use_auth_token is not None:
|
| 176 |
+
warnings.warn(
|
| 177 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 178 |
+
FutureWarning,
|
| 179 |
+
)
|
| 180 |
+
if token is not None:
|
| 181 |
+
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
| 182 |
+
token = use_auth_token
|
| 183 |
+
|
| 184 |
+
resolved_config_file = cached_file(
|
| 185 |
+
pretrained_model_name_or_path,
|
| 186 |
+
VIDEO_PROCESSOR_NAME,
|
| 187 |
+
cache_dir=cache_dir,
|
| 188 |
+
force_download=force_download,
|
| 189 |
+
resume_download=resume_download,
|
| 190 |
+
proxies=proxies,
|
| 191 |
+
token=token,
|
| 192 |
+
revision=revision,
|
| 193 |
+
local_files_only=local_files_only,
|
| 194 |
+
)
|
| 195 |
+
if resolved_config_file is None:
|
| 196 |
+
logger.info(
|
| 197 |
+
"Could not locate the video processor configuration file, will try to use the model config instead."
|
| 198 |
+
)
|
| 199 |
+
return {}
|
| 200 |
+
|
| 201 |
+
with open(resolved_config_file, encoding="utf-8") as reader:
|
| 202 |
+
return json.load(reader)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
@requires(backends=("vision", "torchvision"))
|
| 206 |
+
class AutoVideoProcessor:
|
| 207 |
+
r"""
|
| 208 |
+
This is a generic video processor class that will be instantiated as one of the video processor classes of the
|
| 209 |
+
library when created with the [`AutoVideoProcessor.from_pretrained`] class method.
|
| 210 |
+
|
| 211 |
+
This class cannot be instantiated directly using `__init__()` (throws an error).
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(self):
|
| 215 |
+
raise OSError(
|
| 216 |
+
"AutoVideoProcessor is designed to be instantiated "
|
| 217 |
+
"using the `AutoVideoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
@classmethod
|
| 221 |
+
@replace_list_option_in_docstrings(VIDEO_PROCESSOR_MAPPING_NAMES)
|
| 222 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
| 223 |
+
r"""
|
| 224 |
+
Instantiate one of the video processor classes of the library from a pretrained model vocabulary.
|
| 225 |
+
|
| 226 |
+
The video processor class to instantiate is selected based on the `model_type` property of the config object
|
| 227 |
+
(either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
|
| 228 |
+
missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 229 |
+
|
| 230 |
+
List options
|
| 231 |
+
|
| 232 |
+
Params:
|
| 233 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 234 |
+
This can be either:
|
| 235 |
+
|
| 236 |
+
- a string, the *model id* of a pretrained video_processor hosted inside a model repo on
|
| 237 |
+
huggingface.co.
|
| 238 |
+
- a path to a *directory* containing a video processor file saved using the
|
| 239 |
+
[`~video_processing_utils.BaseVideoProcessor.save_pretrained`] method, e.g.,
|
| 240 |
+
`./my_model_directory/`.
|
| 241 |
+
- a path or url to a saved video processor JSON *file*, e.g.,
|
| 242 |
+
`./my_model_directory/preprocessor_config.json`.
|
| 243 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 244 |
+
Path to a directory in which a downloaded pretrained model video processor should be cached if the
|
| 245 |
+
standard cache should not be used.
|
| 246 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 247 |
+
Whether or not to force to (re-)download the video processor files and override the cached versions if
|
| 248 |
+
they exist.
|
| 249 |
+
resume_download:
|
| 250 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 251 |
+
Will be removed in v5 of Transformers.
|
| 252 |
+
proxies (`dict[str, str]`, *optional*):
|
| 253 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 254 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 255 |
+
token (`str` or *bool*, *optional*):
|
| 256 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 257 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 258 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 259 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 260 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 261 |
+
identifier allowed by git.
|
| 262 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 263 |
+
If `False`, then this function returns just the final video processor object. If `True`, then this
|
| 264 |
+
functions returns a `Tuple(video_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
|
| 265 |
+
consisting of the key/value pairs whose keys are not video processor attributes: i.e., the part of
|
| 266 |
+
`kwargs` which has not been used to update `video_processor` and is otherwise ignored.
|
| 267 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 268 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 269 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 270 |
+
execute code present on the Hub on your local machine.
|
| 271 |
+
kwargs (`dict[str, Any]`, *optional*):
|
| 272 |
+
The values in kwargs of any keys which are video processor attributes will be used to override the
|
| 273 |
+
loaded values. Behavior concerning key/value pairs whose keys are *not* video processor attributes is
|
| 274 |
+
controlled by the `return_unused_kwargs` keyword parameter.
|
| 275 |
+
|
| 276 |
+
<Tip>
|
| 277 |
+
|
| 278 |
+
Passing `token=True` is required when you want to use a private model.
|
| 279 |
+
|
| 280 |
+
</Tip>
|
| 281 |
+
|
| 282 |
+
Examples:
|
| 283 |
+
|
| 284 |
+
```python
|
| 285 |
+
>>> from transformers import AutoVideoProcessor
|
| 286 |
+
|
| 287 |
+
>>> # Download video processor from huggingface.co and cache.
|
| 288 |
+
>>> video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
| 289 |
+
|
| 290 |
+
>>> # If video processor files are in a directory (e.g. video processor was saved using *save_pretrained('./test/saved_model/')*)
|
| 291 |
+
>>> # video_processor = AutoVideoProcessor.from_pretrained("./test/saved_model/")
|
| 292 |
+
```"""
|
| 293 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 294 |
+
if use_auth_token is not None:
|
| 295 |
+
warnings.warn(
|
| 296 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 297 |
+
FutureWarning,
|
| 298 |
+
)
|
| 299 |
+
if kwargs.get("token") is not None:
|
| 300 |
+
raise ValueError(
|
| 301 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 302 |
+
)
|
| 303 |
+
kwargs["token"] = use_auth_token
|
| 304 |
+
|
| 305 |
+
config = kwargs.pop("config", None)
|
| 306 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 307 |
+
kwargs["_from_auto"] = True
|
| 308 |
+
|
| 309 |
+
config_dict, _ = BaseVideoProcessor.get_video_processor_dict(pretrained_model_name_or_path, **kwargs)
|
| 310 |
+
video_processor_class = config_dict.get("video_processor_type", None)
|
| 311 |
+
video_processor_auto_map = None
|
| 312 |
+
if "AutoVideoProcessor" in config_dict.get("auto_map", {}):
|
| 313 |
+
video_processor_auto_map = config_dict["auto_map"]["AutoVideoProcessor"]
|
| 314 |
+
|
| 315 |
+
# If we still don't have the video processor class, check if we're loading from a previous image processor config
|
| 316 |
+
# and if so, infer the video processor class from there.
|
| 317 |
+
if video_processor_class is None and video_processor_auto_map is None:
|
| 318 |
+
image_processor_class = config_dict.pop("image_processor_type", None)
|
| 319 |
+
if image_processor_class is not None:
|
| 320 |
+
video_processor_class_inferred = image_processor_class.replace("ImageProcessor", "VideoProcessor")
|
| 321 |
+
|
| 322 |
+
# Some models have different image processors, e.g. InternVL uses GotOCRImageProcessor
|
| 323 |
+
# We cannot use GotOCRVideoProcessor when falling back for BC and should try to infer from config later on
|
| 324 |
+
if video_processor_class_inferred in VIDEO_PROCESSOR_MAPPING_NAMES.values():
|
| 325 |
+
video_processor_class = video_processor_class_inferred
|
| 326 |
+
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
|
| 327 |
+
image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
|
| 328 |
+
video_processor_auto_map = image_processor_auto_map.replace("ImageProcessor", "VideoProcessor")
|
| 329 |
+
|
| 330 |
+
# If we don't find the video processor class in the video processor config, let's try the model config.
|
| 331 |
+
if video_processor_class is None and video_processor_auto_map is None:
|
| 332 |
+
if not isinstance(config, PretrainedConfig):
|
| 333 |
+
config = AutoConfig.from_pretrained(
|
| 334 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 335 |
+
)
|
| 336 |
+
# It could be in `config.video_processor_type``
|
| 337 |
+
video_processor_class = getattr(config, "video_processor_type", None)
|
| 338 |
+
if hasattr(config, "auto_map") and "AutoVideoProcessor" in config.auto_map:
|
| 339 |
+
video_processor_auto_map = config.auto_map["AutoVideoProcessor"]
|
| 340 |
+
|
| 341 |
+
if video_processor_class is not None:
|
| 342 |
+
video_processor_class = video_processor_class_from_name(video_processor_class)
|
| 343 |
+
|
| 344 |
+
has_remote_code = video_processor_auto_map is not None
|
| 345 |
+
has_local_code = video_processor_class is not None or type(config) in VIDEO_PROCESSOR_MAPPING
|
| 346 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 347 |
+
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
if has_remote_code and trust_remote_code:
|
| 351 |
+
class_ref = video_processor_auto_map
|
| 352 |
+
video_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
| 353 |
+
_ = kwargs.pop("code_revision", None)
|
| 354 |
+
video_processor_class.register_for_auto_class()
|
| 355 |
+
return video_processor_class.from_dict(config_dict, **kwargs)
|
| 356 |
+
elif video_processor_class is not None:
|
| 357 |
+
return video_processor_class.from_dict(config_dict, **kwargs)
|
| 358 |
+
# Last try: we use the VIDEO_PROCESSOR_MAPPING.
|
| 359 |
+
elif type(config) in VIDEO_PROCESSOR_MAPPING:
|
| 360 |
+
video_processor_class = VIDEO_PROCESSOR_MAPPING[type(config)]
|
| 361 |
+
|
| 362 |
+
if video_processor_class is not None:
|
| 363 |
+
return video_processor_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
| 364 |
+
else:
|
| 365 |
+
raise ValueError(
|
| 366 |
+
"This video processor cannot be instantiated. Please make sure you have `torchvision` installed."
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
raise ValueError(
|
| 370 |
+
f"Unrecognized video processor in {pretrained_model_name_or_path}. Should have a "
|
| 371 |
+
f"`video_processor_type` key in its {VIDEO_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
|
| 372 |
+
f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in VIDEO_PROCESSOR_MAPPING_NAMES)}"
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
@staticmethod
|
| 376 |
+
def register(
|
| 377 |
+
config_class,
|
| 378 |
+
video_processor_class,
|
| 379 |
+
exist_ok=False,
|
| 380 |
+
):
|
| 381 |
+
"""
|
| 382 |
+
Register a new video processor for this class.
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
config_class ([`PretrainedConfig`]):
|
| 386 |
+
The configuration corresponding to the model to register.
|
| 387 |
+
video_processor_class ([`BaseVideoProcessor`]):
|
| 388 |
+
The video processor to register.
|
| 389 |
+
"""
|
| 390 |
+
VIDEO_PROCESSOR_MAPPING.register(config_class, video_processor_class, exist_ok=exist_ok)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
__all__ = ["VIDEO_PROCESSOR_MAPPING", "AutoVideoProcessor"]
|
venv/lib/python3.13/site-packages/transformers/models/aya_vision/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_aya_vision import *
|
| 22 |
+
from .modeling_aya_vision import *
|
| 23 |
+
from .processing_aya_vision import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
venv/lib/python3.13/site-packages/transformers/models/aya_vision/configuration_aya_vision.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 Cohere team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""AyaVision model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
from ..auto import CONFIG_MAPPING, AutoConfig
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AyaVisionConfig(PretrainedConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`AyaVisionForConditionalGeneration`]. It is used to instantiate an
|
| 28 |
+
AyaVision model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 29 |
+
with the defaults will yield a similar configuration to that of AyaVision.
|
| 30 |
+
e.g. [CohereForAI/aya-vision-8b](https://huggingface.co/CohereForAI/aya-vision-8b)
|
| 31 |
+
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SiglipVisionConfig`):
|
| 37 |
+
The config object or dictionary of the vision backbone.
|
| 38 |
+
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Cohere2Config`):
|
| 39 |
+
The config object or dictionary of the text backbone.
|
| 40 |
+
vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`):
|
| 41 |
+
The feature selection strategy used to select the vision feature from the vision backbone.
|
| 42 |
+
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
|
| 43 |
+
If `"full"`, the full vision features are used.
|
| 44 |
+
vision_feature_layer (`int`, *optional*, defaults to -1):
|
| 45 |
+
The index of the layer to select the vision feature.
|
| 46 |
+
downsample_factor (`int`, *optional*, defaults to 2):
|
| 47 |
+
The downsample factor to apply to the vision features.
|
| 48 |
+
adapter_layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 49 |
+
The epsilon value used for layer normalization in the adapter.
|
| 50 |
+
image_token_index (`int`, *optional*, defaults to 255036):
|
| 51 |
+
The image token index to encode the image prompt.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
model_type = "aya_vision"
|
| 55 |
+
attribute_map = {
|
| 56 |
+
"image_token_id": "image_token_index",
|
| 57 |
+
}
|
| 58 |
+
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
vision_config=None,
|
| 63 |
+
text_config=None,
|
| 64 |
+
vision_feature_select_strategy="full",
|
| 65 |
+
vision_feature_layer=-1,
|
| 66 |
+
downsample_factor=2,
|
| 67 |
+
adapter_layer_norm_eps=1e-6,
|
| 68 |
+
image_token_index=255036,
|
| 69 |
+
**kwargs,
|
| 70 |
+
):
|
| 71 |
+
self.image_token_index = image_token_index
|
| 72 |
+
self.downsample_factor = downsample_factor
|
| 73 |
+
self.adapter_layer_norm_eps = adapter_layer_norm_eps
|
| 74 |
+
if vision_feature_select_strategy not in ["default", "full"]:
|
| 75 |
+
raise ValueError(
|
| 76 |
+
"vision_feature_select_strategy should be one of 'default', 'full'."
|
| 77 |
+
f"Got: {vision_feature_select_strategy}"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
self.vision_feature_select_strategy = vision_feature_select_strategy
|
| 81 |
+
self.vision_feature_layer = vision_feature_layer
|
| 82 |
+
|
| 83 |
+
if isinstance(vision_config, dict):
|
| 84 |
+
vision_config["model_type"] = vision_config.get("model_type", "siglip_vision_model")
|
| 85 |
+
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
| 86 |
+
elif vision_config is None:
|
| 87 |
+
vision_config = CONFIG_MAPPING["siglip_vision_model"](
|
| 88 |
+
hidden_size=1152,
|
| 89 |
+
intermediate_size=4304,
|
| 90 |
+
patch_size=14,
|
| 91 |
+
image_size=384,
|
| 92 |
+
num_hidden_layers=26,
|
| 93 |
+
num_attention_heads=14,
|
| 94 |
+
vision_use_head=False,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.vision_config = vision_config
|
| 98 |
+
|
| 99 |
+
if isinstance(text_config, dict):
|
| 100 |
+
text_config["model_type"] = text_config.get("model_type", "cohere2")
|
| 101 |
+
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
| 102 |
+
elif text_config is None:
|
| 103 |
+
text_config = CONFIG_MAPPING["cohere2"]()
|
| 104 |
+
|
| 105 |
+
self.text_config = text_config
|
| 106 |
+
|
| 107 |
+
super().__init__(**kwargs)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
__all__ = ["AyaVisionConfig"]
|
venv/lib/python3.13/site-packages/transformers/models/aya_vision/modeling_aya_vision.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/aya_vision/modular_aya_vision.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_aya_vision.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2025 the Cohere Inc. team. All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
from typing import Optional, Union
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from torch import nn
|
| 27 |
+
|
| 28 |
+
from ...activations import ACT2FN
|
| 29 |
+
from ...cache_utils import Cache
|
| 30 |
+
from ...generation import GenerationMixin
|
| 31 |
+
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
| 32 |
+
from ...modeling_utils import PreTrainedModel
|
| 33 |
+
from ...processing_utils import Unpack
|
| 34 |
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
| 35 |
+
from ...utils.generic import check_model_inputs
|
| 36 |
+
from ..auto import AutoModel
|
| 37 |
+
from .configuration_aya_vision import AyaVisionConfig
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class AyaVisionMultiModalProjector(nn.Module):
|
| 41 |
+
def __init__(self, config: AyaVisionConfig):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.config = config
|
| 44 |
+
self.downsample_factor = config.downsample_factor
|
| 45 |
+
self.alignment_intermediate_size = getattr(
|
| 46 |
+
config, "alignment_intermediate_size", config.text_config.hidden_size
|
| 47 |
+
)
|
| 48 |
+
self.layernorm = nn.LayerNorm(
|
| 49 |
+
config.vision_config.hidden_size * (config.downsample_factor**2), eps=config.adapter_layer_norm_eps
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.linear_1 = nn.Linear(
|
| 53 |
+
config.vision_config.hidden_size * (config.downsample_factor**2),
|
| 54 |
+
self.alignment_intermediate_size,
|
| 55 |
+
bias=True,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
self.act = ACT2FN["silu"] # SwiGLU uses SiLU activation
|
| 59 |
+
# For SwiGLU, project down to half size since we split intermediate dim
|
| 60 |
+
self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2, config.text_config.hidden_size, bias=True)
|
| 61 |
+
|
| 62 |
+
def forward(self, image_features):
|
| 63 |
+
image_features = self.pixel_shuffle(image_features)
|
| 64 |
+
image_features = self.layernorm(image_features)
|
| 65 |
+
hidden_states = self.linear_1(image_features)
|
| 66 |
+
|
| 67 |
+
# Split along last dimension and apply SwiGLU
|
| 68 |
+
x, gate = hidden_states.chunk(2, dim=-1)
|
| 69 |
+
hidden_states = self.act(gate) * x
|
| 70 |
+
|
| 71 |
+
hidden_states = self.linear_2(hidden_states)
|
| 72 |
+
return hidden_states
|
| 73 |
+
|
| 74 |
+
def pixel_shuffle(self, image_features): # B, S, D
|
| 75 |
+
batch_size, seq_length, feature_dim = image_features.shape
|
| 76 |
+
height = width = int(seq_length**0.5)
|
| 77 |
+
image_features = image_features.reshape(image_features.shape[0], width, height, -1)
|
| 78 |
+
channels = image_features.shape[-1]
|
| 79 |
+
image_features = image_features.reshape(
|
| 80 |
+
batch_size, width, int(height / self.downsample_factor), int(channels * self.downsample_factor)
|
| 81 |
+
)
|
| 82 |
+
image_features = image_features.permute(0, 2, 1, 3)
|
| 83 |
+
image_features = image_features.reshape(
|
| 84 |
+
batch_size, int(height / self.downsample_factor), int(width / self.downsample_factor), -1
|
| 85 |
+
)
|
| 86 |
+
image_features = image_features.permute(0, 2, 1, 3)
|
| 87 |
+
return image_features
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@auto_docstring
|
| 91 |
+
class AyaVisionPreTrainedModel(PreTrainedModel):
|
| 92 |
+
config: AyaVisionConfig
|
| 93 |
+
base_model_prefix = ""
|
| 94 |
+
supports_gradient_checkpointing = True
|
| 95 |
+
_skip_keys_device_placement = "past_key_values"
|
| 96 |
+
|
| 97 |
+
_supports_flash_attn = True
|
| 98 |
+
_supports_sdpa = True
|
| 99 |
+
_can_compile_fullgraph = False
|
| 100 |
+
_supports_flex_attn = True
|
| 101 |
+
_supports_attention_backend = True
|
| 102 |
+
_can_record_outputs = {
|
| 103 |
+
"hidden_states": "DecoderLayer",
|
| 104 |
+
"attentions": "Attention",
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@dataclass
|
| 109 |
+
@auto_docstring(
|
| 110 |
+
custom_intro="""
|
| 111 |
+
Base class for AyaVision causal language model (or autoregressive) outputs.
|
| 112 |
+
"""
|
| 113 |
+
)
|
| 114 |
+
class AyaVisionCausalLMOutputWithPast(ModelOutput):
|
| 115 |
+
r"""
|
| 116 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 117 |
+
Language modeling loss (for next-token prediction).
|
| 118 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 119 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 120 |
+
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 121 |
+
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
| 122 |
+
|
| 123 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 124 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 125 |
+
image_hidden_states (`torch.FloatTensor`, *optional*):
|
| 126 |
+
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
| 127 |
+
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
loss: Optional[torch.FloatTensor] = None
|
| 131 |
+
logits: Optional[torch.FloatTensor] = None
|
| 132 |
+
past_key_values: Optional[Cache] = None
|
| 133 |
+
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 134 |
+
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 135 |
+
image_hidden_states: Optional[torch.FloatTensor] = None
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclass
|
| 139 |
+
@auto_docstring(
|
| 140 |
+
custom_intro="""
|
| 141 |
+
Base class for AyaVision outputs, with hidden states and attentions.
|
| 142 |
+
"""
|
| 143 |
+
)
|
| 144 |
+
class AyaVisionModelOutputWithPast(BaseModelOutputWithPast):
|
| 145 |
+
r"""
|
| 146 |
+
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 147 |
+
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
| 148 |
+
|
| 149 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 150 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 151 |
+
image_hidden_states (`torch.FloatTensor`, *optional*):
|
| 152 |
+
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
| 153 |
+
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
image_hidden_states: Optional[torch.FloatTensor] = None
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@auto_docstring(
|
| 160 |
+
custom_intro="""
|
| 161 |
+
The AyaVision model which consists of a vision backbone and a language model, without a language modeling head.
|
| 162 |
+
"""
|
| 163 |
+
)
|
| 164 |
+
class AyaVisionModel(AyaVisionPreTrainedModel):
|
| 165 |
+
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
| 166 |
+
|
| 167 |
+
def __init__(self, config: AyaVisionConfig):
|
| 168 |
+
super().__init__(config)
|
| 169 |
+
self.vision_tower = AutoModel.from_config(config.vision_config)
|
| 170 |
+
|
| 171 |
+
self.multi_modal_projector = AyaVisionMultiModalProjector(config)
|
| 172 |
+
self.language_model = AutoModel.from_config(config.text_config)
|
| 173 |
+
self.post_init()
|
| 174 |
+
|
| 175 |
+
def get_input_embeddings(self):
|
| 176 |
+
return self.language_model.get_input_embeddings()
|
| 177 |
+
|
| 178 |
+
def set_input_embeddings(self, value):
|
| 179 |
+
self.language_model.set_input_embeddings(value)
|
| 180 |
+
|
| 181 |
+
def set_decoder(self, decoder):
|
| 182 |
+
self.language_model = decoder
|
| 183 |
+
|
| 184 |
+
def get_decoder(self):
|
| 185 |
+
return self.language_model
|
| 186 |
+
|
| 187 |
+
def get_image_features(
|
| 188 |
+
self,
|
| 189 |
+
pixel_values: torch.FloatTensor,
|
| 190 |
+
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
| 191 |
+
vision_feature_select_strategy: Optional[str] = None,
|
| 192 |
+
**kwargs,
|
| 193 |
+
):
|
| 194 |
+
"""
|
| 195 |
+
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
| 199 |
+
The tensors corresponding to the input images.
|
| 200 |
+
vision_feature_layer (`Union[int, list[int]]`, *optional*):
|
| 201 |
+
The index of the layer to select the vision feature. If multiple indices are provided,
|
| 202 |
+
the vision feature of the corresponding indices will be concatenated to form the
|
| 203 |
+
vision features.
|
| 204 |
+
vision_feature_select_strategy (`str`, *optional*):
|
| 205 |
+
The feature selection strategy used to select the vision feature from the vision backbone.
|
| 206 |
+
Can be one of `"default"` or `"full"`
|
| 207 |
+
Returns:
|
| 208 |
+
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
| 209 |
+
"""
|
| 210 |
+
vision_feature_layer = (
|
| 211 |
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
| 212 |
+
)
|
| 213 |
+
vision_feature_select_strategy = (
|
| 214 |
+
vision_feature_select_strategy
|
| 215 |
+
if vision_feature_select_strategy is not None
|
| 216 |
+
else self.config.vision_feature_select_strategy
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
if vision_feature_select_strategy not in ["default", "full"]:
|
| 220 |
+
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
| 221 |
+
|
| 222 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
| 223 |
+
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
| 224 |
+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
|
| 225 |
+
|
| 226 |
+
# If we have one vision feature layer, return the corresponding hidden states,
|
| 227 |
+
# otherwise, select the hidden states of each feature layer and concatenate them
|
| 228 |
+
if isinstance(vision_feature_layer, int):
|
| 229 |
+
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
| 230 |
+
if vision_feature_select_strategy == "default":
|
| 231 |
+
selected_image_feature = selected_image_feature[:, 1:]
|
| 232 |
+
else:
|
| 233 |
+
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
| 234 |
+
# For default; crop CLS from each hidden state in the hidden state pool
|
| 235 |
+
if vision_feature_select_strategy == "default":
|
| 236 |
+
hs_pool = [hs[:, 1:] for hs in hs_pool]
|
| 237 |
+
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
| 238 |
+
|
| 239 |
+
image_features = self.multi_modal_projector(selected_image_feature)
|
| 240 |
+
return image_features
|
| 241 |
+
|
| 242 |
+
def get_placeholder_mask(
|
| 243 |
+
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
|
| 244 |
+
):
|
| 245 |
+
"""
|
| 246 |
+
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
|
| 247 |
+
equal to the length of multimodal features. If the lengths are different, an error is raised.
|
| 248 |
+
"""
|
| 249 |
+
if input_ids is None:
|
| 250 |
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
| 251 |
+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 252 |
+
)
|
| 253 |
+
special_image_mask = special_image_mask.all(-1)
|
| 254 |
+
else:
|
| 255 |
+
special_image_mask = input_ids == self.config.image_token_id
|
| 256 |
+
|
| 257 |
+
n_image_tokens = special_image_mask.sum()
|
| 258 |
+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 259 |
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
| 260 |
+
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
| 261 |
+
raise ValueError(
|
| 262 |
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
| 263 |
+
)
|
| 264 |
+
return special_image_mask
|
| 265 |
+
|
| 266 |
+
@check_model_inputs()
|
| 267 |
+
@auto_docstring
|
| 268 |
+
def forward(
|
| 269 |
+
self,
|
| 270 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 271 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 272 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 273 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 274 |
+
past_key_values: Optional[Cache] = None,
|
| 275 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 276 |
+
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
| 277 |
+
vision_feature_select_strategy: Optional[str] = None,
|
| 278 |
+
use_cache: Optional[bool] = None,
|
| 279 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 280 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 281 |
+
) -> Union[tuple, AyaVisionModelOutputWithPast]:
|
| 282 |
+
vision_feature_layer = (
|
| 283 |
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
| 284 |
+
)
|
| 285 |
+
vision_feature_select_strategy = (
|
| 286 |
+
vision_feature_select_strategy
|
| 287 |
+
if vision_feature_select_strategy is not None
|
| 288 |
+
else self.config.vision_feature_select_strategy
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 292 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 293 |
+
|
| 294 |
+
if inputs_embeds is None:
|
| 295 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 296 |
+
|
| 297 |
+
if pixel_values is not None:
|
| 298 |
+
image_features = self.get_image_features(
|
| 299 |
+
pixel_values=pixel_values,
|
| 300 |
+
vision_feature_layer=vision_feature_layer,
|
| 301 |
+
vision_feature_select_strategy=vision_feature_select_strategy,
|
| 302 |
+
)
|
| 303 |
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 304 |
+
special_image_mask = self.get_placeholder_mask(
|
| 305 |
+
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
| 306 |
+
)
|
| 307 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
| 308 |
+
|
| 309 |
+
outputs = self.language_model(
|
| 310 |
+
attention_mask=attention_mask,
|
| 311 |
+
position_ids=position_ids,
|
| 312 |
+
past_key_values=past_key_values,
|
| 313 |
+
inputs_embeds=inputs_embeds,
|
| 314 |
+
use_cache=use_cache,
|
| 315 |
+
cache_position=cache_position,
|
| 316 |
+
**kwargs,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
return AyaVisionModelOutputWithPast(
|
| 320 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 321 |
+
past_key_values=outputs.past_key_values,
|
| 322 |
+
hidden_states=outputs.hidden_states,
|
| 323 |
+
attentions=outputs.attentions,
|
| 324 |
+
image_hidden_states=image_features if pixel_values is not None else None,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
@auto_docstring(
|
| 329 |
+
custom_intro="""
|
| 330 |
+
The AYA_VISION model which consists of a vision backbone and a language model.
|
| 331 |
+
"""
|
| 332 |
+
)
|
| 333 |
+
class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin):
|
| 334 |
+
_checkpoint_conversion_mapping = {
|
| 335 |
+
"^language_model.model": "model.language_model",
|
| 336 |
+
"^vision_tower": "model.vision_tower",
|
| 337 |
+
"^multi_modal_projector": "model.multi_modal_projector",
|
| 338 |
+
"^language_model.lm_head": "lm_head",
|
| 339 |
+
}
|
| 340 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 341 |
+
|
| 342 |
+
def __init__(self, config: AyaVisionConfig):
|
| 343 |
+
super().__init__(config)
|
| 344 |
+
self.model = AyaVisionModel(config)
|
| 345 |
+
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
| 346 |
+
self.post_init()
|
| 347 |
+
|
| 348 |
+
def get_input_embeddings(self):
|
| 349 |
+
return self.model.get_input_embeddings()
|
| 350 |
+
|
| 351 |
+
def set_input_embeddings(self, value):
|
| 352 |
+
self.model.set_input_embeddings(value)
|
| 353 |
+
|
| 354 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 355 |
+
return self.lm_head
|
| 356 |
+
|
| 357 |
+
def set_decoder(self, decoder):
|
| 358 |
+
self.model.set_decoder(decoder)
|
| 359 |
+
|
| 360 |
+
def get_decoder(self):
|
| 361 |
+
return self.model.get_decoder()
|
| 362 |
+
|
| 363 |
+
def get_image_features(
|
| 364 |
+
self,
|
| 365 |
+
pixel_values: torch.FloatTensor,
|
| 366 |
+
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
| 367 |
+
vision_feature_select_strategy: Optional[str] = None,
|
| 368 |
+
**kwargs,
|
| 369 |
+
):
|
| 370 |
+
return self.model.get_image_features(
|
| 371 |
+
pixel_values=pixel_values,
|
| 372 |
+
vision_feature_layer=vision_feature_layer,
|
| 373 |
+
vision_feature_select_strategy=vision_feature_select_strategy,
|
| 374 |
+
**kwargs,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Make modules available through conditional class for BC
|
| 378 |
+
@property
|
| 379 |
+
def language_model(self):
|
| 380 |
+
return self.model.language_model
|
| 381 |
+
|
| 382 |
+
@property
|
| 383 |
+
def vision_tower(self):
|
| 384 |
+
return self.model.vision_tower
|
| 385 |
+
|
| 386 |
+
@property
|
| 387 |
+
def multi_modal_projector(self):
|
| 388 |
+
return self.model.multi_modal_projector
|
| 389 |
+
|
| 390 |
+
@can_return_tuple
|
| 391 |
+
@auto_docstring
|
| 392 |
+
def forward(
|
| 393 |
+
self,
|
| 394 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 395 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 396 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 397 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 398 |
+
past_key_values: Optional[Cache] = None,
|
| 399 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 400 |
+
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
| 401 |
+
vision_feature_select_strategy: Optional[str] = None,
|
| 402 |
+
labels: Optional[torch.LongTensor] = None,
|
| 403 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 404 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 405 |
+
image_sizes: Optional[torch.Tensor] = None,
|
| 406 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 407 |
+
) -> Union[tuple, AyaVisionCausalLMOutputWithPast]:
|
| 408 |
+
r"""
|
| 409 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 410 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 411 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 412 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 413 |
+
|
| 414 |
+
Example:
|
| 415 |
+
|
| 416 |
+
```python
|
| 417 |
+
>>> from transformers import AutoProcessor, AyaVisionForConditionalGeneration
|
| 418 |
+
>>> import torch
|
| 419 |
+
|
| 420 |
+
>>> torch_device = "cuda:0"
|
| 421 |
+
>>> processor = AutoProcessor.from_pretrained("CohereForAI/aya-vision-8b", use_fast=True)
|
| 422 |
+
>>> model = AyaVisionForConditionalGeneration.from_pretrained("CohereForAI/aya-vision-8b", device_map=torch_device)
|
| 423 |
+
|
| 424 |
+
>>> messages = [
|
| 425 |
+
... {
|
| 426 |
+
... "role": "user",
|
| 427 |
+
... "content": [
|
| 428 |
+
... {
|
| 429 |
+
... "type": "image",
|
| 430 |
+
... "url": "https://pbs.twimg.com/media/Fx7YvfQWYAIp6rZ?format=jpg&name=medium",
|
| 431 |
+
... },
|
| 432 |
+
... {"type": "text", "text": "चित्र में लिखा पाठ क्या कहता है?"},
|
| 433 |
+
... ],
|
| 434 |
+
... }
|
| 435 |
+
... ]
|
| 436 |
+
|
| 437 |
+
>>> inputs = processor.apply_chat_template(
|
| 438 |
+
... messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", device=torch_device
|
| 439 |
+
... ).to(model.device)
|
| 440 |
+
|
| 441 |
+
>>> gen_tokens = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.3)
|
| 442 |
+
>>> processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
| 443 |
+
```"""
|
| 444 |
+
vision_feature_layer = (
|
| 445 |
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
| 446 |
+
)
|
| 447 |
+
vision_feature_select_strategy = (
|
| 448 |
+
vision_feature_select_strategy
|
| 449 |
+
if vision_feature_select_strategy is not None
|
| 450 |
+
else self.config.vision_feature_select_strategy
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
outputs = self.model(
|
| 454 |
+
input_ids=input_ids,
|
| 455 |
+
pixel_values=pixel_values,
|
| 456 |
+
attention_mask=attention_mask,
|
| 457 |
+
position_ids=position_ids,
|
| 458 |
+
past_key_values=past_key_values,
|
| 459 |
+
inputs_embeds=inputs_embeds,
|
| 460 |
+
vision_feature_layer=vision_feature_layer,
|
| 461 |
+
vision_feature_select_strategy=vision_feature_select_strategy,
|
| 462 |
+
cache_position=cache_position,
|
| 463 |
+
image_sizes=image_sizes,
|
| 464 |
+
**kwargs,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
hidden_states = outputs[0]
|
| 468 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 469 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 470 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 471 |
+
|
| 472 |
+
loss = None
|
| 473 |
+
if labels is not None:
|
| 474 |
+
loss = self.loss_function(
|
| 475 |
+
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
return AyaVisionCausalLMOutputWithPast(
|
| 479 |
+
loss=loss,
|
| 480 |
+
logits=logits,
|
| 481 |
+
past_key_values=outputs.past_key_values,
|
| 482 |
+
hidden_states=outputs.hidden_states,
|
| 483 |
+
attentions=outputs.attentions,
|
| 484 |
+
image_hidden_states=outputs.image_hidden_states,
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
def prepare_inputs_for_generation(
|
| 488 |
+
self,
|
| 489 |
+
input_ids,
|
| 490 |
+
past_key_values=None,
|
| 491 |
+
inputs_embeds=None,
|
| 492 |
+
pixel_values=None,
|
| 493 |
+
attention_mask=None,
|
| 494 |
+
cache_position=None,
|
| 495 |
+
logits_to_keep=None,
|
| 496 |
+
**kwargs,
|
| 497 |
+
):
|
| 498 |
+
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
| 499 |
+
|
| 500 |
+
model_inputs = super().prepare_inputs_for_generation(
|
| 501 |
+
input_ids,
|
| 502 |
+
past_key_values=past_key_values,
|
| 503 |
+
inputs_embeds=inputs_embeds,
|
| 504 |
+
attention_mask=attention_mask,
|
| 505 |
+
cache_position=cache_position,
|
| 506 |
+
logits_to_keep=logits_to_keep,
|
| 507 |
+
**kwargs,
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if cache_position[0] == 0:
|
| 511 |
+
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
| 512 |
+
# Otherwise we need pixel values to be passed to model
|
| 513 |
+
model_inputs["pixel_values"] = pixel_values
|
| 514 |
+
|
| 515 |
+
return model_inputs
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]
|
venv/lib/python3.13/site-packages/transformers/models/aya_vision/modular_aya_vision.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 the Cohere Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch AyaVision model."""
|
| 16 |
+
|
| 17 |
+
from typing import Optional, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch import nn
|
| 21 |
+
|
| 22 |
+
from transformers.models.llava.modeling_llava import (
|
| 23 |
+
LlavaCausalLMOutputWithPast,
|
| 24 |
+
LlavaForConditionalGeneration,
|
| 25 |
+
LlavaModel,
|
| 26 |
+
LlavaModelOutputWithPast,
|
| 27 |
+
LlavaPreTrainedModel,
|
| 28 |
+
TransformersKwargs,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
from ...activations import ACT2FN
|
| 32 |
+
from ...cache_utils import Cache
|
| 33 |
+
from ...processing_utils import Unpack
|
| 34 |
+
from ...utils import auto_docstring, logging
|
| 35 |
+
from ...utils.generic import check_model_inputs
|
| 36 |
+
from .configuration_aya_vision import AyaVisionConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class AyaVisionMultiModalProjector(nn.Module):
|
| 43 |
+
def __init__(self, config: AyaVisionConfig):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.config = config
|
| 46 |
+
self.downsample_factor = config.downsample_factor
|
| 47 |
+
self.alignment_intermediate_size = getattr(
|
| 48 |
+
config, "alignment_intermediate_size", config.text_config.hidden_size
|
| 49 |
+
)
|
| 50 |
+
self.layernorm = nn.LayerNorm(
|
| 51 |
+
config.vision_config.hidden_size * (config.downsample_factor**2), eps=config.adapter_layer_norm_eps
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.linear_1 = nn.Linear(
|
| 55 |
+
config.vision_config.hidden_size * (config.downsample_factor**2),
|
| 56 |
+
self.alignment_intermediate_size,
|
| 57 |
+
bias=True,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
self.act = ACT2FN["silu"] # SwiGLU uses SiLU activation
|
| 61 |
+
# For SwiGLU, project down to half size since we split intermediate dim
|
| 62 |
+
self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2, config.text_config.hidden_size, bias=True)
|
| 63 |
+
|
| 64 |
+
def forward(self, image_features):
|
| 65 |
+
image_features = self.pixel_shuffle(image_features)
|
| 66 |
+
image_features = self.layernorm(image_features)
|
| 67 |
+
hidden_states = self.linear_1(image_features)
|
| 68 |
+
|
| 69 |
+
# Split along last dimension and apply SwiGLU
|
| 70 |
+
x, gate = hidden_states.chunk(2, dim=-1)
|
| 71 |
+
hidden_states = self.act(gate) * x
|
| 72 |
+
|
| 73 |
+
hidden_states = self.linear_2(hidden_states)
|
| 74 |
+
return hidden_states
|
| 75 |
+
|
| 76 |
+
def pixel_shuffle(self, image_features): # B, S, D
|
| 77 |
+
batch_size, seq_length, feature_dim = image_features.shape
|
| 78 |
+
height = width = int(seq_length**0.5)
|
| 79 |
+
image_features = image_features.reshape(image_features.shape[0], width, height, -1)
|
| 80 |
+
channels = image_features.shape[-1]
|
| 81 |
+
image_features = image_features.reshape(
|
| 82 |
+
batch_size, width, int(height / self.downsample_factor), int(channels * self.downsample_factor)
|
| 83 |
+
)
|
| 84 |
+
image_features = image_features.permute(0, 2, 1, 3)
|
| 85 |
+
image_features = image_features.reshape(
|
| 86 |
+
batch_size, int(height / self.downsample_factor), int(width / self.downsample_factor), -1
|
| 87 |
+
)
|
| 88 |
+
image_features = image_features.permute(0, 2, 1, 3)
|
| 89 |
+
return image_features
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class AyaVisionPreTrainedModel(LlavaPreTrainedModel):
|
| 93 |
+
_can_compile_fullgraph = False
|
| 94 |
+
_can_record_outputs = {
|
| 95 |
+
"hidden_states": "DecoderLayer",
|
| 96 |
+
"attentions": "Attention",
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class AyaVisionModelOutputWithPast(LlavaModelOutputWithPast):
|
| 105 |
+
pass
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class AyaVisionModel(LlavaModel):
|
| 109 |
+
# Unlike LLaVA, the model doesn't have to deal with Pixtral-style image states
|
| 110 |
+
def get_image_features(
|
| 111 |
+
self,
|
| 112 |
+
pixel_values: torch.FloatTensor,
|
| 113 |
+
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
| 114 |
+
vision_feature_select_strategy: Optional[str] = None,
|
| 115 |
+
**kwargs,
|
| 116 |
+
):
|
| 117 |
+
"""
|
| 118 |
+
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
| 122 |
+
The tensors corresponding to the input images.
|
| 123 |
+
vision_feature_layer (`Union[int, list[int]]`, *optional*):
|
| 124 |
+
The index of the layer to select the vision feature. If multiple indices are provided,
|
| 125 |
+
the vision feature of the corresponding indices will be concatenated to form the
|
| 126 |
+
vision features.
|
| 127 |
+
vision_feature_select_strategy (`str`, *optional*):
|
| 128 |
+
The feature selection strategy used to select the vision feature from the vision backbone.
|
| 129 |
+
Can be one of `"default"` or `"full"`
|
| 130 |
+
Returns:
|
| 131 |
+
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
| 132 |
+
"""
|
| 133 |
+
vision_feature_layer = (
|
| 134 |
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
| 135 |
+
)
|
| 136 |
+
vision_feature_select_strategy = (
|
| 137 |
+
vision_feature_select_strategy
|
| 138 |
+
if vision_feature_select_strategy is not None
|
| 139 |
+
else self.config.vision_feature_select_strategy
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
if vision_feature_select_strategy not in ["default", "full"]:
|
| 143 |
+
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
| 144 |
+
|
| 145 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
| 146 |
+
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
| 147 |
+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
|
| 148 |
+
|
| 149 |
+
# If we have one vision feature layer, return the corresponding hidden states,
|
| 150 |
+
# otherwise, select the hidden states of each feature layer and concatenate them
|
| 151 |
+
if isinstance(vision_feature_layer, int):
|
| 152 |
+
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
| 153 |
+
if vision_feature_select_strategy == "default":
|
| 154 |
+
selected_image_feature = selected_image_feature[:, 1:]
|
| 155 |
+
else:
|
| 156 |
+
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
| 157 |
+
# For default; crop CLS from each hidden state in the hidden state pool
|
| 158 |
+
if vision_feature_select_strategy == "default":
|
| 159 |
+
hs_pool = [hs[:, 1:] for hs in hs_pool]
|
| 160 |
+
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
| 161 |
+
|
| 162 |
+
image_features = self.multi_modal_projector(selected_image_feature)
|
| 163 |
+
return image_features
|
| 164 |
+
|
| 165 |
+
@check_model_inputs()
|
| 166 |
+
@auto_docstring
|
| 167 |
+
def forward(
|
| 168 |
+
self,
|
| 169 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 170 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 171 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 172 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 173 |
+
past_key_values: Optional[Cache] = None,
|
| 174 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 175 |
+
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
| 176 |
+
vision_feature_select_strategy: Optional[str] = None,
|
| 177 |
+
use_cache: Optional[bool] = None,
|
| 178 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 179 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 180 |
+
) -> Union[tuple, AyaVisionModelOutputWithPast]:
|
| 181 |
+
vision_feature_layer = (
|
| 182 |
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
| 183 |
+
)
|
| 184 |
+
vision_feature_select_strategy = (
|
| 185 |
+
vision_feature_select_strategy
|
| 186 |
+
if vision_feature_select_strategy is not None
|
| 187 |
+
else self.config.vision_feature_select_strategy
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 191 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 192 |
+
|
| 193 |
+
if inputs_embeds is None:
|
| 194 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 195 |
+
|
| 196 |
+
if pixel_values is not None:
|
| 197 |
+
image_features = self.get_image_features(
|
| 198 |
+
pixel_values=pixel_values,
|
| 199 |
+
vision_feature_layer=vision_feature_layer,
|
| 200 |
+
vision_feature_select_strategy=vision_feature_select_strategy,
|
| 201 |
+
)
|
| 202 |
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 203 |
+
special_image_mask = self.get_placeholder_mask(
|
| 204 |
+
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
| 205 |
+
)
|
| 206 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
| 207 |
+
|
| 208 |
+
outputs = self.language_model(
|
| 209 |
+
attention_mask=attention_mask,
|
| 210 |
+
position_ids=position_ids,
|
| 211 |
+
past_key_values=past_key_values,
|
| 212 |
+
inputs_embeds=inputs_embeds,
|
| 213 |
+
use_cache=use_cache,
|
| 214 |
+
cache_position=cache_position,
|
| 215 |
+
**kwargs,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
return AyaVisionModelOutputWithPast(
|
| 219 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 220 |
+
past_key_values=outputs.past_key_values,
|
| 221 |
+
hidden_states=outputs.hidden_states,
|
| 222 |
+
attentions=outputs.attentions,
|
| 223 |
+
image_hidden_states=image_features if pixel_values is not None else None,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration):
|
| 228 |
+
def forward(
|
| 229 |
+
self,
|
| 230 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 231 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 232 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 233 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 234 |
+
past_key_values: Optional[Cache] = None,
|
| 235 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 236 |
+
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
| 237 |
+
vision_feature_select_strategy: Optional[str] = None,
|
| 238 |
+
labels: Optional[torch.LongTensor] = None,
|
| 239 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 240 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 241 |
+
image_sizes: Optional[torch.Tensor] = None,
|
| 242 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 243 |
+
) -> Union[tuple, AyaVisionCausalLMOutputWithPast]:
|
| 244 |
+
r"""
|
| 245 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 246 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 247 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 248 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 249 |
+
|
| 250 |
+
Example:
|
| 251 |
+
|
| 252 |
+
```python
|
| 253 |
+
>>> from transformers import AutoProcessor, AyaVisionForConditionalGeneration
|
| 254 |
+
>>> import torch
|
| 255 |
+
|
| 256 |
+
>>> torch_device = "cuda:0"
|
| 257 |
+
>>> processor = AutoProcessor.from_pretrained("CohereForAI/aya-vision-8b", use_fast=True)
|
| 258 |
+
>>> model = AyaVisionForConditionalGeneration.from_pretrained("CohereForAI/aya-vision-8b", device_map=torch_device)
|
| 259 |
+
|
| 260 |
+
>>> messages = [
|
| 261 |
+
... {
|
| 262 |
+
... "role": "user",
|
| 263 |
+
... "content": [
|
| 264 |
+
... {
|
| 265 |
+
... "type": "image",
|
| 266 |
+
... "url": "https://pbs.twimg.com/media/Fx7YvfQWYAIp6rZ?format=jpg&name=medium",
|
| 267 |
+
... },
|
| 268 |
+
... {"type": "text", "text": "चित्र में लिखा पाठ क्या कहता है?"},
|
| 269 |
+
... ],
|
| 270 |
+
... }
|
| 271 |
+
... ]
|
| 272 |
+
|
| 273 |
+
>>> inputs = processor.apply_chat_template(
|
| 274 |
+
... messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", device=torch_device
|
| 275 |
+
... ).to(model.device)
|
| 276 |
+
|
| 277 |
+
>>> gen_tokens = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.3)
|
| 278 |
+
>>> processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
| 279 |
+
```"""
|
| 280 |
+
super().forward(
|
| 281 |
+
input_ids=input_ids,
|
| 282 |
+
pixel_values=pixel_values,
|
| 283 |
+
attention_mask=attention_mask,
|
| 284 |
+
position_ids=position_ids,
|
| 285 |
+
past_key_values=past_key_values,
|
| 286 |
+
inputs_embeds=inputs_embeds,
|
| 287 |
+
vision_feature_layer=vision_feature_layer,
|
| 288 |
+
vision_feature_select_strategy=vision_feature_select_strategy,
|
| 289 |
+
labels=labels,
|
| 290 |
+
cache_position=cache_position,
|
| 291 |
+
logits_to_keep=logits_to_keep,
|
| 292 |
+
image_sizes=image_sizes,
|
| 293 |
+
**kwargs,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]
|
venv/lib/python3.13/site-packages/transformers/models/aya_vision/processing_aya_vision.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
from ...image_processing_utils import BatchFeature
|
| 21 |
+
from ...image_utils import ImageInput, make_flat_list_of_images
|
| 22 |
+
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
| 23 |
+
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class AyaVisionImagesKwargs(ImagesKwargs, total=False):
|
| 27 |
+
crop_to_patches: Optional[bool]
|
| 28 |
+
min_patches: Optional[int]
|
| 29 |
+
max_patches: Optional[int]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AyaVisionProcessorKwargs(ProcessingKwargs, total=False):
|
| 33 |
+
images_kwargs: AyaVisionImagesKwargs
|
| 34 |
+
_defaults = {
|
| 35 |
+
"text_kwargs": {
|
| 36 |
+
"padding_side": "left",
|
| 37 |
+
"padding": True,
|
| 38 |
+
"return_mm_token_type_ids": False,
|
| 39 |
+
},
|
| 40 |
+
"images_kwargs": {
|
| 41 |
+
"crop_to_patches": True,
|
| 42 |
+
},
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class AyaVisionProcessor(ProcessorMixin):
|
| 47 |
+
r"""
|
| 48 |
+
Constructs a AyaVision processor which wraps a [`AutoImageProcessor`] and
|
| 49 |
+
[`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and
|
| 50 |
+
tokenizer functionalities. See the [`~AyaVisionProcessor.__call__`] and [`~AyaVisionProcessor.decode`] for more information.
|
| 51 |
+
Args:
|
| 52 |
+
image_processor ([`AutoImageProcessor`], *optional*):
|
| 53 |
+
The image processor is a required input.
|
| 54 |
+
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*):
|
| 55 |
+
The tokenizer is a required input.
|
| 56 |
+
patch_size (`int`, *optional*, defaults to 28):
|
| 57 |
+
The size of image patches for tokenization.
|
| 58 |
+
img_size (`int`, *optional*, defaults to 364):
|
| 59 |
+
The size of the image to be tokenized. This should correspond to the size given to the image processor.
|
| 60 |
+
image_token (`str`, *optional*, defaults to `"<image>"`):
|
| 61 |
+
The token to be used to represent an image in the text.
|
| 62 |
+
downsample_factor (`int`, *optional*, defaults to 1):
|
| 63 |
+
The factor by which to scale the patch size.
|
| 64 |
+
start_of_img_token (`str`, *optional*, defaults to `"<|START_OF_IMG|>"`):
|
| 65 |
+
The token to be used to represent the start of an image in the text.
|
| 66 |
+
end_of_img_token (`str`, *optional*, defaults to `"<|END_OF_IMG|>"`):
|
| 67 |
+
The token to be used to represent the end of an image in the text.
|
| 68 |
+
img_patch_token (`str`, *optional*, defaults to `"<|IMG_PATCH|>"`):
|
| 69 |
+
The token to be used to represent an image patch in the text.
|
| 70 |
+
img_line_break_token (`str`, *optional*, defaults to `"<|IMG_LINE_BREAK|>"`):
|
| 71 |
+
The token to be used to represent a line break in the text.
|
| 72 |
+
tile_token (`str`, *optional*, defaults to `"TILE"`):
|
| 73 |
+
The token to be used to represent an image patch in the text.
|
| 74 |
+
tile_global_token (`str`, *optional*, defaults to `"TILE_GLOBAL"`):
|
| 75 |
+
The token to be used to represent the cover image in the text.
|
| 76 |
+
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
| 77 |
+
in a chat into a tokenizable string.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
attributes = ["image_processor", "tokenizer"]
|
| 81 |
+
image_processor_class = "AutoImageProcessor"
|
| 82 |
+
tokenizer_class = "AutoTokenizer"
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
image_processor=None,
|
| 87 |
+
tokenizer=None,
|
| 88 |
+
patch_size: int = 28,
|
| 89 |
+
img_size: int = 364,
|
| 90 |
+
image_token="<image>", # set the default and let users change if they have peculiar special tokens in rare cases
|
| 91 |
+
downsample_factor: int = 1,
|
| 92 |
+
start_of_img_token="<|START_OF_IMG|>",
|
| 93 |
+
end_of_img_token="<|END_OF_IMG|>",
|
| 94 |
+
img_patch_token="<|IMG_PATCH|>",
|
| 95 |
+
img_line_break_token="<|IMG_LINE_BREAK|>",
|
| 96 |
+
tile_token="TILE",
|
| 97 |
+
tile_global_token="TILE_GLOBAL",
|
| 98 |
+
chat_template=None,
|
| 99 |
+
**kwargs,
|
| 100 |
+
):
|
| 101 |
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
| 102 |
+
|
| 103 |
+
self.image_token = image_token
|
| 104 |
+
self.patch_size = patch_size * downsample_factor
|
| 105 |
+
self.img_size = img_size
|
| 106 |
+
|
| 107 |
+
self.start_of_img_token = start_of_img_token
|
| 108 |
+
self.end_of_img_token = end_of_img_token
|
| 109 |
+
self.img_patch_token = img_patch_token
|
| 110 |
+
self.img_line_break_token = img_line_break_token
|
| 111 |
+
self.tile_token = tile_token
|
| 112 |
+
self.tile_global_token = tile_global_token
|
| 113 |
+
self.image_token_id = tokenizer.convert_tokens_to_ids(self.img_patch_token)
|
| 114 |
+
self.image_ids = tokenizer.convert_tokens_to_ids(
|
| 115 |
+
[img_patch_token, tile_token, tile_global_token, start_of_img_token, end_of_img_token]
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def _prompt_split_image(self, num_patches):
|
| 119 |
+
"""
|
| 120 |
+
Create a structured string representation of image tokens
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
num_patches: Number of patches in the image
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
String with appropriate image tokens
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
img_patches_per_tile = (self.img_size // self.patch_size) ** 2
|
| 130 |
+
img_string = f"{self.start_of_img_token}"
|
| 131 |
+
if num_patches > 1:
|
| 132 |
+
for idx in range(1, num_patches):
|
| 133 |
+
img_string += f"{self.tile_token}_{idx}" + f"{self.img_patch_token}" * img_patches_per_tile
|
| 134 |
+
|
| 135 |
+
img_string += f"{self.tile_global_token}" + f"{self.img_patch_token}" * img_patches_per_tile
|
| 136 |
+
img_string += f"{self.end_of_img_token}"
|
| 137 |
+
return img_string
|
| 138 |
+
|
| 139 |
+
def __call__(
|
| 140 |
+
self,
|
| 141 |
+
images: Optional[ImageInput] = None,
|
| 142 |
+
text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
|
| 143 |
+
audio=None,
|
| 144 |
+
videos=None,
|
| 145 |
+
**kwargs: Unpack[AyaVisionProcessorKwargs],
|
| 146 |
+
) -> BatchFeature:
|
| 147 |
+
"""
|
| 148 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
| 149 |
+
and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text.
|
| 150 |
+
To prepare the vision inputs, this method forwards the `images` and `kwargs` arguments to
|
| 151 |
+
GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
| 155 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
| 156 |
+
tensor. Both channels-first and channels-last formats are supported.
|
| 157 |
+
text (`str`, `list[str]`, `list[list[str]]`):
|
| 158 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 159 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 160 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 161 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 162 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
| 163 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 164 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 165 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
| 166 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
| 170 |
+
|
| 171 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 172 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
| 173 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
| 174 |
+
`None`).
|
| 175 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
| 176 |
+
"""
|
| 177 |
+
if text is None:
|
| 178 |
+
raise ValueError("You have to specify text.")
|
| 179 |
+
|
| 180 |
+
output_kwargs = self._merge_kwargs(
|
| 181 |
+
AyaVisionProcessorKwargs,
|
| 182 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 183 |
+
**kwargs,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if not isinstance(text, (list, tuple)):
|
| 187 |
+
text = [text]
|
| 188 |
+
|
| 189 |
+
# Process images
|
| 190 |
+
image_inputs = {}
|
| 191 |
+
if images is not None:
|
| 192 |
+
images = self.image_processor.fetch_images(images)
|
| 193 |
+
images = make_flat_list_of_images(images)
|
| 194 |
+
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
| 195 |
+
num_patches = image_inputs.pop("num_patches")
|
| 196 |
+
image_index = 0
|
| 197 |
+
processed_text = []
|
| 198 |
+
for prompt in text:
|
| 199 |
+
new_prompt = prompt
|
| 200 |
+
while "<image>" in new_prompt:
|
| 201 |
+
# Replace the image placeholder with structured image tokens
|
| 202 |
+
image_tokens = self._prompt_split_image(num_patches[image_index])
|
| 203 |
+
new_prompt = new_prompt.replace("<image>", image_tokens, 1)
|
| 204 |
+
image_index += 1
|
| 205 |
+
processed_text.append(new_prompt)
|
| 206 |
+
|
| 207 |
+
if image_index != len(images):
|
| 208 |
+
raise ValueError("Number of image placeholders in the prompt does not match the number of images.")
|
| 209 |
+
|
| 210 |
+
text = processed_text
|
| 211 |
+
|
| 212 |
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 213 |
+
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
| 214 |
+
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
|
| 215 |
+
|
| 216 |
+
if return_mm_token_type_ids:
|
| 217 |
+
array_ids = np.array(text_inputs["input_ids"])
|
| 218 |
+
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
| 219 |
+
mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
|
| 220 |
+
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
| 221 |
+
|
| 222 |
+
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
| 223 |
+
|
| 224 |
+
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
| 225 |
+
"""
|
| 226 |
+
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
image_sizes (`list[list[int]]`, *optional*):
|
| 230 |
+
The input sizes formatted as (height, width) per each image.
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
| 234 |
+
input modalities, along with other useful data.
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
vision_data = {}
|
| 238 |
+
if image_sizes is not None:
|
| 239 |
+
images_kwargs = AyaVisionProcessorKwargs._defaults.get("images_kwargs", {})
|
| 240 |
+
images_kwargs.update(kwargs)
|
| 241 |
+
|
| 242 |
+
num_image_patches = [
|
| 243 |
+
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
|
| 244 |
+
for image_size in image_sizes
|
| 245 |
+
]
|
| 246 |
+
|
| 247 |
+
token_per_patch = (self.img_size // self.patch_size) ** 2
|
| 248 |
+
num_image_tokens = [
|
| 249 |
+
token_per_patch + 3 + sum(token_per_patch + 1 for _ in range(1, num_patches))
|
| 250 |
+
for num_patches in num_image_patches
|
| 251 |
+
] # Add +3 and +1 for BOI/EOI and image tile tokens
|
| 252 |
+
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
| 253 |
+
|
| 254 |
+
return MultiModalData(**vision_data)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
__all__ = ["AyaVisionProcessor"]
|
venv/lib/python3.13/site-packages/transformers/models/barthez/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .tokenization_barthez import *
|
| 22 |
+
from .tokenization_barthez_fast import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
venv/lib/python3.13/site-packages/transformers/models/barthez/tokenization_barthez.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020 Ecole Polytechnique and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License
|
| 15 |
+
"""Tokenization classes for the BARThez model."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from shutil import copyfile
|
| 19 |
+
from typing import Any, Optional
|
| 20 |
+
|
| 21 |
+
import sentencepiece as spm
|
| 22 |
+
|
| 23 |
+
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 24 |
+
from ...utils import logging
|
| 25 |
+
from ...utils.import_utils import requires
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
SPIECE_UNDERLINE = "▁"
|
| 34 |
+
|
| 35 |
+
# TODO this class is useless. This is the most standard sentencpiece model. Let's find which one is closest and nuke this.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@requires(backends=("sentencepiece",))
|
| 39 |
+
class BarthezTokenizer(PreTrainedTokenizer):
|
| 40 |
+
"""
|
| 41 |
+
Adapted from [`CamembertTokenizer`] and [`BartTokenizer`]. Construct a BARThez tokenizer. Based on
|
| 42 |
+
[SentencePiece](https://github.com/google/sentencepiece).
|
| 43 |
+
|
| 44 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 45 |
+
this superclass for more information regarding those methods.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
vocab_file (`str`):
|
| 49 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
| 50 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 51 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 52 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 53 |
+
|
| 54 |
+
<Tip>
|
| 55 |
+
|
| 56 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 57 |
+
sequence. The token used is the `cls_token`.
|
| 58 |
+
|
| 59 |
+
</Tip>
|
| 60 |
+
|
| 61 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 62 |
+
The end of sequence token.
|
| 63 |
+
|
| 64 |
+
<Tip>
|
| 65 |
+
|
| 66 |
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
| 67 |
+
The token used is the `sep_token`.
|
| 68 |
+
|
| 69 |
+
</Tip>
|
| 70 |
+
|
| 71 |
+
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
| 72 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 73 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 74 |
+
token of a sequence built with special tokens.
|
| 75 |
+
cls_token (`str`, *optional*, defaults to `"<s>"`):
|
| 76 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 77 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 78 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 79 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 80 |
+
token instead.
|
| 81 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 82 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 83 |
+
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
| 84 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 85 |
+
modeling. This is the token which the model will try to predict.
|
| 86 |
+
sp_model_kwargs (`dict`, *optional*):
|
| 87 |
+
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
| 88 |
+
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
| 89 |
+
to set:
|
| 90 |
+
|
| 91 |
+
- `enable_sampling`: Enable subword regularization.
|
| 92 |
+
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
| 93 |
+
|
| 94 |
+
- `nbest_size = {0,1}`: No sampling is performed.
|
| 95 |
+
- `nbest_size > 1`: samples from the nbest_size results.
|
| 96 |
+
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
| 97 |
+
using forward-filtering-and-backward-sampling algorithm.
|
| 98 |
+
|
| 99 |
+
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
| 100 |
+
BPE-dropout.
|
| 101 |
+
|
| 102 |
+
Attributes:
|
| 103 |
+
sp_model (`SentencePieceProcessor`):
|
| 104 |
+
The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 108 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
vocab_file,
|
| 113 |
+
bos_token="<s>",
|
| 114 |
+
eos_token="</s>",
|
| 115 |
+
sep_token="</s>",
|
| 116 |
+
cls_token="<s>",
|
| 117 |
+
unk_token="<unk>",
|
| 118 |
+
pad_token="<pad>",
|
| 119 |
+
mask_token="<mask>",
|
| 120 |
+
sp_model_kwargs: Optional[dict[str, Any]] = None,
|
| 121 |
+
**kwargs,
|
| 122 |
+
) -> None:
|
| 123 |
+
# Mask token behave like a normal word, i.e. include the space before it. Will have normalized=False by default this way
|
| 124 |
+
mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token
|
| 125 |
+
|
| 126 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 127 |
+
|
| 128 |
+
self.vocab_file = vocab_file
|
| 129 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 130 |
+
self.sp_model.Load(str(vocab_file))
|
| 131 |
+
super().__init__(
|
| 132 |
+
bos_token=bos_token,
|
| 133 |
+
eos_token=eos_token,
|
| 134 |
+
unk_token=unk_token,
|
| 135 |
+
sep_token=sep_token,
|
| 136 |
+
cls_token=cls_token,
|
| 137 |
+
pad_token=pad_token,
|
| 138 |
+
mask_token=mask_token,
|
| 139 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
| 140 |
+
**kwargs,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def build_inputs_with_special_tokens(
|
| 144 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 145 |
+
) -> list[int]:
|
| 146 |
+
"""
|
| 147 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 148 |
+
adding special tokens. A BARThez sequence has the following format:
|
| 149 |
+
|
| 150 |
+
- single sequence: `<s> X </s>`
|
| 151 |
+
- pair of sequences: `<s> A </s></s> B </s>`
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
token_ids_0 (`list[int]`):
|
| 155 |
+
List of IDs to which the special tokens will be added.
|
| 156 |
+
token_ids_1 (`list[int]`, *optional*):
|
| 157 |
+
Optional second list of IDs for sequence pairs.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
`list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
if token_ids_1 is None:
|
| 164 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 165 |
+
cls = [self.cls_token_id]
|
| 166 |
+
sep = [self.sep_token_id]
|
| 167 |
+
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
|
| 168 |
+
|
| 169 |
+
def get_special_tokens_mask(
|
| 170 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
|
| 171 |
+
) -> list[int]:
|
| 172 |
+
"""
|
| 173 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 174 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
token_ids_0 (`list[int]`):
|
| 178 |
+
List of IDs.
|
| 179 |
+
token_ids_1 (`list[int]`, *optional*):
|
| 180 |
+
Optional second list of IDs for sequence pairs.
|
| 181 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 182 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
`list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 186 |
+
"""
|
| 187 |
+
if already_has_special_tokens:
|
| 188 |
+
return super().get_special_tokens_mask(
|
| 189 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if token_ids_1 is None:
|
| 193 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 194 |
+
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
|
| 195 |
+
|
| 196 |
+
def create_token_type_ids_from_sequences(
|
| 197 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 198 |
+
) -> list[int]:
|
| 199 |
+
"""
|
| 200 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
token_ids_0 (`list[int]`):
|
| 204 |
+
List of IDs.
|
| 205 |
+
token_ids_1 (`list[int]`, *optional*):
|
| 206 |
+
Optional second list of IDs for sequence pairs.
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
`list[int]`: List of zeros.
|
| 210 |
+
"""
|
| 211 |
+
sep = [self.sep_token_id]
|
| 212 |
+
cls = [self.cls_token_id]
|
| 213 |
+
|
| 214 |
+
if token_ids_1 is None:
|
| 215 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 216 |
+
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
| 217 |
+
|
| 218 |
+
@property
|
| 219 |
+
def vocab_size(self):
|
| 220 |
+
return len(self.sp_model)
|
| 221 |
+
|
| 222 |
+
def get_vocab(self):
|
| 223 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 224 |
+
vocab.update(self.added_tokens_encoder)
|
| 225 |
+
return vocab
|
| 226 |
+
|
| 227 |
+
def _tokenize(self, text: str) -> list[str]:
|
| 228 |
+
return self.sp_model.encode(text, out_type=str)
|
| 229 |
+
|
| 230 |
+
def _convert_token_to_id(self, token):
|
| 231 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 232 |
+
return self.sp_model.PieceToId(token)
|
| 233 |
+
|
| 234 |
+
def _convert_id_to_token(self, index):
|
| 235 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 236 |
+
return self.sp_model.IdToPiece(index)
|
| 237 |
+
|
| 238 |
+
# Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.convert_tokens_to_string
|
| 239 |
+
def convert_tokens_to_string(self, tokens):
|
| 240 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 241 |
+
current_sub_tokens = []
|
| 242 |
+
out_string = ""
|
| 243 |
+
prev_is_special = False
|
| 244 |
+
for token in tokens:
|
| 245 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
| 246 |
+
if token in self.all_special_tokens:
|
| 247 |
+
if not prev_is_special:
|
| 248 |
+
out_string += " "
|
| 249 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
| 250 |
+
prev_is_special = True
|
| 251 |
+
current_sub_tokens = []
|
| 252 |
+
else:
|
| 253 |
+
current_sub_tokens.append(token)
|
| 254 |
+
prev_is_special = False
|
| 255 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 256 |
+
return out_string.strip()
|
| 257 |
+
|
| 258 |
+
def __getstate__(self):
|
| 259 |
+
state = self.__dict__.copy()
|
| 260 |
+
state["sp_model"] = None
|
| 261 |
+
return state
|
| 262 |
+
|
| 263 |
+
def __setstate__(self, d):
|
| 264 |
+
self.__dict__ = d
|
| 265 |
+
|
| 266 |
+
# for backward compatibility
|
| 267 |
+
if not hasattr(self, "sp_model_kwargs"):
|
| 268 |
+
self.sp_model_kwargs = {}
|
| 269 |
+
|
| 270 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 271 |
+
self.sp_model.Load(self.vocab_file)
|
| 272 |
+
|
| 273 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
|
| 274 |
+
if not os.path.isdir(save_directory):
|
| 275 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 276 |
+
return
|
| 277 |
+
out_vocab_file = os.path.join(
|
| 278 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 282 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 283 |
+
elif not os.path.isfile(self.vocab_file):
|
| 284 |
+
with open(out_vocab_file, "wb") as fi:
|
| 285 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 286 |
+
fi.write(content_spiece_model)
|
| 287 |
+
|
| 288 |
+
return (out_vocab_file,)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
__all__ = ["BarthezTokenizer"]
|
venv/lib/python3.13/site-packages/transformers/models/barthez/tokenization_barthez_fast.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020 Ecole Polytechnique and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License
|
| 15 |
+
"""Tokenization classes for the BARThez model."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from shutil import copyfile
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
from ...tokenization_utils import AddedToken
|
| 22 |
+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
| 23 |
+
from ...utils import is_sentencepiece_available, logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_sentencepiece_available():
|
| 27 |
+
from .tokenization_barthez import BarthezTokenizer
|
| 28 |
+
else:
|
| 29 |
+
BarthezTokenizer = None
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
SPIECE_UNDERLINE = "▁"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BarthezTokenizerFast(PreTrainedTokenizerFast):
|
| 40 |
+
"""
|
| 41 |
+
Adapted from [`CamembertTokenizer`] and [`BartTokenizer`]. Construct a "fast" BARThez tokenizer. Based on
|
| 42 |
+
[SentencePiece](https://github.com/google/sentencepiece).
|
| 43 |
+
|
| 44 |
+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
| 45 |
+
refer to this superclass for more information regarding those methods.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
vocab_file (`str`):
|
| 49 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
| 50 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 51 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 52 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 53 |
+
|
| 54 |
+
<Tip>
|
| 55 |
+
|
| 56 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 57 |
+
sequence. The token used is the `cls_token`.
|
| 58 |
+
|
| 59 |
+
</Tip>
|
| 60 |
+
|
| 61 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 62 |
+
The end of sequence token.
|
| 63 |
+
|
| 64 |
+
<Tip>
|
| 65 |
+
|
| 66 |
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
| 67 |
+
The token used is the `sep_token`.
|
| 68 |
+
|
| 69 |
+
</Tip>
|
| 70 |
+
|
| 71 |
+
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
| 72 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 73 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 74 |
+
token of a sequence built with special tokens.
|
| 75 |
+
cls_token (`str`, *optional*, defaults to `"<s>"`):
|
| 76 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 77 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 78 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 79 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 80 |
+
token instead.
|
| 81 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 82 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 83 |
+
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
| 84 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 85 |
+
modeling. This is the token which the model will try to predict.
|
| 86 |
+
additional_special_tokens (`list[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
|
| 87 |
+
Additional special tokens used by the tokenizer.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 91 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 92 |
+
slow_tokenizer_class = BarthezTokenizer
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
vocab_file=None,
|
| 97 |
+
tokenizer_file=None,
|
| 98 |
+
bos_token="<s>",
|
| 99 |
+
eos_token="</s>",
|
| 100 |
+
sep_token="</s>",
|
| 101 |
+
cls_token="<s>",
|
| 102 |
+
unk_token="<unk>",
|
| 103 |
+
pad_token="<pad>",
|
| 104 |
+
mask_token="<mask>",
|
| 105 |
+
**kwargs,
|
| 106 |
+
):
|
| 107 |
+
# Mask token behave like a normal word, i.e. include the space before it
|
| 108 |
+
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
| 109 |
+
|
| 110 |
+
super().__init__(
|
| 111 |
+
vocab_file,
|
| 112 |
+
tokenizer_file=tokenizer_file,
|
| 113 |
+
bos_token=bos_token,
|
| 114 |
+
eos_token=eos_token,
|
| 115 |
+
unk_token=unk_token,
|
| 116 |
+
sep_token=sep_token,
|
| 117 |
+
cls_token=cls_token,
|
| 118 |
+
pad_token=pad_token,
|
| 119 |
+
mask_token=mask_token,
|
| 120 |
+
**kwargs,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
self.vocab_file = vocab_file
|
| 124 |
+
|
| 125 |
+
def build_inputs_with_special_tokens(
|
| 126 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 127 |
+
) -> list[int]:
|
| 128 |
+
"""
|
| 129 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 130 |
+
adding special tokens. A BARThez sequence has the following format:
|
| 131 |
+
|
| 132 |
+
- single sequence: `<s> X </s>`
|
| 133 |
+
- pair of sequences: `<s> A </s></s> B </s>`
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
token_ids_0 (`list[int]`):
|
| 137 |
+
List of IDs to which the special tokens will be added.
|
| 138 |
+
token_ids_1 (`list[int]`, *optional*):
|
| 139 |
+
Optional second list of IDs for sequence pairs.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
`list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
if token_ids_1 is None:
|
| 146 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 147 |
+
cls = [self.cls_token_id]
|
| 148 |
+
sep = [self.sep_token_id]
|
| 149 |
+
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
|
| 150 |
+
|
| 151 |
+
def create_token_type_ids_from_sequences(
|
| 152 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 153 |
+
) -> list[int]:
|
| 154 |
+
"""
|
| 155 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
token_ids_0 (`list[int]`):
|
| 159 |
+
List of IDs.
|
| 160 |
+
token_ids_1 (`list[int]`, *optional*):
|
| 161 |
+
Optional second list of IDs for sequence pairs.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
`list[int]`: List of zeros.
|
| 165 |
+
"""
|
| 166 |
+
sep = [self.sep_token_id]
|
| 167 |
+
cls = [self.cls_token_id]
|
| 168 |
+
|
| 169 |
+
if token_ids_1 is None:
|
| 170 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 171 |
+
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
| 172 |
+
|
| 173 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
|
| 174 |
+
if not self.can_save_slow_tokenizer:
|
| 175 |
+
raise ValueError(
|
| 176 |
+
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
| 177 |
+
"tokenizer."
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if not os.path.isdir(save_directory):
|
| 181 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 182 |
+
return
|
| 183 |
+
out_vocab_file = os.path.join(
|
| 184 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
| 188 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 189 |
+
|
| 190 |
+
return (out_vocab_file,)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
__all__ = ["BarthezTokenizerFast"]
|
venv/lib/python3.13/site-packages/transformers/models/bert_japanese/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .tokenization_bert_japanese import *
|
| 22 |
+
else:
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
_file = globals()["__file__"]
|
| 26 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
venv/lib/python3.13/site-packages/transformers/models/bert_japanese/tokenization_bert_japanese.py
ADDED
|
@@ -0,0 +1,952 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
import copy
|
| 19 |
+
import os
|
| 20 |
+
import unicodedata
|
| 21 |
+
from typing import Any, Optional
|
| 22 |
+
|
| 23 |
+
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
| 24 |
+
from ...utils import is_sentencepiece_available, is_sudachi_projection_available, logging
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if is_sentencepiece_available():
|
| 28 |
+
import sentencepiece as spm
|
| 29 |
+
else:
|
| 30 |
+
spm = None
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "spm_file": "spiece.model"}
|
| 35 |
+
|
| 36 |
+
SPIECE_UNDERLINE = "▁"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Copied from transformers.models.bert.tokenization_bert.load_vocab
|
| 40 |
+
def load_vocab(vocab_file):
|
| 41 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 42 |
+
vocab = collections.OrderedDict()
|
| 43 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 44 |
+
tokens = reader.readlines()
|
| 45 |
+
for index, token in enumerate(tokens):
|
| 46 |
+
token = token.rstrip("\n")
|
| 47 |
+
vocab[token] = index
|
| 48 |
+
return vocab
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
|
| 52 |
+
def whitespace_tokenize(text):
|
| 53 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
| 54 |
+
text = text.strip()
|
| 55 |
+
if not text:
|
| 56 |
+
return []
|
| 57 |
+
tokens = text.split()
|
| 58 |
+
return tokens
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class BertJapaneseTokenizer(PreTrainedTokenizer):
|
| 62 |
+
r"""
|
| 63 |
+
Construct a BERT tokenizer for Japanese text.
|
| 64 |
+
|
| 65 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer
|
| 66 |
+
to: this superclass for more information regarding those methods.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
vocab_file (`str`):
|
| 70 |
+
Path to a one-wordpiece-per-line vocabulary file.
|
| 71 |
+
spm_file (`str`, *optional*):
|
| 72 |
+
Path to [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm or .model
|
| 73 |
+
extension) that contains the vocabulary.
|
| 74 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
| 75 |
+
Whether to lower case the input. Only has an effect when do_basic_tokenize=True.
|
| 76 |
+
do_word_tokenize (`bool`, *optional*, defaults to `True`):
|
| 77 |
+
Whether to do word tokenization.
|
| 78 |
+
do_subword_tokenize (`bool`, *optional*, defaults to `True`):
|
| 79 |
+
Whether to do subword tokenization.
|
| 80 |
+
word_tokenizer_type (`str`, *optional*, defaults to `"basic"`):
|
| 81 |
+
Type of word tokenizer. Choose from ["basic", "mecab", "sudachi", "jumanpp"].
|
| 82 |
+
subword_tokenizer_type (`str`, *optional*, defaults to `"wordpiece"`):
|
| 83 |
+
Type of subword tokenizer. Choose from ["wordpiece", "character", "sentencepiece",].
|
| 84 |
+
mecab_kwargs (`dict`, *optional*):
|
| 85 |
+
Dictionary passed to the `MecabTokenizer` constructor.
|
| 86 |
+
sudachi_kwargs (`dict`, *optional*):
|
| 87 |
+
Dictionary passed to the `SudachiTokenizer` constructor.
|
| 88 |
+
jumanpp_kwargs (`dict`, *optional*):
|
| 89 |
+
Dictionary passed to the `JumanppTokenizer` constructor.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
vocab_file,
|
| 97 |
+
spm_file=None,
|
| 98 |
+
do_lower_case=False,
|
| 99 |
+
do_word_tokenize=True,
|
| 100 |
+
do_subword_tokenize=True,
|
| 101 |
+
word_tokenizer_type="basic",
|
| 102 |
+
subword_tokenizer_type="wordpiece",
|
| 103 |
+
never_split=None,
|
| 104 |
+
unk_token="[UNK]",
|
| 105 |
+
sep_token="[SEP]",
|
| 106 |
+
pad_token="[PAD]",
|
| 107 |
+
cls_token="[CLS]",
|
| 108 |
+
mask_token="[MASK]",
|
| 109 |
+
mecab_kwargs=None,
|
| 110 |
+
sudachi_kwargs=None,
|
| 111 |
+
jumanpp_kwargs=None,
|
| 112 |
+
**kwargs,
|
| 113 |
+
):
|
| 114 |
+
if subword_tokenizer_type == "sentencepiece":
|
| 115 |
+
if not os.path.isfile(spm_file):
|
| 116 |
+
raise ValueError(
|
| 117 |
+
f"Can't find a vocabulary file at path '{spm_file}'. To load the vocabulary from a Google"
|
| 118 |
+
" pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
| 119 |
+
)
|
| 120 |
+
self.spm_file = spm_file
|
| 121 |
+
else:
|
| 122 |
+
if not os.path.isfile(vocab_file):
|
| 123 |
+
raise ValueError(
|
| 124 |
+
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google"
|
| 125 |
+
" pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
| 126 |
+
)
|
| 127 |
+
self.vocab = load_vocab(vocab_file)
|
| 128 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 129 |
+
|
| 130 |
+
self.do_word_tokenize = do_word_tokenize
|
| 131 |
+
self.word_tokenizer_type = word_tokenizer_type
|
| 132 |
+
self.lower_case = do_lower_case
|
| 133 |
+
self.never_split = never_split
|
| 134 |
+
self.mecab_kwargs = copy.deepcopy(mecab_kwargs)
|
| 135 |
+
self.sudachi_kwargs = copy.deepcopy(sudachi_kwargs)
|
| 136 |
+
self.jumanpp_kwargs = copy.deepcopy(jumanpp_kwargs)
|
| 137 |
+
if do_word_tokenize:
|
| 138 |
+
if word_tokenizer_type == "basic":
|
| 139 |
+
self.word_tokenizer = BasicTokenizer(
|
| 140 |
+
do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=False
|
| 141 |
+
)
|
| 142 |
+
elif word_tokenizer_type == "mecab":
|
| 143 |
+
self.word_tokenizer = MecabTokenizer(
|
| 144 |
+
do_lower_case=do_lower_case, never_split=never_split, **(mecab_kwargs or {})
|
| 145 |
+
)
|
| 146 |
+
elif word_tokenizer_type == "sudachi":
|
| 147 |
+
self.word_tokenizer = SudachiTokenizer(
|
| 148 |
+
do_lower_case=do_lower_case, never_split=never_split, **(sudachi_kwargs or {})
|
| 149 |
+
)
|
| 150 |
+
elif word_tokenizer_type == "jumanpp":
|
| 151 |
+
self.word_tokenizer = JumanppTokenizer(
|
| 152 |
+
do_lower_case=do_lower_case, never_split=never_split, **(jumanpp_kwargs or {})
|
| 153 |
+
)
|
| 154 |
+
else:
|
| 155 |
+
raise ValueError(f"Invalid word_tokenizer_type '{word_tokenizer_type}' is specified.")
|
| 156 |
+
|
| 157 |
+
self.do_subword_tokenize = do_subword_tokenize
|
| 158 |
+
self.subword_tokenizer_type = subword_tokenizer_type
|
| 159 |
+
if do_subword_tokenize:
|
| 160 |
+
if subword_tokenizer_type == "wordpiece":
|
| 161 |
+
self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
|
| 162 |
+
elif subword_tokenizer_type == "character":
|
| 163 |
+
self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=str(unk_token))
|
| 164 |
+
elif subword_tokenizer_type == "sentencepiece":
|
| 165 |
+
self.subword_tokenizer = SentencepieceTokenizer(vocab=self.spm_file, unk_token=str(unk_token))
|
| 166 |
+
else:
|
| 167 |
+
raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.")
|
| 168 |
+
super().__init__(
|
| 169 |
+
spm_file=spm_file,
|
| 170 |
+
unk_token=unk_token,
|
| 171 |
+
sep_token=sep_token,
|
| 172 |
+
pad_token=pad_token,
|
| 173 |
+
cls_token=cls_token,
|
| 174 |
+
mask_token=mask_token,
|
| 175 |
+
do_lower_case=do_lower_case,
|
| 176 |
+
do_word_tokenize=do_word_tokenize,
|
| 177 |
+
do_subword_tokenize=do_subword_tokenize,
|
| 178 |
+
word_tokenizer_type=word_tokenizer_type,
|
| 179 |
+
subword_tokenizer_type=subword_tokenizer_type,
|
| 180 |
+
never_split=never_split,
|
| 181 |
+
mecab_kwargs=mecab_kwargs,
|
| 182 |
+
sudachi_kwargs=sudachi_kwargs,
|
| 183 |
+
jumanpp_kwargs=jumanpp_kwargs,
|
| 184 |
+
**kwargs,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
@property
|
| 188 |
+
def do_lower_case(self):
|
| 189 |
+
return self.lower_case
|
| 190 |
+
|
| 191 |
+
def __getstate__(self):
|
| 192 |
+
state = dict(self.__dict__)
|
| 193 |
+
if self.word_tokenizer_type in ["mecab", "sudachi", "jumanpp"]:
|
| 194 |
+
del state["word_tokenizer"]
|
| 195 |
+
return state
|
| 196 |
+
|
| 197 |
+
def __setstate__(self, state):
|
| 198 |
+
self.__dict__ = state
|
| 199 |
+
if self.word_tokenizer_type == "mecab":
|
| 200 |
+
self.word_tokenizer = MecabTokenizer(
|
| 201 |
+
do_lower_case=self.do_lower_case, never_split=self.never_split, **(self.mecab_kwargs or {})
|
| 202 |
+
)
|
| 203 |
+
elif self.word_tokenizer_type == "sudachi":
|
| 204 |
+
self.word_tokenizer = SudachiTokenizer(
|
| 205 |
+
do_lower_case=self.do_lower_case, never_split=self.never_split, **(self.sudachi_kwargs or {})
|
| 206 |
+
)
|
| 207 |
+
elif self.word_tokenizer_type == "jumanpp":
|
| 208 |
+
self.word_tokenizer = JumanppTokenizer(
|
| 209 |
+
do_lower_case=self.do_lower_case, never_split=self.never_split, **(self.jumanpp_kwargs or {})
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
def _tokenize(self, text):
|
| 213 |
+
if self.do_word_tokenize:
|
| 214 |
+
tokens = self.word_tokenizer.tokenize(text, never_split=self.all_special_tokens)
|
| 215 |
+
else:
|
| 216 |
+
tokens = [text]
|
| 217 |
+
|
| 218 |
+
if self.do_subword_tokenize:
|
| 219 |
+
split_tokens = [sub_token for token in tokens for sub_token in self.subword_tokenizer.tokenize(token)]
|
| 220 |
+
else:
|
| 221 |
+
split_tokens = tokens
|
| 222 |
+
|
| 223 |
+
return split_tokens
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def vocab_size(self):
|
| 227 |
+
if self.subword_tokenizer_type == "sentencepiece":
|
| 228 |
+
return len(self.subword_tokenizer.sp_model)
|
| 229 |
+
return len(self.vocab)
|
| 230 |
+
|
| 231 |
+
def get_vocab(self):
|
| 232 |
+
if self.subword_tokenizer_type == "sentencepiece":
|
| 233 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 234 |
+
vocab.update(self.added_tokens_encoder)
|
| 235 |
+
return vocab
|
| 236 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 237 |
+
|
| 238 |
+
def _convert_token_to_id(self, token):
|
| 239 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 240 |
+
if self.subword_tokenizer_type == "sentencepiece":
|
| 241 |
+
return self.subword_tokenizer.sp_model.PieceToId(token)
|
| 242 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 243 |
+
|
| 244 |
+
def _convert_id_to_token(self, index):
|
| 245 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 246 |
+
if self.subword_tokenizer_type == "sentencepiece":
|
| 247 |
+
return self.subword_tokenizer.sp_model.IdToPiece(index)
|
| 248 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 249 |
+
|
| 250 |
+
def convert_tokens_to_string(self, tokens):
|
| 251 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 252 |
+
if self.subword_tokenizer_type == "sentencepiece":
|
| 253 |
+
return self.subword_tokenizer.sp_model.decode(tokens)
|
| 254 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 255 |
+
return out_string
|
| 256 |
+
|
| 257 |
+
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.build_inputs_with_special_tokens
|
| 258 |
+
def build_inputs_with_special_tokens(
|
| 259 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 260 |
+
) -> list[int]:
|
| 261 |
+
"""
|
| 262 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 263 |
+
adding special tokens. A BERT sequence has the following format:
|
| 264 |
+
|
| 265 |
+
- single sequence: `[CLS] X [SEP]`
|
| 266 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
token_ids_0 (`List[int]`):
|
| 270 |
+
List of IDs to which the special tokens will be added.
|
| 271 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 272 |
+
Optional second list of IDs for sequence pairs.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 276 |
+
"""
|
| 277 |
+
if token_ids_1 is None:
|
| 278 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 279 |
+
cls = [self.cls_token_id]
|
| 280 |
+
sep = [self.sep_token_id]
|
| 281 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 282 |
+
|
| 283 |
+
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask
|
| 284 |
+
def get_special_tokens_mask(
|
| 285 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
|
| 286 |
+
) -> list[int]:
|
| 287 |
+
"""
|
| 288 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 289 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
token_ids_0 (`List[int]`):
|
| 293 |
+
List of IDs.
|
| 294 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 295 |
+
Optional second list of IDs for sequence pairs.
|
| 296 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 297 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
if already_has_special_tokens:
|
| 304 |
+
return super().get_special_tokens_mask(
|
| 305 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if token_ids_1 is not None:
|
| 309 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 310 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 311 |
+
|
| 312 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
|
| 313 |
+
if os.path.isdir(save_directory):
|
| 314 |
+
if self.subword_tokenizer_type == "sentencepiece":
|
| 315 |
+
vocab_file = os.path.join(
|
| 316 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["spm_file"]
|
| 317 |
+
)
|
| 318 |
+
else:
|
| 319 |
+
vocab_file = os.path.join(
|
| 320 |
+
save_directory,
|
| 321 |
+
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"],
|
| 322 |
+
)
|
| 323 |
+
else:
|
| 324 |
+
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
| 325 |
+
|
| 326 |
+
if self.subword_tokenizer_type == "sentencepiece":
|
| 327 |
+
with open(vocab_file, "wb") as writer:
|
| 328 |
+
content_spiece_model = self.subword_tokenizer.sp_model.serialized_model_proto()
|
| 329 |
+
writer.write(content_spiece_model)
|
| 330 |
+
else:
|
| 331 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 332 |
+
index = 0
|
| 333 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 334 |
+
if index != token_index:
|
| 335 |
+
logger.warning(
|
| 336 |
+
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
| 337 |
+
" Please check that the vocabulary is not corrupted!"
|
| 338 |
+
)
|
| 339 |
+
index = token_index
|
| 340 |
+
writer.write(token + "\n")
|
| 341 |
+
index += 1
|
| 342 |
+
return (vocab_file,)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class MecabTokenizer:
|
| 346 |
+
"""Runs basic tokenization with MeCab morphological parser."""
|
| 347 |
+
|
| 348 |
+
def __init__(
|
| 349 |
+
self,
|
| 350 |
+
do_lower_case=False,
|
| 351 |
+
never_split=None,
|
| 352 |
+
normalize_text=True,
|
| 353 |
+
mecab_dic: Optional[str] = "unidic_lite",
|
| 354 |
+
mecab_option: Optional[str] = None,
|
| 355 |
+
):
|
| 356 |
+
"""
|
| 357 |
+
Constructs a MecabTokenizer.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
**do_lower_case**: (*optional*) boolean (default True)
|
| 361 |
+
Whether to lowercase the input.
|
| 362 |
+
**never_split**: (*optional*) list of str
|
| 363 |
+
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
| 364 |
+
[`PreTrainedTokenizer.tokenize`]) List of tokens not to split.
|
| 365 |
+
**normalize_text**: (*optional*) boolean (default True)
|
| 366 |
+
Whether to apply unicode normalization to text before tokenization.
|
| 367 |
+
**mecab_dic**: (*optional*) string (default "ipadic")
|
| 368 |
+
Name of dictionary to be used for MeCab initialization. If you are using a system-installed dictionary,
|
| 369 |
+
set this option to `None` and modify *mecab_option*.
|
| 370 |
+
**mecab_option**: (*optional*) string
|
| 371 |
+
String passed to MeCab constructor.
|
| 372 |
+
"""
|
| 373 |
+
self.do_lower_case = do_lower_case
|
| 374 |
+
self.never_split = never_split if never_split is not None else []
|
| 375 |
+
self.normalize_text = normalize_text
|
| 376 |
+
|
| 377 |
+
try:
|
| 378 |
+
import fugashi
|
| 379 |
+
except ModuleNotFoundError as error:
|
| 380 |
+
raise error.__class__(
|
| 381 |
+
"You need to install fugashi to use MecabTokenizer. "
|
| 382 |
+
"See https://pypi.org/project/fugashi/ for installation."
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
mecab_option = mecab_option or ""
|
| 386 |
+
|
| 387 |
+
if mecab_dic is not None:
|
| 388 |
+
if mecab_dic == "ipadic":
|
| 389 |
+
try:
|
| 390 |
+
import ipadic
|
| 391 |
+
except ModuleNotFoundError as error:
|
| 392 |
+
raise error.__class__(
|
| 393 |
+
"The ipadic dictionary is not installed. "
|
| 394 |
+
"See https://github.com/polm/ipadic-py for installation."
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
dic_dir = ipadic.DICDIR
|
| 398 |
+
|
| 399 |
+
elif mecab_dic == "unidic_lite":
|
| 400 |
+
try:
|
| 401 |
+
import unidic_lite
|
| 402 |
+
except ModuleNotFoundError as error:
|
| 403 |
+
raise error.__class__(
|
| 404 |
+
"The unidic_lite dictionary is not installed. "
|
| 405 |
+
"See https://github.com/polm/unidic-lite for installation."
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
dic_dir = unidic_lite.DICDIR
|
| 409 |
+
|
| 410 |
+
elif mecab_dic == "unidic":
|
| 411 |
+
try:
|
| 412 |
+
import unidic
|
| 413 |
+
except ModuleNotFoundError as error:
|
| 414 |
+
raise error.__class__(
|
| 415 |
+
"The unidic dictionary is not installed. "
|
| 416 |
+
"See https://github.com/polm/unidic-py for installation."
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
dic_dir = unidic.DICDIR
|
| 420 |
+
if not os.path.isdir(dic_dir):
|
| 421 |
+
raise RuntimeError(
|
| 422 |
+
"The unidic dictionary itself is not found. "
|
| 423 |
+
"See https://github.com/polm/unidic-py for installation."
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
else:
|
| 427 |
+
raise ValueError("Invalid mecab_dic is specified.")
|
| 428 |
+
|
| 429 |
+
mecabrc = os.path.join(dic_dir, "mecabrc")
|
| 430 |
+
mecab_option = f'-d "{dic_dir}" -r "{mecabrc}" ' + mecab_option
|
| 431 |
+
|
| 432 |
+
self.mecab = fugashi.GenericTagger(mecab_option)
|
| 433 |
+
|
| 434 |
+
def tokenize(self, text, never_split=None, **kwargs):
|
| 435 |
+
"""Tokenizes a piece of text."""
|
| 436 |
+
if self.normalize_text:
|
| 437 |
+
text = unicodedata.normalize("NFKC", text)
|
| 438 |
+
|
| 439 |
+
never_split = self.never_split + (never_split if never_split is not None else [])
|
| 440 |
+
tokens = []
|
| 441 |
+
|
| 442 |
+
for word in self.mecab(text):
|
| 443 |
+
token = word.surface
|
| 444 |
+
|
| 445 |
+
if self.do_lower_case and token not in never_split:
|
| 446 |
+
token = token.lower()
|
| 447 |
+
|
| 448 |
+
tokens.append(token)
|
| 449 |
+
|
| 450 |
+
return tokens
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
class SudachiTokenizer:
|
| 454 |
+
"""Runs basic tokenization with Sudachi morphological parser."""
|
| 455 |
+
|
| 456 |
+
def __init__(
|
| 457 |
+
self,
|
| 458 |
+
do_lower_case=False,
|
| 459 |
+
never_split=None,
|
| 460 |
+
normalize_text=True,
|
| 461 |
+
trim_whitespace=False,
|
| 462 |
+
sudachi_split_mode="A",
|
| 463 |
+
sudachi_config_path=None,
|
| 464 |
+
sudachi_resource_dir=None,
|
| 465 |
+
sudachi_dict_type="core",
|
| 466 |
+
sudachi_projection=None,
|
| 467 |
+
):
|
| 468 |
+
"""
|
| 469 |
+
Constructs a SudachiTokenizer.
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
**do_lower_case**: (*optional*) boolean (default True)
|
| 473 |
+
Whether to lowercase the input.
|
| 474 |
+
**never_split**: (*optional*) list of str
|
| 475 |
+
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
| 476 |
+
[`PreTrainedTokenizer.tokenize`]) List of tokens not to split.
|
| 477 |
+
**normalize_text**: (*optional*) boolean (default True)
|
| 478 |
+
Whether to apply unicode normalization to text before tokenization.
|
| 479 |
+
**trim_whitespace**: (*optional*) boolean (default False)
|
| 480 |
+
Whether to trim all whitespace, tab, newline from tokens.
|
| 481 |
+
**sudachi_split_mode**: (*optional*) string
|
| 482 |
+
Split mode of sudachi, choose from `["A", "B", "C"]`.
|
| 483 |
+
**sudachi_config_path**: (*optional*) string
|
| 484 |
+
**sudachi_resource_dir**: (*optional*) string
|
| 485 |
+
**sudachi_dict_type**: (*optional*) string
|
| 486 |
+
dict type of sudachi, choose from `["small", "core", "full"]`.
|
| 487 |
+
**sudachi_projection**: (*optional*) string
|
| 488 |
+
Word projection mode of sudachi, choose from `["surface", "normalized", "reading", "dictionary", "dictionary_and_surface", "normalized_and_surface", "normalized_nouns"]`.
|
| 489 |
+
"""
|
| 490 |
+
|
| 491 |
+
self.do_lower_case = do_lower_case
|
| 492 |
+
self.never_split = never_split if never_split is not None else []
|
| 493 |
+
self.normalize_text = normalize_text
|
| 494 |
+
self.trim_whitespace = trim_whitespace
|
| 495 |
+
|
| 496 |
+
try:
|
| 497 |
+
from sudachipy import dictionary, tokenizer
|
| 498 |
+
except ImportError:
|
| 499 |
+
raise ImportError(
|
| 500 |
+
"You need to install sudachipy to use SudachiTokenizer. "
|
| 501 |
+
"See https://github.com/WorksApplications/SudachiPy for installation."
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if sudachi_split_mode == "A":
|
| 505 |
+
self.split_mode = tokenizer.Tokenizer.SplitMode.A
|
| 506 |
+
elif sudachi_split_mode == "B":
|
| 507 |
+
self.split_mode = tokenizer.Tokenizer.SplitMode.B
|
| 508 |
+
elif sudachi_split_mode == "C":
|
| 509 |
+
self.split_mode = tokenizer.Tokenizer.SplitMode.C
|
| 510 |
+
else:
|
| 511 |
+
raise ValueError("Invalid sudachi_split_mode is specified.")
|
| 512 |
+
|
| 513 |
+
self.projection = sudachi_projection
|
| 514 |
+
|
| 515 |
+
sudachi_dictionary = dictionary.Dictionary(
|
| 516 |
+
config_path=sudachi_config_path, resource_dir=sudachi_resource_dir, dict=sudachi_dict_type
|
| 517 |
+
)
|
| 518 |
+
if is_sudachi_projection_available():
|
| 519 |
+
self.sudachi = sudachi_dictionary.create(self.split_mode, projection=self.projection)
|
| 520 |
+
elif self.projection is not None:
|
| 521 |
+
raise ImportError("You need to install sudachipy>=0.6.8 to specify `projection` field in sudachi_kwargs.")
|
| 522 |
+
else:
|
| 523 |
+
self.sudachi = sudachi_dictionary.create(self.split_mode)
|
| 524 |
+
|
| 525 |
+
def tokenize(self, text, never_split=None, **kwargs):
|
| 526 |
+
"""Tokenizes a piece of text."""
|
| 527 |
+
if self.normalize_text:
|
| 528 |
+
text = unicodedata.normalize("NFKC", text)
|
| 529 |
+
|
| 530 |
+
never_split = self.never_split + (never_split if never_split is not None else [])
|
| 531 |
+
tokens = []
|
| 532 |
+
|
| 533 |
+
for word in self.sudachi.tokenize(text):
|
| 534 |
+
token = word.surface()
|
| 535 |
+
|
| 536 |
+
if self.do_lower_case and token not in never_split:
|
| 537 |
+
token = token.lower()
|
| 538 |
+
|
| 539 |
+
if self.trim_whitespace:
|
| 540 |
+
if token.strip() == "":
|
| 541 |
+
continue
|
| 542 |
+
else:
|
| 543 |
+
token = token.strip()
|
| 544 |
+
|
| 545 |
+
tokens.append(token)
|
| 546 |
+
|
| 547 |
+
return tokens
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
class JumanppTokenizer:
|
| 551 |
+
"""Runs basic tokenization with jumanpp morphological parser."""
|
| 552 |
+
|
| 553 |
+
def __init__(
|
| 554 |
+
self,
|
| 555 |
+
do_lower_case=False,
|
| 556 |
+
never_split=None,
|
| 557 |
+
normalize_text=True,
|
| 558 |
+
trim_whitespace=False,
|
| 559 |
+
):
|
| 560 |
+
"""
|
| 561 |
+
Constructs a JumanppTokenizer.
|
| 562 |
+
|
| 563 |
+
Args:
|
| 564 |
+
**do_lower_case**: (*optional*) boolean (default True)
|
| 565 |
+
Whether to lowercase the input.
|
| 566 |
+
**never_split**: (*optional*) list of str
|
| 567 |
+
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
| 568 |
+
[`PreTrainedTokenizer.tokenize`]) List of tokens not to split.
|
| 569 |
+
**normalize_text**: (*optional*) boolean (default True)
|
| 570 |
+
Whether to apply unicode normalization to text before tokenization.
|
| 571 |
+
**trim_whitespace**: (*optional*) boolean (default False)
|
| 572 |
+
Whether to trim all whitespace, tab, newline from tokens.
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
self.do_lower_case = do_lower_case
|
| 576 |
+
self.never_split = never_split if never_split is not None else []
|
| 577 |
+
self.normalize_text = normalize_text
|
| 578 |
+
self.trim_whitespace = trim_whitespace
|
| 579 |
+
|
| 580 |
+
try:
|
| 581 |
+
import rhoknp
|
| 582 |
+
except ImportError:
|
| 583 |
+
raise ImportError(
|
| 584 |
+
"You need to install rhoknp to use JumanppTokenizer. "
|
| 585 |
+
"See https://github.com/ku-nlp/rhoknp for installation."
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
self.juman = rhoknp.Jumanpp()
|
| 589 |
+
|
| 590 |
+
def tokenize(self, text, never_split=None, **kwargs):
|
| 591 |
+
"""Tokenizes a piece of text."""
|
| 592 |
+
if self.normalize_text:
|
| 593 |
+
text = unicodedata.normalize("NFKC", text)
|
| 594 |
+
|
| 595 |
+
text = text.strip()
|
| 596 |
+
|
| 597 |
+
never_split = self.never_split + (never_split if never_split is not None else [])
|
| 598 |
+
tokens = []
|
| 599 |
+
|
| 600 |
+
for mrph in self.juman.apply_to_sentence(text).morphemes:
|
| 601 |
+
token = mrph.text
|
| 602 |
+
|
| 603 |
+
if self.do_lower_case and token not in never_split:
|
| 604 |
+
token = token.lower()
|
| 605 |
+
|
| 606 |
+
if self.trim_whitespace:
|
| 607 |
+
if token.strip() == "":
|
| 608 |
+
continue
|
| 609 |
+
else:
|
| 610 |
+
token = token.strip()
|
| 611 |
+
|
| 612 |
+
tokens.append(token)
|
| 613 |
+
|
| 614 |
+
return tokens
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
class CharacterTokenizer:
|
| 618 |
+
"""Runs Character tokenization."""
|
| 619 |
+
|
| 620 |
+
def __init__(self, vocab, unk_token, normalize_text=True):
|
| 621 |
+
"""
|
| 622 |
+
Constructs a CharacterTokenizer.
|
| 623 |
+
|
| 624 |
+
Args:
|
| 625 |
+
**vocab**:
|
| 626 |
+
Vocabulary object.
|
| 627 |
+
**unk_token**: str
|
| 628 |
+
A special symbol for out-of-vocabulary token.
|
| 629 |
+
**normalize_text**: (`optional`) boolean (default True)
|
| 630 |
+
Whether to apply unicode normalization to text before tokenization.
|
| 631 |
+
"""
|
| 632 |
+
self.vocab = vocab
|
| 633 |
+
self.unk_token = unk_token
|
| 634 |
+
self.normalize_text = normalize_text
|
| 635 |
+
|
| 636 |
+
def tokenize(self, text):
|
| 637 |
+
"""
|
| 638 |
+
Tokenizes a piece of text into characters.
|
| 639 |
+
|
| 640 |
+
For example, `input = "apple""` will return as output `["a", "p", "p", "l", "e"]`.
|
| 641 |
+
|
| 642 |
+
Args:
|
| 643 |
+
text: A single token or whitespace separated tokens.
|
| 644 |
+
This should have already been passed through *BasicTokenizer*.
|
| 645 |
+
|
| 646 |
+
Returns:
|
| 647 |
+
A list of characters.
|
| 648 |
+
"""
|
| 649 |
+
if self.normalize_text:
|
| 650 |
+
text = unicodedata.normalize("NFKC", text)
|
| 651 |
+
|
| 652 |
+
output_tokens = []
|
| 653 |
+
for char in text:
|
| 654 |
+
if char not in self.vocab:
|
| 655 |
+
output_tokens.append(self.unk_token)
|
| 656 |
+
continue
|
| 657 |
+
|
| 658 |
+
output_tokens.append(char)
|
| 659 |
+
|
| 660 |
+
return output_tokens
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
|
| 664 |
+
class BasicTokenizer:
|
| 665 |
+
"""
|
| 666 |
+
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
|
| 667 |
+
|
| 668 |
+
Args:
|
| 669 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
| 670 |
+
Whether or not to lowercase the input when tokenizing.
|
| 671 |
+
never_split (`Iterable`, *optional*):
|
| 672 |
+
Collection of tokens which will never be split during tokenization. Only has an effect when
|
| 673 |
+
`do_basic_tokenize=True`
|
| 674 |
+
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
|
| 675 |
+
Whether or not to tokenize Chinese characters.
|
| 676 |
+
|
| 677 |
+
This should likely be deactivated for Japanese (see this
|
| 678 |
+
[issue](https://github.com/huggingface/transformers/issues/328)).
|
| 679 |
+
strip_accents (`bool`, *optional*):
|
| 680 |
+
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
| 681 |
+
value for `lowercase` (as in the original BERT).
|
| 682 |
+
do_split_on_punc (`bool`, *optional*, defaults to `True`):
|
| 683 |
+
In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
|
| 684 |
+
the full context of the words, such as contractions.
|
| 685 |
+
"""
|
| 686 |
+
|
| 687 |
+
def __init__(
|
| 688 |
+
self,
|
| 689 |
+
do_lower_case=True,
|
| 690 |
+
never_split=None,
|
| 691 |
+
tokenize_chinese_chars=True,
|
| 692 |
+
strip_accents=None,
|
| 693 |
+
do_split_on_punc=True,
|
| 694 |
+
):
|
| 695 |
+
if never_split is None:
|
| 696 |
+
never_split = []
|
| 697 |
+
self.do_lower_case = do_lower_case
|
| 698 |
+
self.never_split = set(never_split)
|
| 699 |
+
self.tokenize_chinese_chars = tokenize_chinese_chars
|
| 700 |
+
self.strip_accents = strip_accents
|
| 701 |
+
self.do_split_on_punc = do_split_on_punc
|
| 702 |
+
|
| 703 |
+
def tokenize(self, text, never_split=None):
|
| 704 |
+
"""
|
| 705 |
+
Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
|
| 706 |
+
|
| 707 |
+
Args:
|
| 708 |
+
never_split (`List[str]`, *optional*)
|
| 709 |
+
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
| 710 |
+
[`PreTrainedTokenizer.tokenize`]) List of token not to split.
|
| 711 |
+
"""
|
| 712 |
+
# union() returns a new set by concatenating the two sets.
|
| 713 |
+
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
|
| 714 |
+
text = self._clean_text(text)
|
| 715 |
+
|
| 716 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
| 717 |
+
# models. This is also applied to the English models now, but it doesn't
|
| 718 |
+
# matter since the English models were not trained on any Chinese data
|
| 719 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
| 720 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
| 721 |
+
# words in the English Wikipedia.).
|
| 722 |
+
if self.tokenize_chinese_chars:
|
| 723 |
+
text = self._tokenize_chinese_chars(text)
|
| 724 |
+
# prevents treating the same character with different unicode codepoints as different characters
|
| 725 |
+
unicode_normalized_text = unicodedata.normalize("NFC", text)
|
| 726 |
+
orig_tokens = whitespace_tokenize(unicode_normalized_text)
|
| 727 |
+
split_tokens = []
|
| 728 |
+
for token in orig_tokens:
|
| 729 |
+
if token not in never_split:
|
| 730 |
+
if self.do_lower_case:
|
| 731 |
+
token = token.lower()
|
| 732 |
+
if self.strip_accents is not False:
|
| 733 |
+
token = self._run_strip_accents(token)
|
| 734 |
+
elif self.strip_accents:
|
| 735 |
+
token = self._run_strip_accents(token)
|
| 736 |
+
split_tokens.extend(self._run_split_on_punc(token, never_split))
|
| 737 |
+
|
| 738 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
| 739 |
+
return output_tokens
|
| 740 |
+
|
| 741 |
+
def _run_strip_accents(self, text):
|
| 742 |
+
"""Strips accents from a piece of text."""
|
| 743 |
+
text = unicodedata.normalize("NFD", text)
|
| 744 |
+
output = []
|
| 745 |
+
for char in text:
|
| 746 |
+
cat = unicodedata.category(char)
|
| 747 |
+
if cat == "Mn":
|
| 748 |
+
continue
|
| 749 |
+
output.append(char)
|
| 750 |
+
return "".join(output)
|
| 751 |
+
|
| 752 |
+
def _run_split_on_punc(self, text, never_split=None):
|
| 753 |
+
"""Splits punctuation on a piece of text."""
|
| 754 |
+
if not self.do_split_on_punc or (never_split is not None and text in never_split):
|
| 755 |
+
return [text]
|
| 756 |
+
chars = list(text)
|
| 757 |
+
i = 0
|
| 758 |
+
start_new_word = True
|
| 759 |
+
output = []
|
| 760 |
+
while i < len(chars):
|
| 761 |
+
char = chars[i]
|
| 762 |
+
if _is_punctuation(char):
|
| 763 |
+
output.append([char])
|
| 764 |
+
start_new_word = True
|
| 765 |
+
else:
|
| 766 |
+
if start_new_word:
|
| 767 |
+
output.append([])
|
| 768 |
+
start_new_word = False
|
| 769 |
+
output[-1].append(char)
|
| 770 |
+
i += 1
|
| 771 |
+
|
| 772 |
+
return ["".join(x) for x in output]
|
| 773 |
+
|
| 774 |
+
def _tokenize_chinese_chars(self, text):
|
| 775 |
+
"""Adds whitespace around any CJK character."""
|
| 776 |
+
output = []
|
| 777 |
+
for char in text:
|
| 778 |
+
cp = ord(char)
|
| 779 |
+
if self._is_chinese_char(cp):
|
| 780 |
+
output.append(" ")
|
| 781 |
+
output.append(char)
|
| 782 |
+
output.append(" ")
|
| 783 |
+
else:
|
| 784 |
+
output.append(char)
|
| 785 |
+
return "".join(output)
|
| 786 |
+
|
| 787 |
+
def _is_chinese_char(self, cp):
|
| 788 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
| 789 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
| 790 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
| 791 |
+
#
|
| 792 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
| 793 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
| 794 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
| 795 |
+
# space-separated words, so they are not treated specially and handled
|
| 796 |
+
# like the all of the other languages.
|
| 797 |
+
if (
|
| 798 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
| 799 |
+
or (cp >= 0x3400 and cp <= 0x4DBF)
|
| 800 |
+
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
| 801 |
+
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
| 802 |
+
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
| 803 |
+
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
| 804 |
+
or (cp >= 0xF900 and cp <= 0xFAFF)
|
| 805 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
| 806 |
+
):
|
| 807 |
+
return True
|
| 808 |
+
|
| 809 |
+
return False
|
| 810 |
+
|
| 811 |
+
def _clean_text(self, text):
|
| 812 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
| 813 |
+
output = []
|
| 814 |
+
for char in text:
|
| 815 |
+
cp = ord(char)
|
| 816 |
+
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
| 817 |
+
continue
|
| 818 |
+
if _is_whitespace(char):
|
| 819 |
+
output.append(" ")
|
| 820 |
+
else:
|
| 821 |
+
output.append(char)
|
| 822 |
+
return "".join(output)
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
|
| 826 |
+
class WordpieceTokenizer:
|
| 827 |
+
"""Runs WordPiece tokenization."""
|
| 828 |
+
|
| 829 |
+
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
| 830 |
+
self.vocab = vocab
|
| 831 |
+
self.unk_token = unk_token
|
| 832 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
| 833 |
+
|
| 834 |
+
def tokenize(self, text):
|
| 835 |
+
"""
|
| 836 |
+
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
|
| 837 |
+
tokenization using the given vocabulary.
|
| 838 |
+
|
| 839 |
+
For example, `input = "unaffable"` will return as output `["un", "##aff", "##able"]`.
|
| 840 |
+
|
| 841 |
+
Args:
|
| 842 |
+
text: A single token or whitespace separated tokens. This should have
|
| 843 |
+
already been passed through *BasicTokenizer*.
|
| 844 |
+
|
| 845 |
+
Returns:
|
| 846 |
+
A list of wordpiece tokens.
|
| 847 |
+
"""
|
| 848 |
+
|
| 849 |
+
output_tokens = []
|
| 850 |
+
for token in whitespace_tokenize(text):
|
| 851 |
+
chars = list(token)
|
| 852 |
+
if len(chars) > self.max_input_chars_per_word:
|
| 853 |
+
output_tokens.append(self.unk_token)
|
| 854 |
+
continue
|
| 855 |
+
|
| 856 |
+
is_bad = False
|
| 857 |
+
start = 0
|
| 858 |
+
sub_tokens = []
|
| 859 |
+
while start < len(chars):
|
| 860 |
+
end = len(chars)
|
| 861 |
+
cur_substr = None
|
| 862 |
+
while start < end:
|
| 863 |
+
substr = "".join(chars[start:end])
|
| 864 |
+
if start > 0:
|
| 865 |
+
substr = "##" + substr
|
| 866 |
+
if substr in self.vocab:
|
| 867 |
+
cur_substr = substr
|
| 868 |
+
break
|
| 869 |
+
end -= 1
|
| 870 |
+
if cur_substr is None:
|
| 871 |
+
is_bad = True
|
| 872 |
+
break
|
| 873 |
+
sub_tokens.append(cur_substr)
|
| 874 |
+
start = end
|
| 875 |
+
|
| 876 |
+
if is_bad:
|
| 877 |
+
output_tokens.append(self.unk_token)
|
| 878 |
+
else:
|
| 879 |
+
output_tokens.extend(sub_tokens)
|
| 880 |
+
return output_tokens
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
class SentencepieceTokenizer:
|
| 884 |
+
"""
|
| 885 |
+
Runs sentencepiece tokenization. Based on transformers.models.albert.tokenization_albert.AlbertTokenizer.
|
| 886 |
+
"""
|
| 887 |
+
|
| 888 |
+
def __init__(
|
| 889 |
+
self,
|
| 890 |
+
vocab,
|
| 891 |
+
unk_token,
|
| 892 |
+
do_lower_case=False,
|
| 893 |
+
remove_space=True,
|
| 894 |
+
keep_accents=True,
|
| 895 |
+
sp_model_kwargs: Optional[dict[str, Any]] = None,
|
| 896 |
+
):
|
| 897 |
+
self.vocab = vocab
|
| 898 |
+
self.unk_token = unk_token
|
| 899 |
+
self.do_lower_case = do_lower_case
|
| 900 |
+
self.remove_space = remove_space
|
| 901 |
+
self.keep_accents = keep_accents
|
| 902 |
+
|
| 903 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 904 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 905 |
+
self.sp_model.Load(self.vocab)
|
| 906 |
+
|
| 907 |
+
def preprocess_text(self, inputs):
|
| 908 |
+
if self.remove_space:
|
| 909 |
+
outputs = " ".join(inputs.strip().split())
|
| 910 |
+
else:
|
| 911 |
+
outputs = inputs
|
| 912 |
+
outputs = outputs.replace("``", '"').replace("''", '"')
|
| 913 |
+
|
| 914 |
+
if not self.keep_accents:
|
| 915 |
+
outputs = unicodedata.normalize("NFKD", outputs)
|
| 916 |
+
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
|
| 917 |
+
if self.do_lower_case:
|
| 918 |
+
outputs = outputs.lower()
|
| 919 |
+
|
| 920 |
+
return outputs
|
| 921 |
+
|
| 922 |
+
def tokenize(self, text):
|
| 923 |
+
"""
|
| 924 |
+
Tokenizes text by sentencepiece. Based on [SentencePiece](https://github.com/google/sentencepiece).
|
| 925 |
+
Tokenization needs the given vocabulary.
|
| 926 |
+
|
| 927 |
+
Args:
|
| 928 |
+
text: A string needs to be tokenized.
|
| 929 |
+
|
| 930 |
+
Returns:
|
| 931 |
+
A list of sentencepiece tokens.
|
| 932 |
+
"""
|
| 933 |
+
text = self.preprocess_text(text)
|
| 934 |
+
pieces = self.sp_model.encode(text, out_type=str)
|
| 935 |
+
new_pieces = []
|
| 936 |
+
for piece in pieces:
|
| 937 |
+
if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
|
| 938 |
+
cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
|
| 939 |
+
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
|
| 940 |
+
if len(cur_pieces[0]) == 1:
|
| 941 |
+
cur_pieces = cur_pieces[1:]
|
| 942 |
+
else:
|
| 943 |
+
cur_pieces[0] = cur_pieces[0][1:]
|
| 944 |
+
cur_pieces.append(piece[-1])
|
| 945 |
+
new_pieces.extend(cur_pieces)
|
| 946 |
+
else:
|
| 947 |
+
new_pieces.append(piece)
|
| 948 |
+
|
| 949 |
+
return new_pieces
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
__all__ = ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"]
|
venv/lib/python3.13/site-packages/transformers/models/bertweet/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .tokenization_bertweet import *
|
| 22 |
+
else:
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
_file = globals()["__file__"]
|
| 26 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
venv/lib/python3.13/site-packages/transformers/models/bertweet/tokenization_bertweet.py
ADDED
|
@@ -0,0 +1,769 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2020, VinAI Research and the HuggingFace Inc. team.
|
| 3 |
+
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""Tokenization classes for BERTweet"""
|
| 17 |
+
|
| 18 |
+
import html
|
| 19 |
+
import os
|
| 20 |
+
import re
|
| 21 |
+
from shutil import copyfile
|
| 22 |
+
from typing import Optional
|
| 23 |
+
|
| 24 |
+
import regex
|
| 25 |
+
|
| 26 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 27 |
+
from ...utils import logging
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
VOCAB_FILES_NAMES = {
|
| 33 |
+
"vocab_file": "vocab.txt",
|
| 34 |
+
"merges_file": "bpe.codes",
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_pairs(word):
|
| 39 |
+
"""
|
| 40 |
+
Return set of symbol pairs in a word.
|
| 41 |
+
|
| 42 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 43 |
+
"""
|
| 44 |
+
pairs = set()
|
| 45 |
+
prev_char = word[0]
|
| 46 |
+
for char in word[1:]:
|
| 47 |
+
pairs.add((prev_char, char))
|
| 48 |
+
prev_char = char
|
| 49 |
+
|
| 50 |
+
pairs = set(pairs)
|
| 51 |
+
return pairs
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class BertweetTokenizer(PreTrainedTokenizer):
|
| 55 |
+
"""
|
| 56 |
+
Constructs a BERTweet tokenizer, using Byte-Pair-Encoding.
|
| 57 |
+
|
| 58 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 59 |
+
this superclass for more information regarding those methods.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
vocab_file (`str`):
|
| 63 |
+
Path to the vocabulary file.
|
| 64 |
+
merges_file (`str`):
|
| 65 |
+
Path to the merges file.
|
| 66 |
+
normalization (`bool`, *optional*, defaults to `False`):
|
| 67 |
+
Whether or not to apply a normalization preprocess.
|
| 68 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 69 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 70 |
+
|
| 71 |
+
<Tip>
|
| 72 |
+
|
| 73 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 74 |
+
sequence. The token used is the `cls_token`.
|
| 75 |
+
|
| 76 |
+
</Tip>
|
| 77 |
+
|
| 78 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 79 |
+
The end of sequence token.
|
| 80 |
+
|
| 81 |
+
<Tip>
|
| 82 |
+
|
| 83 |
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
| 84 |
+
The token used is the `sep_token`.
|
| 85 |
+
|
| 86 |
+
</Tip>
|
| 87 |
+
|
| 88 |
+
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
| 89 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 90 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 91 |
+
token of a sequence built with special tokens.
|
| 92 |
+
cls_token (`str`, *optional*, defaults to `"<s>"`):
|
| 93 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 94 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 95 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 96 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 97 |
+
token instead.
|
| 98 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 99 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 100 |
+
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
| 101 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 102 |
+
modeling. This is the token which the model will try to predict.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 106 |
+
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
vocab_file,
|
| 110 |
+
merges_file,
|
| 111 |
+
normalization=False,
|
| 112 |
+
bos_token="<s>",
|
| 113 |
+
eos_token="</s>",
|
| 114 |
+
sep_token="</s>",
|
| 115 |
+
cls_token="<s>",
|
| 116 |
+
unk_token="<unk>",
|
| 117 |
+
pad_token="<pad>",
|
| 118 |
+
mask_token="<mask>",
|
| 119 |
+
**kwargs,
|
| 120 |
+
):
|
| 121 |
+
try:
|
| 122 |
+
from emoji import demojize
|
| 123 |
+
|
| 124 |
+
self.demojizer = demojize
|
| 125 |
+
except ImportError:
|
| 126 |
+
logger.warning(
|
| 127 |
+
"emoji is not installed, thus not converting emoticons or emojis into text. Install emoji: pip3"
|
| 128 |
+
" install emoji==0.6.0"
|
| 129 |
+
)
|
| 130 |
+
self.demojizer = None
|
| 131 |
+
|
| 132 |
+
self.vocab_file = vocab_file
|
| 133 |
+
self.merges_file = merges_file
|
| 134 |
+
|
| 135 |
+
self.encoder = {}
|
| 136 |
+
self.encoder[str(bos_token)] = 0
|
| 137 |
+
self.encoder[str(pad_token)] = 1
|
| 138 |
+
self.encoder[str(eos_token)] = 2
|
| 139 |
+
self.encoder[str(unk_token)] = 3
|
| 140 |
+
|
| 141 |
+
self.add_from_file(vocab_file)
|
| 142 |
+
|
| 143 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 144 |
+
|
| 145 |
+
with open(merges_file, encoding="utf-8") as merges_handle:
|
| 146 |
+
merges = merges_handle.read().split("\n")[:-1]
|
| 147 |
+
merges = [tuple(merge.split()[:-1]) for merge in merges]
|
| 148 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 149 |
+
self.cache = {}
|
| 150 |
+
|
| 151 |
+
self.normalization = normalization
|
| 152 |
+
self.tweetPreprocessor = TweetTokenizer()
|
| 153 |
+
self.special_puncts = {"’": "'", "…": "..."}
|
| 154 |
+
|
| 155 |
+
super().__init__(
|
| 156 |
+
normalization=normalization,
|
| 157 |
+
bos_token=bos_token,
|
| 158 |
+
eos_token=eos_token,
|
| 159 |
+
sep_token=sep_token,
|
| 160 |
+
cls_token=cls_token,
|
| 161 |
+
unk_token=unk_token,
|
| 162 |
+
pad_token=pad_token,
|
| 163 |
+
mask_token=mask_token,
|
| 164 |
+
**kwargs,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def build_inputs_with_special_tokens(
|
| 168 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 169 |
+
) -> list[int]:
|
| 170 |
+
"""
|
| 171 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 172 |
+
adding special tokens. A BERTweet sequence has the following format:
|
| 173 |
+
|
| 174 |
+
- single sequence: `<s> X </s>`
|
| 175 |
+
- pair of sequences: `<s> A </s></s> B </s>`
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
token_ids_0 (`list[int]`):
|
| 179 |
+
List of IDs to which the special tokens will be added.
|
| 180 |
+
token_ids_1 (`list[int]`, *optional*):
|
| 181 |
+
Optional second list of IDs for sequence pairs.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
`list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
if token_ids_1 is None:
|
| 188 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 189 |
+
cls = [self.cls_token_id]
|
| 190 |
+
sep = [self.sep_token_id]
|
| 191 |
+
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
|
| 192 |
+
|
| 193 |
+
def get_special_tokens_mask(
|
| 194 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
|
| 195 |
+
) -> list[int]:
|
| 196 |
+
"""
|
| 197 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 198 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
token_ids_0 (`list[int]`):
|
| 202 |
+
List of IDs.
|
| 203 |
+
token_ids_1 (`list[int]`, *optional*):
|
| 204 |
+
Optional second list of IDs for sequence pairs.
|
| 205 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 206 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
`list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
if already_has_special_tokens:
|
| 213 |
+
return super().get_special_tokens_mask(
|
| 214 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if token_ids_1 is None:
|
| 218 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 219 |
+
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
|
| 220 |
+
|
| 221 |
+
def create_token_type_ids_from_sequences(
|
| 222 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 223 |
+
) -> list[int]:
|
| 224 |
+
"""
|
| 225 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. BERTweet does
|
| 226 |
+
not make use of token type ids, therefore a list of zeros is returned.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
token_ids_0 (`list[int]`):
|
| 230 |
+
List of IDs.
|
| 231 |
+
token_ids_1 (`list[int]`, *optional*):
|
| 232 |
+
Optional second list of IDs for sequence pairs.
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
`list[int]`: List of zeros.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
sep = [self.sep_token_id]
|
| 239 |
+
cls = [self.cls_token_id]
|
| 240 |
+
|
| 241 |
+
if token_ids_1 is None:
|
| 242 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 243 |
+
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
| 244 |
+
|
| 245 |
+
@property
|
| 246 |
+
def vocab_size(self):
|
| 247 |
+
return len(self.encoder)
|
| 248 |
+
|
| 249 |
+
def get_vocab(self):
|
| 250 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
| 251 |
+
|
| 252 |
+
def bpe(self, token):
|
| 253 |
+
if token in self.cache:
|
| 254 |
+
return self.cache[token]
|
| 255 |
+
word = tuple(token)
|
| 256 |
+
word = tuple(list(word[:-1]) + [word[-1] + "</w>"])
|
| 257 |
+
pairs = get_pairs(word)
|
| 258 |
+
|
| 259 |
+
if not pairs:
|
| 260 |
+
return token
|
| 261 |
+
|
| 262 |
+
while True:
|
| 263 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
| 264 |
+
if bigram not in self.bpe_ranks:
|
| 265 |
+
break
|
| 266 |
+
first, second = bigram
|
| 267 |
+
new_word = []
|
| 268 |
+
i = 0
|
| 269 |
+
while i < len(word):
|
| 270 |
+
try:
|
| 271 |
+
j = word.index(first, i)
|
| 272 |
+
except ValueError:
|
| 273 |
+
new_word.extend(word[i:])
|
| 274 |
+
break
|
| 275 |
+
else:
|
| 276 |
+
new_word.extend(word[i:j])
|
| 277 |
+
i = j
|
| 278 |
+
|
| 279 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 280 |
+
new_word.append(first + second)
|
| 281 |
+
i += 2
|
| 282 |
+
else:
|
| 283 |
+
new_word.append(word[i])
|
| 284 |
+
i += 1
|
| 285 |
+
new_word = tuple(new_word)
|
| 286 |
+
word = new_word
|
| 287 |
+
if len(word) == 1:
|
| 288 |
+
break
|
| 289 |
+
else:
|
| 290 |
+
pairs = get_pairs(word)
|
| 291 |
+
word = "@@ ".join(word)
|
| 292 |
+
word = word[:-4]
|
| 293 |
+
self.cache[token] = word
|
| 294 |
+
return word
|
| 295 |
+
|
| 296 |
+
def _tokenize(self, text):
|
| 297 |
+
"""Tokenize a string."""
|
| 298 |
+
if self.normalization: # Perform Tweet normalization before performing BPE
|
| 299 |
+
text = self.normalizeTweet(text)
|
| 300 |
+
|
| 301 |
+
split_tokens = []
|
| 302 |
+
words = re.findall(r"\S+\n?", text)
|
| 303 |
+
for token in words:
|
| 304 |
+
split_tokens.extend(list(self.bpe(token).split(" ")))
|
| 305 |
+
return split_tokens
|
| 306 |
+
|
| 307 |
+
def normalizeTweet(self, tweet):
|
| 308 |
+
"""
|
| 309 |
+
Normalize a raw Tweet
|
| 310 |
+
"""
|
| 311 |
+
for punct in self.special_puncts:
|
| 312 |
+
tweet = tweet.replace(punct, self.special_puncts[punct])
|
| 313 |
+
|
| 314 |
+
tokens = self.tweetPreprocessor.tokenize(tweet)
|
| 315 |
+
normTweet = " ".join([self.normalizeToken(token) for token in tokens])
|
| 316 |
+
|
| 317 |
+
normTweet = (
|
| 318 |
+
normTweet.replace("cannot ", "can not ")
|
| 319 |
+
.replace("n't ", " n't ")
|
| 320 |
+
.replace("n 't ", " n't ")
|
| 321 |
+
.replace("ca n't", "can't")
|
| 322 |
+
.replace("ai n't", "ain't")
|
| 323 |
+
)
|
| 324 |
+
normTweet = (
|
| 325 |
+
normTweet.replace("'m ", " 'm ")
|
| 326 |
+
.replace("'re ", " 're ")
|
| 327 |
+
.replace("'s ", " 's ")
|
| 328 |
+
.replace("'ll ", " 'll ")
|
| 329 |
+
.replace("'d ", " 'd ")
|
| 330 |
+
.replace("'ve ", " 've ")
|
| 331 |
+
)
|
| 332 |
+
normTweet = (
|
| 333 |
+
normTweet.replace(" p . m .", " p.m.")
|
| 334 |
+
.replace(" p . m ", " p.m ")
|
| 335 |
+
.replace(" a . m .", " a.m.")
|
| 336 |
+
.replace(" a . m ", " a.m ")
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
return " ".join(normTweet.split())
|
| 340 |
+
|
| 341 |
+
def normalizeToken(self, token):
|
| 342 |
+
"""
|
| 343 |
+
Normalize tokens in a Tweet
|
| 344 |
+
"""
|
| 345 |
+
lowercased_token = token.lower()
|
| 346 |
+
if token.startswith("@"):
|
| 347 |
+
return "@USER"
|
| 348 |
+
elif lowercased_token.startswith("http") or lowercased_token.startswith("www"):
|
| 349 |
+
return "HTTPURL"
|
| 350 |
+
elif len(token) == 1:
|
| 351 |
+
if token in self.special_puncts:
|
| 352 |
+
return self.special_puncts[token]
|
| 353 |
+
if self.demojizer is not None:
|
| 354 |
+
return self.demojizer(token)
|
| 355 |
+
else:
|
| 356 |
+
return token
|
| 357 |
+
else:
|
| 358 |
+
return token
|
| 359 |
+
|
| 360 |
+
def _convert_token_to_id(self, token):
|
| 361 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 362 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
| 363 |
+
|
| 364 |
+
def _convert_id_to_token(self, index):
|
| 365 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 366 |
+
return self.decoder.get(index, self.unk_token)
|
| 367 |
+
|
| 368 |
+
def convert_tokens_to_string(self, tokens):
|
| 369 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 370 |
+
out_string = " ".join(tokens).replace("@@ ", "").strip()
|
| 371 |
+
return out_string
|
| 372 |
+
|
| 373 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
|
| 374 |
+
if not os.path.isdir(save_directory):
|
| 375 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 376 |
+
return
|
| 377 |
+
out_vocab_file = os.path.join(
|
| 378 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 379 |
+
)
|
| 380 |
+
out_merge_file = os.path.join(
|
| 381 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 385 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 386 |
+
elif not os.path.isfile(self.vocab_file):
|
| 387 |
+
with open(out_vocab_file, "wb") as fi:
|
| 388 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 389 |
+
fi.write(content_spiece_model)
|
| 390 |
+
|
| 391 |
+
if os.path.abspath(self.merges_file) != os.path.abspath(out_merge_file):
|
| 392 |
+
copyfile(self.merges_file, out_merge_file)
|
| 393 |
+
|
| 394 |
+
return out_vocab_file, out_merge_file
|
| 395 |
+
|
| 396 |
+
# def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
| 397 |
+
# filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens))
|
| 398 |
+
# tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens)
|
| 399 |
+
# tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far)
|
| 400 |
+
# return ''.join(tokens_generated_so_far)
|
| 401 |
+
|
| 402 |
+
def add_from_file(self, f):
|
| 403 |
+
"""
|
| 404 |
+
Loads a pre-existing dictionary from a text file and adds its symbols to this instance.
|
| 405 |
+
"""
|
| 406 |
+
if isinstance(f, str):
|
| 407 |
+
try:
|
| 408 |
+
with open(f, "r", encoding="utf-8") as fd:
|
| 409 |
+
self.add_from_file(fd)
|
| 410 |
+
except FileNotFoundError as fnfe:
|
| 411 |
+
raise fnfe
|
| 412 |
+
except UnicodeError:
|
| 413 |
+
raise Exception(f"Incorrect encoding detected in {f}, please rebuild the dataset")
|
| 414 |
+
return
|
| 415 |
+
|
| 416 |
+
lines = f.readlines()
|
| 417 |
+
for lineTmp in lines:
|
| 418 |
+
line = lineTmp.strip()
|
| 419 |
+
idx = line.rfind(" ")
|
| 420 |
+
if idx == -1:
|
| 421 |
+
raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
|
| 422 |
+
word = line[:idx]
|
| 423 |
+
self.encoder[word] = len(self.encoder)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
# Natural Language Toolkit: Twitter Tokenizer
|
| 427 |
+
#
|
| 428 |
+
# Copyright (C) 2001-2020 NLTK Project
|
| 429 |
+
# Author: Christopher Potts <cgpotts@stanford.edu>
|
| 430 |
+
# Ewan Klein <ewan@inf.ed.ac.uk> (modifications)
|
| 431 |
+
# Pierpaolo Pantone <> (modifications)
|
| 432 |
+
# URL: http://nltk.org/
|
| 433 |
+
# For license information, see LICENSE.TXT
|
| 434 |
+
#
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
"""
|
| 438 |
+
Twitter-aware tokenizer, designed to be flexible and easy to adapt to new domains and tasks. The basic logic is this:
|
| 439 |
+
|
| 440 |
+
1. The tuple regex_strings defines a list of regular expression strings.
|
| 441 |
+
|
| 442 |
+
2. The regex_strings strings are put, in order, into a compiled regular expression object called word_re.
|
| 443 |
+
|
| 444 |
+
3. The tokenization is done by word_re.findall(s), where s is the user-supplied string, inside the tokenize() method of
|
| 445 |
+
the class Tokenizer.
|
| 446 |
+
|
| 447 |
+
4. When instantiating Tokenizer objects, there is a single option: preserve_case. By default, it is set to True. If it
|
| 448 |
+
is set to False, then the tokenizer will lowercase everything except for emoticons.
|
| 449 |
+
|
| 450 |
+
"""
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
######################################################################
|
| 454 |
+
#
|
| 455 |
+
# import regex # https://github.com/nltk/nltk/issues/2409
|
| 456 |
+
# import html
|
| 457 |
+
#
|
| 458 |
+
######################################################################
|
| 459 |
+
# The following strings are components in the regular expression
|
| 460 |
+
# that is used for tokenizing. It's important that phone_number
|
| 461 |
+
# appears first in the final regex (since it can contain whitespace).
|
| 462 |
+
# It also could matter that tags comes after emoticons, due to the
|
| 463 |
+
# possibility of having text like
|
| 464 |
+
#
|
| 465 |
+
# <:| and some text >:)
|
| 466 |
+
#
|
| 467 |
+
# Most importantly, the final element should always be last, since it
|
| 468 |
+
# does a last ditch whitespace-based tokenization of whatever is left.
|
| 469 |
+
|
| 470 |
+
# ToDo: Update with http://en.wikipedia.org/wiki/List_of_emoticons ?
|
| 471 |
+
|
| 472 |
+
# This particular element is used in a couple ways, so we define it
|
| 473 |
+
# with a name:
|
| 474 |
+
# docstyle-ignore
|
| 475 |
+
EMOTICONS = r"""
|
| 476 |
+
(?:
|
| 477 |
+
[<>]?
|
| 478 |
+
[:;=8] # eyes
|
| 479 |
+
[\-o\*\']? # optional nose
|
| 480 |
+
[\)\]\(\[dDpP/\:\}\{@\|\\] # mouth
|
| 481 |
+
|
|
| 482 |
+
[\)\]\(\[dDpP/\:\}\{@\|\\] # mouth
|
| 483 |
+
[\-o\*\']? # optional nose
|
| 484 |
+
[:;=8] # eyes
|
| 485 |
+
[<>]?
|
| 486 |
+
|
|
| 487 |
+
<3 # heart
|
| 488 |
+
)"""
|
| 489 |
+
|
| 490 |
+
# URL pattern due to John Gruber, modified by Tom Winzig. See
|
| 491 |
+
# https://gist.github.com/winzig/8894715
|
| 492 |
+
# docstyle-ignore
|
| 493 |
+
URLS = r""" # Capture 1: entire matched URL
|
| 494 |
+
(?:
|
| 495 |
+
https?: # URL protocol and colon
|
| 496 |
+
(?:
|
| 497 |
+
/{1,3} # 1-3 slashes
|
| 498 |
+
| # or
|
| 499 |
+
[a-z0-9%] # Single letter or digit or '%'
|
| 500 |
+
# (Trying not to match e.g. "URI::Escape")
|
| 501 |
+
)
|
| 502 |
+
| # or
|
| 503 |
+
# looks like domain name followed by a slash:
|
| 504 |
+
[a-z0-9.\-]+[.]
|
| 505 |
+
(?:[a-z]{2,13})
|
| 506 |
+
/
|
| 507 |
+
)
|
| 508 |
+
(?: # One or more:
|
| 509 |
+
[^\s()<>{}\[\]]+ # Run of non-space, non-()<>{}[]
|
| 510 |
+
| # or
|
| 511 |
+
\([^\s()]*?\([^\s()]+\)[^\s()]*?\) # balanced parens, one level deep: (...(...)...)
|
| 512 |
+
|
|
| 513 |
+
\([^\s]+?\) # balanced parens, non-recursive: (...)
|
| 514 |
+
)+
|
| 515 |
+
(?: # End with:
|
| 516 |
+
\([^\s()]*?\([^\s()]+\)[^\s()]*?\) # balanced parens, one level deep: (...(...)...)
|
| 517 |
+
|
|
| 518 |
+
\([^\s]+?\) # balanced parens, non-recursive: (...)
|
| 519 |
+
| # or
|
| 520 |
+
[^\s`!()\[\]{};:'".,<>?«»“”‘’] # not a space or one of these punct chars
|
| 521 |
+
)
|
| 522 |
+
| # OR, the following to match naked domains:
|
| 523 |
+
(?:
|
| 524 |
+
(?<!@) # not preceded by a @, avoid matching foo@_gmail.com_
|
| 525 |
+
[a-z0-9]+
|
| 526 |
+
(?:[.\-][a-z0-9]+)*
|
| 527 |
+
[.]
|
| 528 |
+
(?:[a-z]{2,13})
|
| 529 |
+
\b
|
| 530 |
+
/?
|
| 531 |
+
(?!@) # not succeeded by a @,
|
| 532 |
+
# avoid matching "foo.na" in "foo.na@example.com"
|
| 533 |
+
)
|
| 534 |
+
"""
|
| 535 |
+
|
| 536 |
+
# docstyle-ignore
|
| 537 |
+
# The components of the tokenizer:
|
| 538 |
+
REGEXPS = (
|
| 539 |
+
URLS,
|
| 540 |
+
# Phone numbers:
|
| 541 |
+
r"""
|
| 542 |
+
(?:
|
| 543 |
+
(?: # (international)
|
| 544 |
+
\+?[01]
|
| 545 |
+
[ *\-.\)]*
|
| 546 |
+
)?
|
| 547 |
+
(?: # (area code)
|
| 548 |
+
[\(]?
|
| 549 |
+
\d{3}
|
| 550 |
+
[ *\-.\)]*
|
| 551 |
+
)?
|
| 552 |
+
\d{3} # exchange
|
| 553 |
+
[ *\-.\)]*
|
| 554 |
+
\d{4} # base
|
| 555 |
+
)""",
|
| 556 |
+
# ASCII Emoticons
|
| 557 |
+
EMOTICONS,
|
| 558 |
+
# HTML tags:
|
| 559 |
+
r"""<[^>\s]+>""",
|
| 560 |
+
# ASCII Arrows
|
| 561 |
+
r"""[\-]+>|<[\-]+""",
|
| 562 |
+
# Twitter username:
|
| 563 |
+
r"""(?:@[\w_]+)""",
|
| 564 |
+
# Twitter hashtags:
|
| 565 |
+
r"""(?:\#+[\w_]+[\w\'_\-]*[\w_]+)""",
|
| 566 |
+
# email addresses
|
| 567 |
+
r"""[\w.+-]+@[\w-]+\.(?:[\w-]\.?)+[\w-]""",
|
| 568 |
+
# docstyle-ignore
|
| 569 |
+
# Remaining word types:
|
| 570 |
+
r"""
|
| 571 |
+
(?:[^\W\d_](?:[^\W\d_]|['\-_])+[^\W\d_]) # Words with apostrophes or dashes.
|
| 572 |
+
|
|
| 573 |
+
(?:[+\-]?\d+[,/.:-]\d+[+\-]?) # Numbers, including fractions, decimals.
|
| 574 |
+
|
|
| 575 |
+
(?:[\w_]+) # Words without apostrophes or dashes.
|
| 576 |
+
|
|
| 577 |
+
(?:\.(?:\s*\.){1,}) # Ellipsis dots.
|
| 578 |
+
|
|
| 579 |
+
(?:\S) # Everything else that isn't whitespace.
|
| 580 |
+
""",
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
######################################################################
|
| 584 |
+
# This is the core tokenizing regex:
|
| 585 |
+
|
| 586 |
+
WORD_RE = regex.compile(r"""(%s)""" % "|".join(REGEXPS), regex.VERBOSE | regex.I | regex.UNICODE)
|
| 587 |
+
|
| 588 |
+
# WORD_RE performs poorly on these patterns:
|
| 589 |
+
HANG_RE = regex.compile(r"([^a-zA-Z0-9])\1{3,}")
|
| 590 |
+
|
| 591 |
+
# The emoticon string gets its own regex so that we can preserve case for
|
| 592 |
+
# them as needed:
|
| 593 |
+
EMOTICON_RE = regex.compile(EMOTICONS, regex.VERBOSE | regex.I | regex.UNICODE)
|
| 594 |
+
|
| 595 |
+
# These are for regularizing HTML entities to Unicode:
|
| 596 |
+
ENT_RE = regex.compile(r"&(#?(x?))([^&;\s]+);")
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
######################################################################
|
| 600 |
+
# Functions for converting html entities
|
| 601 |
+
######################################################################
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def _str_to_unicode(text, encoding=None, errors="strict"):
|
| 605 |
+
if encoding is None:
|
| 606 |
+
encoding = "utf-8"
|
| 607 |
+
if isinstance(text, bytes):
|
| 608 |
+
return text.decode(encoding, errors)
|
| 609 |
+
return text
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def _replace_html_entities(text, keep=(), remove_illegal=True, encoding="utf-8"):
|
| 613 |
+
"""
|
| 614 |
+
Remove entities from text by converting them to their corresponding unicode character.
|
| 615 |
+
|
| 616 |
+
Args:
|
| 617 |
+
text:
|
| 618 |
+
A unicode string or a byte string encoded in the given *encoding* (which defaults to 'utf-8').
|
| 619 |
+
keep (list):
|
| 620 |
+
List of entity names which should not be replaced. This supports both numeric entities (`&#nnnn;` and
|
| 621 |
+
`&#hhhh;`) and named entities (such as ` ` or `>`).
|
| 622 |
+
remove_illegal (bool):
|
| 623 |
+
If `True`, entities that can't be converted are removed. Otherwise, entities that can't be converted are
|
| 624 |
+
kept "as is".
|
| 625 |
+
|
| 626 |
+
Returns: A unicode string with the entities removed.
|
| 627 |
+
|
| 628 |
+
See https://github.com/scrapy/w3lib/blob/master/w3lib/html.py
|
| 629 |
+
|
| 630 |
+
Examples:
|
| 631 |
+
|
| 632 |
+
```python
|
| 633 |
+
>>> from nltk.tokenize.casual import _replace_html_entities
|
| 634 |
+
|
| 635 |
+
>>> _replace_html_entities(b"Price: £100")
|
| 636 |
+
'Price: \\xa3100'
|
| 637 |
+
|
| 638 |
+
>>> print(_replace_html_entities(b"Price: £100"))
|
| 639 |
+
Price: £100
|
| 640 |
+
```"""
|
| 641 |
+
|
| 642 |
+
def _convert_entity(match):
|
| 643 |
+
entity_body = match.group(3)
|
| 644 |
+
if match.group(1):
|
| 645 |
+
try:
|
| 646 |
+
if match.group(2):
|
| 647 |
+
number = int(entity_body, 16)
|
| 648 |
+
else:
|
| 649 |
+
number = int(entity_body, 10)
|
| 650 |
+
# Numeric character references in the 80-9F range are typically
|
| 651 |
+
# interpreted by browsers as representing the characters mapped
|
| 652 |
+
# to bytes 80-9F in the Windows-1252 encoding. For more info
|
| 653 |
+
# see: https://en.wikipedia.org/wiki/ISO/IEC_8859-1#Similar_character_sets
|
| 654 |
+
if 0x80 <= number <= 0x9F:
|
| 655 |
+
return bytes((number,)).decode("cp1252")
|
| 656 |
+
except ValueError:
|
| 657 |
+
number = None
|
| 658 |
+
else:
|
| 659 |
+
if entity_body in keep:
|
| 660 |
+
return match.group(0)
|
| 661 |
+
else:
|
| 662 |
+
number = html.entities.name2codepoint.get(entity_body)
|
| 663 |
+
if number is not None:
|
| 664 |
+
try:
|
| 665 |
+
return chr(number)
|
| 666 |
+
except (ValueError, OverflowError):
|
| 667 |
+
pass
|
| 668 |
+
|
| 669 |
+
return "" if remove_illegal else match.group(0)
|
| 670 |
+
|
| 671 |
+
return ENT_RE.sub(_convert_entity, _str_to_unicode(text, encoding))
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
######################################################################
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
class TweetTokenizer:
|
| 678 |
+
r"""
|
| 679 |
+
Examples:
|
| 680 |
+
|
| 681 |
+
```python
|
| 682 |
+
>>> # Tokenizer for tweets.
|
| 683 |
+
>>> from nltk.tokenize import TweetTokenizer
|
| 684 |
+
|
| 685 |
+
>>> tknzr = TweetTokenizer()
|
| 686 |
+
>>> s0 = "This is a cooool #dummysmiley: :-) :-P <3 and some arrows < > -> <--"
|
| 687 |
+
>>> tknzr.tokenize(s0)
|
| 688 |
+
['This', 'is', 'a', 'cooool', '#dummysmiley', ':', ':-)', ':-P', '<3', 'and', 'some', 'arrows', '<', '>', '->', '<--']
|
| 689 |
+
|
| 690 |
+
>>> # Examples using *strip_handles* and *reduce_len parameters*:
|
| 691 |
+
>>> tknzr = TweetTokenizer(strip_handles=True, reduce_len=True)
|
| 692 |
+
>>> s1 = "@remy: This is waaaaayyyy too much for you!!!!!!"
|
| 693 |
+
>>> tknzr.tokenize(s1)
|
| 694 |
+
[':', 'This', 'is', 'waaayyy', 'too', 'much', 'for', 'you', '!', '!', '!']
|
| 695 |
+
```"""
|
| 696 |
+
|
| 697 |
+
def __init__(self, preserve_case=True, reduce_len=False, strip_handles=False):
|
| 698 |
+
self.preserve_case = preserve_case
|
| 699 |
+
self.reduce_len = reduce_len
|
| 700 |
+
self.strip_handles = strip_handles
|
| 701 |
+
|
| 702 |
+
def tokenize(self, text):
|
| 703 |
+
"""
|
| 704 |
+
Args:
|
| 705 |
+
text: str
|
| 706 |
+
|
| 707 |
+
Returns: list(str) A tokenized list of strings; concatenating this list returns the original string if
|
| 708 |
+
`preserve_case=False`
|
| 709 |
+
"""
|
| 710 |
+
# Fix HTML character entities:
|
| 711 |
+
text = _replace_html_entities(text)
|
| 712 |
+
# Remove username handles
|
| 713 |
+
if self.strip_handles:
|
| 714 |
+
text = remove_handles(text)
|
| 715 |
+
# Normalize word lengthening
|
| 716 |
+
if self.reduce_len:
|
| 717 |
+
text = reduce_lengthening(text)
|
| 718 |
+
# Shorten problematic sequences of characters
|
| 719 |
+
safe_text = HANG_RE.sub(r"\1\1\1", text)
|
| 720 |
+
# Tokenize:
|
| 721 |
+
words = WORD_RE.findall(safe_text)
|
| 722 |
+
# Possibly alter the case, but avoid changing emoticons like :D into :d:
|
| 723 |
+
if not self.preserve_case:
|
| 724 |
+
words = [x if EMOTICON_RE.search(x) else x.lower() for x in words]
|
| 725 |
+
return words
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
######################################################################
|
| 729 |
+
# Normalization Functions
|
| 730 |
+
######################################################################
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
def reduce_lengthening(text):
|
| 734 |
+
"""
|
| 735 |
+
Replace repeated character sequences of length 3 or greater with sequences of length 3.
|
| 736 |
+
"""
|
| 737 |
+
pattern = regex.compile(r"(.)\1{2,}")
|
| 738 |
+
return pattern.sub(r"\1\1\1", text)
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def remove_handles(text):
|
| 742 |
+
"""
|
| 743 |
+
Remove Twitter username handles from text.
|
| 744 |
+
"""
|
| 745 |
+
pattern = regex.compile(
|
| 746 |
+
r"(?<![A-Za-z0-9_!@#\$%&*])@(([A-Za-z0-9_]){20}(?!@))|(?<![A-Za-z0-9_!@#\$%&*])@(([A-Za-z0-9_]){1,19})(?![A-Za-z0-9_]*@)"
|
| 747 |
+
)
|
| 748 |
+
# Substitute handles with ' ' to ensure that text on either side of removed handles are tokenized correctly
|
| 749 |
+
return pattern.sub(" ", text)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
######################################################################
|
| 753 |
+
# Tokenization Function
|
| 754 |
+
######################################################################
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
def casual_tokenize(text, preserve_case=True, reduce_len=False, strip_handles=False):
|
| 758 |
+
"""
|
| 759 |
+
Convenience function for wrapping the tokenizer.
|
| 760 |
+
"""
|
| 761 |
+
return TweetTokenizer(preserve_case=preserve_case, reduce_len=reduce_len, strip_handles=strip_handles).tokenize(
|
| 762 |
+
text
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
###############################################################################
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
__all__ = ["BertweetTokenizer"]
|
venv/lib/python3.13/site-packages/transformers/models/biogpt/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_biogpt import *
|
| 22 |
+
from .modeling_biogpt import *
|
| 23 |
+
from .tokenization_biogpt import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
venv/lib/python3.13/site-packages/transformers/models/biogpt/configuration_biogpt.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""BioGPT model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class BioGptConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`BioGptModel`]. It is used to instantiate an
|
| 27 |
+
BioGPT model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 28 |
+
with the defaults will yield a similar configuration to that of the BioGPT
|
| 29 |
+
[microsoft/biogpt](https://huggingface.co/microsoft/biogpt) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size (`int`, *optional*, defaults to 42384):
|
| 37 |
+
Vocabulary size of the BioGPT model. Defines the number of different tokens that can be represented by the
|
| 38 |
+
`inputs_ids` passed when calling [`BioGptModel`].
|
| 39 |
+
hidden_size (`int`, *optional*, defaults to 1024):
|
| 40 |
+
Dimension of the encoder layers and the pooler layer.
|
| 41 |
+
num_hidden_layers (`int`, *optional*, defaults to 24):
|
| 42 |
+
Number of hidden layers in the Transformer encoder.
|
| 43 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 44 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 45 |
+
intermediate_size (`int`, *optional*, defaults to 4096):
|
| 46 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 47 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 48 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 49 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 50 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 51 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 52 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 53 |
+
The dropout ratio for the attention probabilities.
|
| 54 |
+
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
| 55 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 56 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 57 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 58 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 59 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 60 |
+
The epsilon used by the layer normalization layers.
|
| 61 |
+
scale_embedding (`bool`, *optional*, defaults to `True`):
|
| 62 |
+
Scale embeddings by diving by sqrt(d_model).
|
| 63 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 64 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 65 |
+
relevant if `config.is_decoder=True`.
|
| 66 |
+
layerdrop (`float`, *optional*, defaults to 0.0):
|
| 67 |
+
Please refer to the paper about LayerDrop: https://huggingface.co/papers/1909.11556 for further details
|
| 68 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 69 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 70 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 71 |
+
Padding token id.
|
| 72 |
+
bos_token_id (`int`, *optional*, defaults to 0):
|
| 73 |
+
Beginning of stream token id.
|
| 74 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 75 |
+
End of stream token id.
|
| 76 |
+
|
| 77 |
+
Example:
|
| 78 |
+
|
| 79 |
+
```python
|
| 80 |
+
>>> from transformers import BioGptModel, BioGptConfig
|
| 81 |
+
|
| 82 |
+
>>> # Initializing a BioGPT microsoft/biogpt style configuration
|
| 83 |
+
>>> configuration = BioGptConfig()
|
| 84 |
+
|
| 85 |
+
>>> # Initializing a model from the microsoft/biogpt style configuration
|
| 86 |
+
>>> model = BioGptModel(configuration)
|
| 87 |
+
|
| 88 |
+
>>> # Accessing the model configuration
|
| 89 |
+
>>> configuration = model.config
|
| 90 |
+
```"""
|
| 91 |
+
|
| 92 |
+
model_type = "biogpt"
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
vocab_size=42384,
|
| 97 |
+
hidden_size=1024,
|
| 98 |
+
num_hidden_layers=24,
|
| 99 |
+
num_attention_heads=16,
|
| 100 |
+
intermediate_size=4096,
|
| 101 |
+
hidden_act="gelu",
|
| 102 |
+
hidden_dropout_prob=0.1,
|
| 103 |
+
attention_probs_dropout_prob=0.1,
|
| 104 |
+
max_position_embeddings=1024,
|
| 105 |
+
initializer_range=0.02,
|
| 106 |
+
layer_norm_eps=1e-12,
|
| 107 |
+
scale_embedding=True,
|
| 108 |
+
use_cache=True,
|
| 109 |
+
layerdrop=0.0,
|
| 110 |
+
activation_dropout=0.0,
|
| 111 |
+
pad_token_id=1,
|
| 112 |
+
bos_token_id=0,
|
| 113 |
+
eos_token_id=2,
|
| 114 |
+
**kwargs,
|
| 115 |
+
):
|
| 116 |
+
self.vocab_size = vocab_size
|
| 117 |
+
self.max_position_embeddings = max_position_embeddings
|
| 118 |
+
self.hidden_size = hidden_size
|
| 119 |
+
self.num_hidden_layers = num_hidden_layers
|
| 120 |
+
self.num_attention_heads = num_attention_heads
|
| 121 |
+
self.intermediate_size = intermediate_size
|
| 122 |
+
self.hidden_act = hidden_act
|
| 123 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 124 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 125 |
+
self.initializer_range = initializer_range
|
| 126 |
+
self.layer_norm_eps = layer_norm_eps
|
| 127 |
+
self.scale_embedding = scale_embedding
|
| 128 |
+
self.use_cache = use_cache
|
| 129 |
+
self.layerdrop = layerdrop
|
| 130 |
+
self.activation_dropout = activation_dropout
|
| 131 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
__all__ = ["BioGptConfig"]
|
venv/lib/python3.13/site-packages/transformers/models/biogpt/modeling_biogpt.py
ADDED
|
@@ -0,0 +1,967 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/biogpt/modular_biogpt.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_biogpt.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
import math
|
| 23 |
+
from typing import Callable, Optional, Union
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 28 |
+
|
| 29 |
+
from ...activations import ACT2FN
|
| 30 |
+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
| 31 |
+
from ...generation import GenerationMixin
|
| 32 |
+
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
| 33 |
+
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
| 34 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 35 |
+
from ...modeling_outputs import (
|
| 36 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 37 |
+
CausalLMOutputWithCrossAttentions,
|
| 38 |
+
SequenceClassifierOutputWithPast,
|
| 39 |
+
TokenClassifierOutput,
|
| 40 |
+
)
|
| 41 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 42 |
+
from ...processing_utils import Unpack
|
| 43 |
+
from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available, logging
|
| 44 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 45 |
+
from .configuration_biogpt import BioGptConfig
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if is_torch_flex_attn_available():
|
| 49 |
+
from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
logger = logging.get_logger(__name__)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class BioGptLearnedPositionalEmbedding(nn.Embedding):
|
| 56 |
+
"""
|
| 57 |
+
This module learns positional embeddings up to a fixed maximum size.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, num_embeddings: int, embedding_dim: int):
|
| 61 |
+
# BIOGPT is set up so that if padding_idx is specified then offset the embedding ids by 2
|
| 62 |
+
# and adjust num_embeddings appropriately. Other models don't have this hack
|
| 63 |
+
self.offset = 2
|
| 64 |
+
super().__init__(num_embeddings + self.offset, embedding_dim)
|
| 65 |
+
|
| 66 |
+
def forward(
|
| 67 |
+
self,
|
| 68 |
+
attention_mask: torch.LongTensor,
|
| 69 |
+
past_key_values_length: int = 0,
|
| 70 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 71 |
+
):
|
| 72 |
+
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
| 73 |
+
|
| 74 |
+
if position_ids is None:
|
| 75 |
+
position_ids = torch.cumsum(attention_mask, dim=1)
|
| 76 |
+
position_ids = (position_ids * attention_mask - 1).long()
|
| 77 |
+
# cut positions if `past_key_values_length` is > 0
|
| 78 |
+
position_ids = position_ids[:, past_key_values_length:]
|
| 79 |
+
|
| 80 |
+
return super().forward(position_ids + self.offset)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class BioGptScaledWordEmbedding(nn.Embedding):
|
| 84 |
+
"""
|
| 85 |
+
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
| 89 |
+
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
| 90 |
+
self.embed_scale = embed_scale
|
| 91 |
+
|
| 92 |
+
def forward(self, input_ids: torch.Tensor):
|
| 93 |
+
return super().forward(input_ids) * self.embed_scale
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def eager_attention_forward(
|
| 97 |
+
module: nn.Module,
|
| 98 |
+
query: torch.Tensor,
|
| 99 |
+
key: torch.Tensor,
|
| 100 |
+
value: torch.Tensor,
|
| 101 |
+
attention_mask: Optional[torch.Tensor],
|
| 102 |
+
scaling: Optional[float] = None,
|
| 103 |
+
dropout: float = 0.0,
|
| 104 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 105 |
+
**kwargs,
|
| 106 |
+
):
|
| 107 |
+
if scaling is None:
|
| 108 |
+
scaling = query.size(-1) ** -0.5
|
| 109 |
+
|
| 110 |
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
| 111 |
+
if attention_mask is not None:
|
| 112 |
+
attn_weights = attn_weights + attention_mask
|
| 113 |
+
|
| 114 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 115 |
+
|
| 116 |
+
if head_mask is not None:
|
| 117 |
+
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
| 118 |
+
|
| 119 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 120 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 121 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 122 |
+
|
| 123 |
+
return attn_output, attn_weights
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class BioGptAttention(nn.Module):
|
| 127 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 128 |
+
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
embed_dim: int,
|
| 132 |
+
num_heads: int,
|
| 133 |
+
dropout: float = 0.0,
|
| 134 |
+
is_decoder: bool = False,
|
| 135 |
+
bias: bool = True,
|
| 136 |
+
is_causal: bool = False,
|
| 137 |
+
config: Optional[BioGptConfig] = None,
|
| 138 |
+
layer_idx: Optional[int] = None,
|
| 139 |
+
):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.embed_dim = embed_dim
|
| 142 |
+
self.num_heads = num_heads
|
| 143 |
+
self.dropout = dropout
|
| 144 |
+
self.head_dim = embed_dim // num_heads
|
| 145 |
+
self.config = config
|
| 146 |
+
|
| 147 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
| 148 |
+
raise ValueError(
|
| 149 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
| 150 |
+
f" and `num_heads`: {num_heads})."
|
| 151 |
+
)
|
| 152 |
+
self.scaling = self.head_dim**-0.5
|
| 153 |
+
self.is_decoder = is_decoder
|
| 154 |
+
self.is_causal = is_causal
|
| 155 |
+
self.layer_idx = layer_idx
|
| 156 |
+
if layer_idx is None and self.is_decoder:
|
| 157 |
+
logger.warning_once(
|
| 158 |
+
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
|
| 159 |
+
"will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
| 160 |
+
"when creating this class."
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 164 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 165 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 166 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 167 |
+
|
| 168 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 169 |
+
def forward(
|
| 170 |
+
self,
|
| 171 |
+
hidden_states: torch.Tensor,
|
| 172 |
+
key_value_states: Optional[torch.Tensor] = None,
|
| 173 |
+
past_key_values: Optional[Cache] = None,
|
| 174 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 175 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 176 |
+
output_attentions: bool = False,
|
| 177 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 178 |
+
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
| 179 |
+
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
| 180 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 181 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 182 |
+
"""Input shape: Batch x Time x Channel"""
|
| 183 |
+
|
| 184 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 185 |
+
# for the decoder
|
| 186 |
+
is_cross_attention = key_value_states is not None
|
| 187 |
+
|
| 188 |
+
# determine input shapes
|
| 189 |
+
bsz, tgt_len = hidden_states.shape[:-1]
|
| 190 |
+
src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
|
| 191 |
+
|
| 192 |
+
q_input_shape = (bsz, tgt_len, -1, self.head_dim)
|
| 193 |
+
kv_input_shape = (bsz, src_len, -1, self.head_dim)
|
| 194 |
+
|
| 195 |
+
# get query proj
|
| 196 |
+
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
| 197 |
+
|
| 198 |
+
is_updated = False
|
| 199 |
+
if past_key_values is not None:
|
| 200 |
+
if isinstance(past_key_values, EncoderDecoderCache):
|
| 201 |
+
is_updated = past_key_values.is_updated.get(self.layer_idx)
|
| 202 |
+
if is_cross_attention:
|
| 203 |
+
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
| 204 |
+
curr_past_key_value = past_key_values.cross_attention_cache
|
| 205 |
+
else:
|
| 206 |
+
curr_past_key_value = past_key_values.self_attention_cache
|
| 207 |
+
else:
|
| 208 |
+
curr_past_key_value = past_key_values
|
| 209 |
+
|
| 210 |
+
current_states = key_value_states if is_cross_attention else hidden_states
|
| 211 |
+
if is_cross_attention and past_key_values is not None and is_updated:
|
| 212 |
+
# reuse k,v, cross_attentions
|
| 213 |
+
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
| 214 |
+
value_states = curr_past_key_value.layers[self.layer_idx].values
|
| 215 |
+
else:
|
| 216 |
+
key_states = self.k_proj(current_states)
|
| 217 |
+
value_states = self.v_proj(current_states)
|
| 218 |
+
key_states = key_states.view(*kv_input_shape).transpose(1, 2)
|
| 219 |
+
value_states = value_states.view(*kv_input_shape).transpose(1, 2)
|
| 220 |
+
|
| 221 |
+
if past_key_values is not None:
|
| 222 |
+
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
| 223 |
+
cache_position = cache_position if not is_cross_attention else None
|
| 224 |
+
key_states, value_states = curr_past_key_value.update(
|
| 225 |
+
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
| 226 |
+
)
|
| 227 |
+
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
| 228 |
+
if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
|
| 229 |
+
past_key_values.is_updated[self.layer_idx] = True
|
| 230 |
+
|
| 231 |
+
attention_interface: Callable = eager_attention_forward
|
| 232 |
+
if self.config._attn_implementation != "eager":
|
| 233 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 234 |
+
|
| 235 |
+
attn_output, attn_weights = attention_interface(
|
| 236 |
+
self,
|
| 237 |
+
query_states,
|
| 238 |
+
key_states,
|
| 239 |
+
value_states,
|
| 240 |
+
attention_mask,
|
| 241 |
+
dropout=0.0 if not self.training else self.dropout,
|
| 242 |
+
scaling=self.scaling,
|
| 243 |
+
output_attentions=output_attentions,
|
| 244 |
+
head_mask=layer_head_mask,
|
| 245 |
+
**kwargs,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
| 249 |
+
attn_output = self.out_proj(attn_output)
|
| 250 |
+
|
| 251 |
+
return attn_output, attn_weights
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class BioGptDecoderLayer(GradientCheckpointingLayer):
|
| 255 |
+
def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None):
|
| 256 |
+
super().__init__()
|
| 257 |
+
self.embed_dim = config.hidden_size
|
| 258 |
+
|
| 259 |
+
self.self_attn = BioGptAttention(
|
| 260 |
+
embed_dim=self.embed_dim,
|
| 261 |
+
num_heads=config.num_attention_heads,
|
| 262 |
+
dropout=config.attention_probs_dropout_prob,
|
| 263 |
+
is_decoder=True,
|
| 264 |
+
is_causal=True,
|
| 265 |
+
config=config,
|
| 266 |
+
layer_idx=layer_idx,
|
| 267 |
+
)
|
| 268 |
+
self.dropout = config.hidden_dropout_prob
|
| 269 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 270 |
+
self.activation_dropout = config.activation_dropout
|
| 271 |
+
|
| 272 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 273 |
+
|
| 274 |
+
self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size)
|
| 275 |
+
self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim)
|
| 276 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 277 |
+
|
| 278 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 279 |
+
def forward(
|
| 280 |
+
self,
|
| 281 |
+
hidden_states: torch.Tensor,
|
| 282 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 283 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 284 |
+
past_key_values: Optional[Cache] = None,
|
| 285 |
+
output_attentions: Optional[bool] = False,
|
| 286 |
+
use_cache: Optional[bool] = True,
|
| 287 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 288 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 289 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 290 |
+
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 291 |
+
"""
|
| 292 |
+
Args:
|
| 293 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 294 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 295 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 296 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
| 297 |
+
`(encoder_attention_heads,)`.
|
| 298 |
+
past_key_values (`Cache`): cached past key and value projection states
|
| 299 |
+
output_attentions (`bool`, *optional*):
|
| 300 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 301 |
+
returned tensors for more detail.
|
| 302 |
+
use_cache (`bool`, *optional*):
|
| 303 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 304 |
+
(see `past_key_values`).
|
| 305 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 306 |
+
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
|
| 307 |
+
cache in the correct position and to infer the complete sequence length.
|
| 308 |
+
"""
|
| 309 |
+
residual = hidden_states
|
| 310 |
+
|
| 311 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 312 |
+
|
| 313 |
+
# Self Attention
|
| 314 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 315 |
+
hidden_states=hidden_states,
|
| 316 |
+
past_key_values=past_key_values,
|
| 317 |
+
attention_mask=attention_mask,
|
| 318 |
+
layer_head_mask=layer_head_mask,
|
| 319 |
+
output_attentions=output_attentions,
|
| 320 |
+
position_ids=position_ids,
|
| 321 |
+
cache_position=cache_position,
|
| 322 |
+
**kwargs,
|
| 323 |
+
)
|
| 324 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 325 |
+
hidden_states = residual + hidden_states
|
| 326 |
+
|
| 327 |
+
# Fully Connected
|
| 328 |
+
residual = hidden_states
|
| 329 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 330 |
+
hidden_states = self.fc1(hidden_states)
|
| 331 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 332 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
| 333 |
+
hidden_states = self.fc2(hidden_states)
|
| 334 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 335 |
+
hidden_states = residual + hidden_states
|
| 336 |
+
|
| 337 |
+
outputs = (hidden_states,)
|
| 338 |
+
|
| 339 |
+
if output_attentions:
|
| 340 |
+
outputs += (self_attn_weights,)
|
| 341 |
+
|
| 342 |
+
return outputs
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
@auto_docstring
|
| 346 |
+
class BioGptPreTrainedModel(PreTrainedModel):
|
| 347 |
+
config: BioGptConfig
|
| 348 |
+
base_model_prefix = "biogpt"
|
| 349 |
+
supports_gradient_checkpointing = True
|
| 350 |
+
_supports_flash_attn = True
|
| 351 |
+
_supports_sdpa = True
|
| 352 |
+
_supports_flex_attn = True
|
| 353 |
+
|
| 354 |
+
_can_compile_fullgraph = True
|
| 355 |
+
|
| 356 |
+
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
|
| 357 |
+
def _update_causal_mask(
|
| 358 |
+
self,
|
| 359 |
+
attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
|
| 360 |
+
input_tensor: torch.Tensor,
|
| 361 |
+
cache_position: torch.Tensor,
|
| 362 |
+
past_key_values: Cache,
|
| 363 |
+
):
|
| 364 |
+
if self.config._attn_implementation == "flex_attention":
|
| 365 |
+
if isinstance(attention_mask, torch.Tensor):
|
| 366 |
+
attention_mask = make_flex_block_causal_mask(attention_mask)
|
| 367 |
+
# Other attention flavors support in-built causal (when `mask is None`)
|
| 368 |
+
# while we need to create our specific block mask regardless
|
| 369 |
+
elif attention_mask is None:
|
| 370 |
+
attention_mask = make_flex_block_causal_mask(
|
| 371 |
+
torch.ones(
|
| 372 |
+
size=(input_tensor.shape[0], input_tensor.shape[1]),
|
| 373 |
+
device=attention_mask.device,
|
| 374 |
+
)
|
| 375 |
+
)
|
| 376 |
+
return attention_mask
|
| 377 |
+
|
| 378 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 379 |
+
if attention_mask is not None and (attention_mask == 0.0).any():
|
| 380 |
+
return attention_mask
|
| 381 |
+
return None
|
| 382 |
+
|
| 383 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 384 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 385 |
+
# to infer the attention mask.
|
| 386 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 387 |
+
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
| 388 |
+
|
| 389 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 390 |
+
if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
|
| 391 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 392 |
+
attention_mask,
|
| 393 |
+
inputs_embeds=input_tensor,
|
| 394 |
+
past_key_values_length=past_seen_tokens,
|
| 395 |
+
is_training=self.training,
|
| 396 |
+
):
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
dtype = input_tensor.dtype
|
| 400 |
+
sequence_length = input_tensor.shape[1]
|
| 401 |
+
if using_compilable_cache:
|
| 402 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 403 |
+
else:
|
| 404 |
+
target_length = (
|
| 405 |
+
attention_mask.shape[-1]
|
| 406 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 407 |
+
else past_seen_tokens + sequence_length + 1
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 411 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 412 |
+
attention_mask,
|
| 413 |
+
sequence_length=sequence_length,
|
| 414 |
+
target_length=target_length,
|
| 415 |
+
dtype=dtype,
|
| 416 |
+
cache_position=cache_position,
|
| 417 |
+
batch_size=input_tensor.shape[0],
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
if (
|
| 421 |
+
self.config._attn_implementation == "sdpa"
|
| 422 |
+
and attention_mask is not None
|
| 423 |
+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
| 424 |
+
):
|
| 425 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 426 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 427 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 428 |
+
min_dtype = torch.finfo(dtype).min
|
| 429 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 430 |
+
|
| 431 |
+
return causal_mask
|
| 432 |
+
|
| 433 |
+
@staticmethod
|
| 434 |
+
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
| 435 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 436 |
+
attention_mask: torch.Tensor,
|
| 437 |
+
sequence_length: int,
|
| 438 |
+
target_length: int,
|
| 439 |
+
dtype: torch.dtype,
|
| 440 |
+
cache_position: torch.Tensor,
|
| 441 |
+
batch_size: int,
|
| 442 |
+
**kwargs,
|
| 443 |
+
):
|
| 444 |
+
"""
|
| 445 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 446 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
attention_mask (`torch.Tensor`):
|
| 450 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
| 451 |
+
`(batch_size, 1, query_length, key_value_length)`.
|
| 452 |
+
sequence_length (`int`):
|
| 453 |
+
The sequence length being processed.
|
| 454 |
+
target_length (`int`):
|
| 455 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
| 456 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 457 |
+
dtype (`torch.dtype`):
|
| 458 |
+
The dtype to use for the 4D attention mask.
|
| 459 |
+
cache_position (`torch.Tensor`):
|
| 460 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 461 |
+
batch_size (`torch.Tensor`):
|
| 462 |
+
Batch size.
|
| 463 |
+
"""
|
| 464 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 465 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 466 |
+
causal_mask = attention_mask
|
| 467 |
+
else:
|
| 468 |
+
min_dtype = torch.finfo(dtype).min
|
| 469 |
+
causal_mask = torch.full(
|
| 470 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
| 471 |
+
)
|
| 472 |
+
if sequence_length != 1:
|
| 473 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 474 |
+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
| 475 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 476 |
+
if attention_mask is not None:
|
| 477 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 478 |
+
mask_length = attention_mask.shape[-1]
|
| 479 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
| 480 |
+
causal_mask.device
|
| 481 |
+
)
|
| 482 |
+
padding_mask = padding_mask == 0
|
| 483 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 484 |
+
padding_mask, min_dtype
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
return causal_mask
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
@auto_docstring
|
| 491 |
+
class BioGptModel(BioGptPreTrainedModel):
|
| 492 |
+
def __init__(self, config: BioGptConfig):
|
| 493 |
+
super().__init__(config)
|
| 494 |
+
self.config = config
|
| 495 |
+
self.layerdrop = config.layerdrop
|
| 496 |
+
self.dropout = config.hidden_dropout_prob
|
| 497 |
+
self.embed_dim = config.hidden_size
|
| 498 |
+
self.padding_idx = config.pad_token_id
|
| 499 |
+
embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
|
| 500 |
+
|
| 501 |
+
self.embed_tokens = BioGptScaledWordEmbedding(
|
| 502 |
+
config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale
|
| 503 |
+
)
|
| 504 |
+
self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
|
| 505 |
+
|
| 506 |
+
self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
| 507 |
+
self.layer_norm = nn.LayerNorm(self.embed_dim)
|
| 508 |
+
|
| 509 |
+
self.gradient_checkpointing = False
|
| 510 |
+
# Initialize weights and apply final processing
|
| 511 |
+
self.post_init()
|
| 512 |
+
|
| 513 |
+
@auto_docstring
|
| 514 |
+
def forward(
|
| 515 |
+
self,
|
| 516 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 517 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 518 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 519 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 520 |
+
past_key_values: Optional[Cache] = None,
|
| 521 |
+
use_cache: Optional[bool] = None,
|
| 522 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 523 |
+
output_attentions: Optional[bool] = None,
|
| 524 |
+
output_hidden_states: Optional[bool] = None,
|
| 525 |
+
return_dict: Optional[bool] = None,
|
| 526 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 527 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 528 |
+
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
| 529 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 530 |
+
output_hidden_states = (
|
| 531 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 532 |
+
)
|
| 533 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 534 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 535 |
+
|
| 536 |
+
# retrieve input_ids and inputs_embeds
|
| 537 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 538 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
| 539 |
+
elif input_ids is not None:
|
| 540 |
+
input = input_ids
|
| 541 |
+
input_shape = input.shape
|
| 542 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 543 |
+
elif inputs_embeds is not None:
|
| 544 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 545 |
+
input = inputs_embeds[:, :, -1]
|
| 546 |
+
else:
|
| 547 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 548 |
+
|
| 549 |
+
if inputs_embeds is None:
|
| 550 |
+
inputs_embeds = self.embed_tokens(input)
|
| 551 |
+
|
| 552 |
+
if self.gradient_checkpointing and self.training:
|
| 553 |
+
if use_cache:
|
| 554 |
+
logger.warning_once(
|
| 555 |
+
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
|
| 556 |
+
)
|
| 557 |
+
use_cache = False
|
| 558 |
+
|
| 559 |
+
# initialize past_key_values
|
| 560 |
+
if use_cache and past_key_values is None:
|
| 561 |
+
past_key_values = DynamicCache(config=self.config)
|
| 562 |
+
if use_cache and isinstance(past_key_values, tuple):
|
| 563 |
+
logger.warning_once(
|
| 564 |
+
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
|
| 565 |
+
"You should pass an instance of `DynamicCache` instead, e.g. "
|
| 566 |
+
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
|
| 567 |
+
)
|
| 568 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 569 |
+
|
| 570 |
+
batch_size, seq_length = inputs_embeds.size()[:-1]
|
| 571 |
+
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 572 |
+
if cache_position is None:
|
| 573 |
+
cache_position = torch.arange(
|
| 574 |
+
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
if attention_mask is None:
|
| 578 |
+
# required mask seq length can be calculated via length of past cache
|
| 579 |
+
mask_seq_length = past_key_values_length + seq_length
|
| 580 |
+
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
| 581 |
+
|
| 582 |
+
self_attn_cache = past_key_values
|
| 583 |
+
|
| 584 |
+
causal_mask = self._update_causal_mask(
|
| 585 |
+
attention_mask,
|
| 586 |
+
inputs_embeds,
|
| 587 |
+
cache_position,
|
| 588 |
+
self_attn_cache,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
# embed positions
|
| 592 |
+
if position_ids is None:
|
| 593 |
+
# position_ids = cache_position.unsqueeze(0)
|
| 594 |
+
position_ids = torch.cumsum(attention_mask, dim=1)
|
| 595 |
+
position_ids = (position_ids * attention_mask - 1).long()
|
| 596 |
+
# cut positions if `past_seen_tokens` is > 0
|
| 597 |
+
position_ids = position_ids[:, past_key_values_length:]
|
| 598 |
+
|
| 599 |
+
positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
|
| 600 |
+
hidden_states = inputs_embeds + positions
|
| 601 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 602 |
+
|
| 603 |
+
if self.gradient_checkpointing and self.training:
|
| 604 |
+
if use_cache:
|
| 605 |
+
logger.warning_once(
|
| 606 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 607 |
+
)
|
| 608 |
+
use_cache = False
|
| 609 |
+
|
| 610 |
+
all_hidden_states = () if output_hidden_states else None
|
| 611 |
+
all_self_attns = () if output_attentions else None
|
| 612 |
+
all_cross_attentions = None
|
| 613 |
+
|
| 614 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 615 |
+
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
| 616 |
+
if output_hidden_states:
|
| 617 |
+
all_hidden_states += (hidden_states,)
|
| 618 |
+
if self.training:
|
| 619 |
+
dropout_probability = torch.rand([])
|
| 620 |
+
if dropout_probability < self.layerdrop:
|
| 621 |
+
continue
|
| 622 |
+
|
| 623 |
+
layer_outputs = decoder_layer(
|
| 624 |
+
hidden_states,
|
| 625 |
+
attention_mask=causal_mask,
|
| 626 |
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
| 627 |
+
past_key_values=past_key_values,
|
| 628 |
+
output_attentions=output_attentions,
|
| 629 |
+
use_cache=use_cache,
|
| 630 |
+
position_ids=position_ids,
|
| 631 |
+
cache_position=cache_position,
|
| 632 |
+
**kwargs,
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
hidden_states = layer_outputs[0]
|
| 636 |
+
|
| 637 |
+
if output_attentions:
|
| 638 |
+
all_self_attns += (layer_outputs[1],)
|
| 639 |
+
|
| 640 |
+
# add hidden states from the last decoder layer
|
| 641 |
+
if output_hidden_states:
|
| 642 |
+
all_hidden_states += (hidden_states,)
|
| 643 |
+
|
| 644 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 645 |
+
|
| 646 |
+
if not return_dict:
|
| 647 |
+
return tuple(
|
| 648 |
+
v
|
| 649 |
+
for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
|
| 650 |
+
if v is not None
|
| 651 |
+
)
|
| 652 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 653 |
+
last_hidden_state=hidden_states,
|
| 654 |
+
past_key_values=past_key_values,
|
| 655 |
+
hidden_states=all_hidden_states,
|
| 656 |
+
attentions=all_self_attns,
|
| 657 |
+
cross_attentions=all_cross_attentions,
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
@auto_docstring(
|
| 662 |
+
custom_intro="""
|
| 663 |
+
BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
|
| 664 |
+
"""
|
| 665 |
+
)
|
| 666 |
+
class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
|
| 667 |
+
_tied_weights_keys = ["output_projection.weight"]
|
| 668 |
+
|
| 669 |
+
def __init__(self, config):
|
| 670 |
+
super().__init__(config)
|
| 671 |
+
|
| 672 |
+
self.biogpt = BioGptModel(config)
|
| 673 |
+
self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 674 |
+
|
| 675 |
+
# Initialize weights and apply final processing
|
| 676 |
+
self.post_init()
|
| 677 |
+
|
| 678 |
+
def get_output_embeddings(self):
|
| 679 |
+
return self.output_projection
|
| 680 |
+
|
| 681 |
+
def set_output_embeddings(self, new_embeddings):
|
| 682 |
+
self.output_projection = new_embeddings
|
| 683 |
+
|
| 684 |
+
@auto_docstring
|
| 685 |
+
def forward(
|
| 686 |
+
self,
|
| 687 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 688 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 689 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 690 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 691 |
+
past_key_values: Optional[Cache] = None,
|
| 692 |
+
labels: Optional[torch.LongTensor] = None,
|
| 693 |
+
use_cache: Optional[bool] = None,
|
| 694 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 695 |
+
output_attentions: Optional[bool] = None,
|
| 696 |
+
output_hidden_states: Optional[bool] = None,
|
| 697 |
+
return_dict: Optional[bool] = None,
|
| 698 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 699 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 700 |
+
) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
|
| 701 |
+
r"""
|
| 702 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 703 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
| 704 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
| 705 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
| 706 |
+
"""
|
| 707 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 708 |
+
|
| 709 |
+
outputs = self.biogpt(
|
| 710 |
+
input_ids,
|
| 711 |
+
attention_mask=attention_mask,
|
| 712 |
+
head_mask=head_mask,
|
| 713 |
+
inputs_embeds=inputs_embeds,
|
| 714 |
+
past_key_values=past_key_values,
|
| 715 |
+
use_cache=use_cache,
|
| 716 |
+
position_ids=position_ids,
|
| 717 |
+
output_attentions=output_attentions,
|
| 718 |
+
output_hidden_states=output_hidden_states,
|
| 719 |
+
return_dict=return_dict,
|
| 720 |
+
cache_position=cache_position,
|
| 721 |
+
**kwargs,
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
sequence_output = outputs[0]
|
| 725 |
+
prediction_scores = self.output_projection(sequence_output)
|
| 726 |
+
|
| 727 |
+
lm_loss = None
|
| 728 |
+
if labels is not None:
|
| 729 |
+
lm_loss = self.loss_function(
|
| 730 |
+
prediction_scores,
|
| 731 |
+
labels,
|
| 732 |
+
vocab_size=self.config.vocab_size,
|
| 733 |
+
**kwargs,
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
if not return_dict:
|
| 737 |
+
output = (prediction_scores,) + outputs[1:]
|
| 738 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
| 739 |
+
|
| 740 |
+
return CausalLMOutputWithCrossAttentions(
|
| 741 |
+
loss=lm_loss,
|
| 742 |
+
logits=prediction_scores,
|
| 743 |
+
past_key_values=outputs.past_key_values,
|
| 744 |
+
hidden_states=outputs.hidden_states,
|
| 745 |
+
attentions=outputs.attentions,
|
| 746 |
+
cross_attentions=outputs.cross_attentions,
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
@auto_docstring
|
| 751 |
+
class BioGptForTokenClassification(BioGptPreTrainedModel):
|
| 752 |
+
def __init__(self, config):
|
| 753 |
+
super().__init__(config)
|
| 754 |
+
self.num_labels = config.num_labels
|
| 755 |
+
|
| 756 |
+
self.biogpt = BioGptModel(config)
|
| 757 |
+
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
|
| 758 |
+
classifier_dropout = config.classifier_dropout
|
| 759 |
+
else:
|
| 760 |
+
classifier_dropout = config.hidden_dropout_prob
|
| 761 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 762 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 763 |
+
|
| 764 |
+
self.post_init()
|
| 765 |
+
|
| 766 |
+
@auto_docstring
|
| 767 |
+
def forward(
|
| 768 |
+
self,
|
| 769 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 770 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 771 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 772 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 773 |
+
past_key_values: Optional[Cache] = None,
|
| 774 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 775 |
+
labels: Optional[torch.LongTensor] = None,
|
| 776 |
+
use_cache: Optional[bool] = None,
|
| 777 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 778 |
+
output_attentions: Optional[bool] = None,
|
| 779 |
+
output_hidden_states: Optional[bool] = None,
|
| 780 |
+
return_dict: Optional[bool] = None,
|
| 781 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 782 |
+
) -> Union[tuple, TokenClassifierOutput]:
|
| 783 |
+
r"""
|
| 784 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 785 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 786 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 787 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 788 |
+
"""
|
| 789 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 790 |
+
|
| 791 |
+
transformer_outputs = self.biogpt(
|
| 792 |
+
input_ids,
|
| 793 |
+
past_key_values=past_key_values,
|
| 794 |
+
attention_mask=attention_mask,
|
| 795 |
+
head_mask=head_mask,
|
| 796 |
+
inputs_embeds=inputs_embeds,
|
| 797 |
+
use_cache=use_cache,
|
| 798 |
+
position_ids=position_ids,
|
| 799 |
+
output_attentions=output_attentions,
|
| 800 |
+
output_hidden_states=output_hidden_states,
|
| 801 |
+
return_dict=return_dict,
|
| 802 |
+
cache_position=cache_position,
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
hidden_states = transformer_outputs[0]
|
| 806 |
+
hidden_states = self.dropout(hidden_states)
|
| 807 |
+
logits = self.classifier(hidden_states)
|
| 808 |
+
|
| 809 |
+
loss = None
|
| 810 |
+
if labels is not None:
|
| 811 |
+
loss_fct = CrossEntropyLoss()
|
| 812 |
+
# Only keep active parts of the loss
|
| 813 |
+
if attention_mask is not None:
|
| 814 |
+
active_loss = attention_mask.view(-1) == 1
|
| 815 |
+
active_logits = logits.view(-1, self.num_labels)
|
| 816 |
+
active_labels = torch.where(
|
| 817 |
+
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
| 818 |
+
)
|
| 819 |
+
loss = loss_fct(active_logits, active_labels)
|
| 820 |
+
else:
|
| 821 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 822 |
+
|
| 823 |
+
if not return_dict:
|
| 824 |
+
output = (logits,) + transformer_outputs[2:]
|
| 825 |
+
return ((loss,) + output) if loss is not None else output
|
| 826 |
+
|
| 827 |
+
return TokenClassifierOutput(
|
| 828 |
+
loss=loss,
|
| 829 |
+
logits=logits,
|
| 830 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 831 |
+
attentions=transformer_outputs.attentions,
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
@auto_docstring(
|
| 836 |
+
custom_intro="""
|
| 837 |
+
The BioGpt Model transformer with a sequence classification head on top (linear layer).
|
| 838 |
+
|
| 839 |
+
[`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
| 840 |
+
(e.g. GPT-2) do.
|
| 841 |
+
|
| 842 |
+
Since it does classification on the last token, it is required to know the position of the last token. If a
|
| 843 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
| 844 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
| 845 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
| 846 |
+
each row of the batch).
|
| 847 |
+
"""
|
| 848 |
+
)
|
| 849 |
+
class BioGptForSequenceClassification(BioGptPreTrainedModel):
|
| 850 |
+
def __init__(self, config: BioGptConfig):
|
| 851 |
+
super().__init__(config)
|
| 852 |
+
self.num_labels = config.num_labels
|
| 853 |
+
self.biogpt = BioGptModel(config)
|
| 854 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 855 |
+
|
| 856 |
+
# Initialize weights and apply final processing
|
| 857 |
+
self.post_init()
|
| 858 |
+
|
| 859 |
+
@auto_docstring
|
| 860 |
+
def forward(
|
| 861 |
+
self,
|
| 862 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 863 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 864 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 865 |
+
past_key_values: Optional[Cache] = None,
|
| 866 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 867 |
+
labels: Optional[torch.LongTensor] = None,
|
| 868 |
+
use_cache: Optional[bool] = None,
|
| 869 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 870 |
+
output_attentions: Optional[bool] = None,
|
| 871 |
+
output_hidden_states: Optional[bool] = None,
|
| 872 |
+
return_dict: Optional[bool] = None,
|
| 873 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 874 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 875 |
+
) -> Union[tuple, SequenceClassifierOutputWithPast]:
|
| 876 |
+
r"""
|
| 877 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 878 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 879 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 880 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 881 |
+
"""
|
| 882 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 883 |
+
|
| 884 |
+
transformer_outputs = self.biogpt(
|
| 885 |
+
input_ids,
|
| 886 |
+
past_key_values=past_key_values,
|
| 887 |
+
attention_mask=attention_mask,
|
| 888 |
+
head_mask=head_mask,
|
| 889 |
+
inputs_embeds=inputs_embeds,
|
| 890 |
+
use_cache=use_cache,
|
| 891 |
+
position_ids=position_ids,
|
| 892 |
+
output_attentions=output_attentions,
|
| 893 |
+
output_hidden_states=output_hidden_states,
|
| 894 |
+
return_dict=return_dict,
|
| 895 |
+
cache_position=cache_position,
|
| 896 |
+
)
|
| 897 |
+
hidden_states = transformer_outputs[0]
|
| 898 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 899 |
+
logits = self.score(hidden_states[:, slice_indices, :])
|
| 900 |
+
|
| 901 |
+
if input_ids is not None:
|
| 902 |
+
batch_size, sequence_length = input_ids.shape[:2]
|
| 903 |
+
else:
|
| 904 |
+
batch_size, sequence_length = inputs_embeds.shape[:2]
|
| 905 |
+
|
| 906 |
+
if self.config.pad_token_id is None:
|
| 907 |
+
sequence_length = -1
|
| 908 |
+
else:
|
| 909 |
+
if input_ids is not None:
|
| 910 |
+
sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
| 911 |
+
else:
|
| 912 |
+
sequence_length = -1
|
| 913 |
+
logger.warning_once(
|
| 914 |
+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
| 915 |
+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length]
|
| 919 |
+
|
| 920 |
+
loss = None
|
| 921 |
+
if labels is not None:
|
| 922 |
+
if self.config.problem_type is None:
|
| 923 |
+
if self.num_labels == 1:
|
| 924 |
+
self.config.problem_type = "regression"
|
| 925 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 926 |
+
self.config.problem_type = "single_label_classification"
|
| 927 |
+
else:
|
| 928 |
+
self.config.problem_type = "multi_label_classification"
|
| 929 |
+
|
| 930 |
+
if self.config.problem_type == "regression":
|
| 931 |
+
loss_fct = MSELoss()
|
| 932 |
+
if self.num_labels == 1:
|
| 933 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
| 934 |
+
else:
|
| 935 |
+
loss = loss_fct(pooled_logits, labels)
|
| 936 |
+
elif self.config.problem_type == "single_label_classification":
|
| 937 |
+
loss_fct = CrossEntropyLoss()
|
| 938 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
| 939 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 940 |
+
loss_fct = BCEWithLogitsLoss()
|
| 941 |
+
loss = loss_fct(pooled_logits, labels)
|
| 942 |
+
if not return_dict:
|
| 943 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
| 944 |
+
return ((loss,) + output) if loss is not None else output
|
| 945 |
+
|
| 946 |
+
return SequenceClassifierOutputWithPast(
|
| 947 |
+
loss=loss,
|
| 948 |
+
logits=pooled_logits,
|
| 949 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 950 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 951 |
+
attentions=transformer_outputs.attentions,
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
def get_input_embeddings(self):
|
| 955 |
+
return self.biogpt.embed_tokens
|
| 956 |
+
|
| 957 |
+
def set_input_embeddings(self, value):
|
| 958 |
+
self.biogpt.embed_tokens = value
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
__all__ = [
|
| 962 |
+
"BioGptForCausalLM",
|
| 963 |
+
"BioGptForTokenClassification",
|
| 964 |
+
"BioGptForSequenceClassification",
|
| 965 |
+
"BioGptModel",
|
| 966 |
+
"BioGptPreTrainedModel",
|
| 967 |
+
]
|
venv/lib/python3.13/site-packages/transformers/models/biogpt/modular_biogpt.py
ADDED
|
@@ -0,0 +1,789 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch BioGPT model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import Optional, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 23 |
+
|
| 24 |
+
from ...activations import ACT2FN
|
| 25 |
+
from ...cache_utils import Cache, DynamicCache
|
| 26 |
+
from ...generation import GenerationMixin
|
| 27 |
+
from ...modeling_attn_mask_utils import (
|
| 28 |
+
AttentionMaskConverter,
|
| 29 |
+
)
|
| 30 |
+
from ...modeling_outputs import (
|
| 31 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 32 |
+
CausalLMOutputWithCrossAttentions,
|
| 33 |
+
SequenceClassifierOutputWithPast,
|
| 34 |
+
TokenClassifierOutput,
|
| 35 |
+
)
|
| 36 |
+
from ...modeling_utils import PreTrainedModel
|
| 37 |
+
from ...processing_utils import Unpack
|
| 38 |
+
from ...utils import (
|
| 39 |
+
TransformersKwargs,
|
| 40 |
+
auto_docstring,
|
| 41 |
+
is_torch_flex_attn_available,
|
| 42 |
+
logger,
|
| 43 |
+
)
|
| 44 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 45 |
+
from ..bart.modeling_bart import (
|
| 46 |
+
BartAttention,
|
| 47 |
+
BartDecoderLayer,
|
| 48 |
+
BartScaledWordEmbedding,
|
| 49 |
+
)
|
| 50 |
+
from ..opt.modeling_opt import OPTLearnedPositionalEmbedding
|
| 51 |
+
from .configuration_biogpt import BioGptConfig
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if is_torch_flex_attn_available():
|
| 55 |
+
from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class BioGptLearnedPositionalEmbedding(OPTLearnedPositionalEmbedding):
|
| 59 |
+
def forward(
|
| 60 |
+
self,
|
| 61 |
+
attention_mask: torch.LongTensor,
|
| 62 |
+
past_key_values_length: int = 0,
|
| 63 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 64 |
+
):
|
| 65 |
+
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
| 66 |
+
super().forward(attention_mask, past_key_values_length, position_ids)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class BioGptScaledWordEmbedding(BartScaledWordEmbedding):
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class BioGptAttention(BartAttention):
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class BioGptDecoderLayer(BartDecoderLayer):
|
| 78 |
+
def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None):
|
| 79 |
+
super().__init__(config)
|
| 80 |
+
self.embed_dim = config.hidden_size
|
| 81 |
+
|
| 82 |
+
self.self_attn = BioGptAttention(
|
| 83 |
+
embed_dim=self.embed_dim,
|
| 84 |
+
num_heads=config.num_attention_heads,
|
| 85 |
+
dropout=config.attention_probs_dropout_prob,
|
| 86 |
+
is_decoder=True,
|
| 87 |
+
is_causal=True,
|
| 88 |
+
config=config,
|
| 89 |
+
layer_idx=layer_idx,
|
| 90 |
+
)
|
| 91 |
+
self.dropout = config.hidden_dropout_prob
|
| 92 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 93 |
+
|
| 94 |
+
self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size)
|
| 95 |
+
self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim)
|
| 96 |
+
|
| 97 |
+
del self.encoder_attn
|
| 98 |
+
del self.encoder_attn_layer_norm
|
| 99 |
+
|
| 100 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 101 |
+
def forward(
|
| 102 |
+
self,
|
| 103 |
+
hidden_states: torch.Tensor,
|
| 104 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 105 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 106 |
+
past_key_values: Optional[Cache] = None,
|
| 107 |
+
output_attentions: Optional[bool] = False,
|
| 108 |
+
use_cache: Optional[bool] = True,
|
| 109 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 110 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 111 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 112 |
+
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 113 |
+
"""
|
| 114 |
+
Args:
|
| 115 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 116 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 117 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 118 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
| 119 |
+
`(encoder_attention_heads,)`.
|
| 120 |
+
past_key_values (`Cache`): cached past key and value projection states
|
| 121 |
+
output_attentions (`bool`, *optional*):
|
| 122 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 123 |
+
returned tensors for more detail.
|
| 124 |
+
use_cache (`bool`, *optional*):
|
| 125 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 126 |
+
(see `past_key_values`).
|
| 127 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 128 |
+
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
|
| 129 |
+
cache in the correct position and to infer the complete sequence length.
|
| 130 |
+
"""
|
| 131 |
+
residual = hidden_states
|
| 132 |
+
|
| 133 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 134 |
+
|
| 135 |
+
# Self Attention
|
| 136 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 137 |
+
hidden_states=hidden_states,
|
| 138 |
+
past_key_values=past_key_values,
|
| 139 |
+
attention_mask=attention_mask,
|
| 140 |
+
layer_head_mask=layer_head_mask,
|
| 141 |
+
output_attentions=output_attentions,
|
| 142 |
+
position_ids=position_ids,
|
| 143 |
+
cache_position=cache_position,
|
| 144 |
+
**kwargs,
|
| 145 |
+
)
|
| 146 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 147 |
+
hidden_states = residual + hidden_states
|
| 148 |
+
|
| 149 |
+
# Fully Connected
|
| 150 |
+
residual = hidden_states
|
| 151 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 152 |
+
hidden_states = self.fc1(hidden_states)
|
| 153 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 154 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
| 155 |
+
hidden_states = self.fc2(hidden_states)
|
| 156 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 157 |
+
hidden_states = residual + hidden_states
|
| 158 |
+
|
| 159 |
+
outputs = (hidden_states,)
|
| 160 |
+
|
| 161 |
+
if output_attentions:
|
| 162 |
+
outputs += (self_attn_weights,)
|
| 163 |
+
|
| 164 |
+
return outputs
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@auto_docstring
|
| 168 |
+
class BioGptPreTrainedModel(PreTrainedModel):
|
| 169 |
+
config: BioGptConfig
|
| 170 |
+
base_model_prefix = "biogpt"
|
| 171 |
+
supports_gradient_checkpointing = True
|
| 172 |
+
_supports_flash_attn = True
|
| 173 |
+
_supports_sdpa = True
|
| 174 |
+
_supports_flex_attn = True
|
| 175 |
+
|
| 176 |
+
_can_compile_fullgraph = True
|
| 177 |
+
|
| 178 |
+
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
|
| 179 |
+
def _update_causal_mask(
|
| 180 |
+
self,
|
| 181 |
+
attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
|
| 182 |
+
input_tensor: torch.Tensor,
|
| 183 |
+
cache_position: torch.Tensor,
|
| 184 |
+
past_key_values: Cache,
|
| 185 |
+
):
|
| 186 |
+
if self.config._attn_implementation == "flex_attention":
|
| 187 |
+
if isinstance(attention_mask, torch.Tensor):
|
| 188 |
+
attention_mask = make_flex_block_causal_mask(attention_mask)
|
| 189 |
+
# Other attention flavors support in-built causal (when `mask is None`)
|
| 190 |
+
# while we need to create our specific block mask regardless
|
| 191 |
+
elif attention_mask is None:
|
| 192 |
+
attention_mask = make_flex_block_causal_mask(
|
| 193 |
+
torch.ones(
|
| 194 |
+
size=(input_tensor.shape[0], input_tensor.shape[1]),
|
| 195 |
+
device=attention_mask.device,
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
return attention_mask
|
| 199 |
+
|
| 200 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 201 |
+
if attention_mask is not None and (attention_mask == 0.0).any():
|
| 202 |
+
return attention_mask
|
| 203 |
+
return None
|
| 204 |
+
|
| 205 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 206 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 207 |
+
# to infer the attention mask.
|
| 208 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 209 |
+
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
| 210 |
+
|
| 211 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 212 |
+
if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
|
| 213 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 214 |
+
attention_mask,
|
| 215 |
+
inputs_embeds=input_tensor,
|
| 216 |
+
past_key_values_length=past_seen_tokens,
|
| 217 |
+
is_training=self.training,
|
| 218 |
+
):
|
| 219 |
+
return None
|
| 220 |
+
|
| 221 |
+
dtype = input_tensor.dtype
|
| 222 |
+
sequence_length = input_tensor.shape[1]
|
| 223 |
+
if using_compilable_cache:
|
| 224 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 225 |
+
else:
|
| 226 |
+
target_length = (
|
| 227 |
+
attention_mask.shape[-1]
|
| 228 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 229 |
+
else past_seen_tokens + sequence_length + 1
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 233 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 234 |
+
attention_mask,
|
| 235 |
+
sequence_length=sequence_length,
|
| 236 |
+
target_length=target_length,
|
| 237 |
+
dtype=dtype,
|
| 238 |
+
cache_position=cache_position,
|
| 239 |
+
batch_size=input_tensor.shape[0],
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
if (
|
| 243 |
+
self.config._attn_implementation == "sdpa"
|
| 244 |
+
and attention_mask is not None
|
| 245 |
+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
| 246 |
+
):
|
| 247 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 248 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 249 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 250 |
+
min_dtype = torch.finfo(dtype).min
|
| 251 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 252 |
+
|
| 253 |
+
return causal_mask
|
| 254 |
+
|
| 255 |
+
@staticmethod
|
| 256 |
+
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
| 257 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 258 |
+
attention_mask: torch.Tensor,
|
| 259 |
+
sequence_length: int,
|
| 260 |
+
target_length: int,
|
| 261 |
+
dtype: torch.dtype,
|
| 262 |
+
cache_position: torch.Tensor,
|
| 263 |
+
batch_size: int,
|
| 264 |
+
**kwargs,
|
| 265 |
+
):
|
| 266 |
+
"""
|
| 267 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 268 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
attention_mask (`torch.Tensor`):
|
| 272 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
| 273 |
+
`(batch_size, 1, query_length, key_value_length)`.
|
| 274 |
+
sequence_length (`int`):
|
| 275 |
+
The sequence length being processed.
|
| 276 |
+
target_length (`int`):
|
| 277 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
| 278 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 279 |
+
dtype (`torch.dtype`):
|
| 280 |
+
The dtype to use for the 4D attention mask.
|
| 281 |
+
cache_position (`torch.Tensor`):
|
| 282 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 283 |
+
batch_size (`torch.Tensor`):
|
| 284 |
+
Batch size.
|
| 285 |
+
"""
|
| 286 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 287 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 288 |
+
causal_mask = attention_mask
|
| 289 |
+
else:
|
| 290 |
+
min_dtype = torch.finfo(dtype).min
|
| 291 |
+
causal_mask = torch.full(
|
| 292 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
| 293 |
+
)
|
| 294 |
+
if sequence_length != 1:
|
| 295 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 296 |
+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
| 297 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 298 |
+
if attention_mask is not None:
|
| 299 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 300 |
+
mask_length = attention_mask.shape[-1]
|
| 301 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
| 302 |
+
causal_mask.device
|
| 303 |
+
)
|
| 304 |
+
padding_mask = padding_mask == 0
|
| 305 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 306 |
+
padding_mask, min_dtype
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
return causal_mask
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
@auto_docstring
|
| 313 |
+
class BioGptModel(BioGptPreTrainedModel):
|
| 314 |
+
def __init__(self, config: BioGptConfig):
|
| 315 |
+
super().__init__(config)
|
| 316 |
+
self.config = config
|
| 317 |
+
self.layerdrop = config.layerdrop
|
| 318 |
+
self.dropout = config.hidden_dropout_prob
|
| 319 |
+
self.embed_dim = config.hidden_size
|
| 320 |
+
self.padding_idx = config.pad_token_id
|
| 321 |
+
embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
|
| 322 |
+
|
| 323 |
+
self.embed_tokens = BioGptScaledWordEmbedding(
|
| 324 |
+
config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale
|
| 325 |
+
)
|
| 326 |
+
self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
|
| 327 |
+
|
| 328 |
+
self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
| 329 |
+
self.layer_norm = nn.LayerNorm(self.embed_dim)
|
| 330 |
+
|
| 331 |
+
self.gradient_checkpointing = False
|
| 332 |
+
# Initialize weights and apply final processing
|
| 333 |
+
self.post_init()
|
| 334 |
+
|
| 335 |
+
@auto_docstring
|
| 336 |
+
def forward(
|
| 337 |
+
self,
|
| 338 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 339 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 340 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 341 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 342 |
+
past_key_values: Optional[Cache] = None,
|
| 343 |
+
use_cache: Optional[bool] = None,
|
| 344 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 345 |
+
output_attentions: Optional[bool] = None,
|
| 346 |
+
output_hidden_states: Optional[bool] = None,
|
| 347 |
+
return_dict: Optional[bool] = None,
|
| 348 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 349 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 350 |
+
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
| 351 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 352 |
+
output_hidden_states = (
|
| 353 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 354 |
+
)
|
| 355 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 356 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 357 |
+
|
| 358 |
+
# retrieve input_ids and inputs_embeds
|
| 359 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 360 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
| 361 |
+
elif input_ids is not None:
|
| 362 |
+
input = input_ids
|
| 363 |
+
input_shape = input.shape
|
| 364 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 365 |
+
elif inputs_embeds is not None:
|
| 366 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 367 |
+
input = inputs_embeds[:, :, -1]
|
| 368 |
+
else:
|
| 369 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 370 |
+
|
| 371 |
+
if inputs_embeds is None:
|
| 372 |
+
inputs_embeds = self.embed_tokens(input)
|
| 373 |
+
|
| 374 |
+
if self.gradient_checkpointing and self.training:
|
| 375 |
+
if use_cache:
|
| 376 |
+
logger.warning_once(
|
| 377 |
+
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
|
| 378 |
+
)
|
| 379 |
+
use_cache = False
|
| 380 |
+
|
| 381 |
+
# initialize past_key_values
|
| 382 |
+
if use_cache and past_key_values is None:
|
| 383 |
+
past_key_values = DynamicCache(config=self.config)
|
| 384 |
+
if use_cache and isinstance(past_key_values, tuple):
|
| 385 |
+
logger.warning_once(
|
| 386 |
+
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
|
| 387 |
+
"You should pass an instance of `DynamicCache` instead, e.g. "
|
| 388 |
+
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
|
| 389 |
+
)
|
| 390 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 391 |
+
|
| 392 |
+
batch_size, seq_length = inputs_embeds.size()[:-1]
|
| 393 |
+
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 394 |
+
if cache_position is None:
|
| 395 |
+
cache_position = torch.arange(
|
| 396 |
+
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
if attention_mask is None:
|
| 400 |
+
# required mask seq length can be calculated via length of past cache
|
| 401 |
+
mask_seq_length = past_key_values_length + seq_length
|
| 402 |
+
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
| 403 |
+
|
| 404 |
+
self_attn_cache = past_key_values
|
| 405 |
+
|
| 406 |
+
causal_mask = self._update_causal_mask(
|
| 407 |
+
attention_mask,
|
| 408 |
+
inputs_embeds,
|
| 409 |
+
cache_position,
|
| 410 |
+
self_attn_cache,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# embed positions
|
| 414 |
+
if position_ids is None:
|
| 415 |
+
# position_ids = cache_position.unsqueeze(0)
|
| 416 |
+
position_ids = torch.cumsum(attention_mask, dim=1)
|
| 417 |
+
position_ids = (position_ids * attention_mask - 1).long()
|
| 418 |
+
# cut positions if `past_seen_tokens` is > 0
|
| 419 |
+
position_ids = position_ids[:, past_key_values_length:]
|
| 420 |
+
|
| 421 |
+
positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
|
| 422 |
+
hidden_states = inputs_embeds + positions
|
| 423 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 424 |
+
|
| 425 |
+
if self.gradient_checkpointing and self.training:
|
| 426 |
+
if use_cache:
|
| 427 |
+
logger.warning_once(
|
| 428 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 429 |
+
)
|
| 430 |
+
use_cache = False
|
| 431 |
+
|
| 432 |
+
all_hidden_states = () if output_hidden_states else None
|
| 433 |
+
all_self_attns = () if output_attentions else None
|
| 434 |
+
all_cross_attentions = None
|
| 435 |
+
|
| 436 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 437 |
+
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
| 438 |
+
if output_hidden_states:
|
| 439 |
+
all_hidden_states += (hidden_states,)
|
| 440 |
+
if self.training:
|
| 441 |
+
dropout_probability = torch.rand([])
|
| 442 |
+
if dropout_probability < self.layerdrop:
|
| 443 |
+
continue
|
| 444 |
+
|
| 445 |
+
layer_outputs = decoder_layer(
|
| 446 |
+
hidden_states,
|
| 447 |
+
attention_mask=causal_mask,
|
| 448 |
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
| 449 |
+
past_key_values=past_key_values,
|
| 450 |
+
output_attentions=output_attentions,
|
| 451 |
+
use_cache=use_cache,
|
| 452 |
+
position_ids=position_ids,
|
| 453 |
+
cache_position=cache_position,
|
| 454 |
+
**kwargs,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
hidden_states = layer_outputs[0]
|
| 458 |
+
|
| 459 |
+
if output_attentions:
|
| 460 |
+
all_self_attns += (layer_outputs[1],)
|
| 461 |
+
|
| 462 |
+
# add hidden states from the last decoder layer
|
| 463 |
+
if output_hidden_states:
|
| 464 |
+
all_hidden_states += (hidden_states,)
|
| 465 |
+
|
| 466 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 467 |
+
|
| 468 |
+
if not return_dict:
|
| 469 |
+
return tuple(
|
| 470 |
+
v
|
| 471 |
+
for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
|
| 472 |
+
if v is not None
|
| 473 |
+
)
|
| 474 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 475 |
+
last_hidden_state=hidden_states,
|
| 476 |
+
past_key_values=past_key_values,
|
| 477 |
+
hidden_states=all_hidden_states,
|
| 478 |
+
attentions=all_self_attns,
|
| 479 |
+
cross_attentions=all_cross_attentions,
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
@auto_docstring(
|
| 484 |
+
custom_intro="""
|
| 485 |
+
BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
|
| 486 |
+
"""
|
| 487 |
+
)
|
| 488 |
+
class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
|
| 489 |
+
_tied_weights_keys = ["output_projection.weight"]
|
| 490 |
+
|
| 491 |
+
def __init__(self, config):
|
| 492 |
+
super().__init__(config)
|
| 493 |
+
|
| 494 |
+
self.biogpt = BioGptModel(config)
|
| 495 |
+
self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 496 |
+
|
| 497 |
+
# Initialize weights and apply final processing
|
| 498 |
+
self.post_init()
|
| 499 |
+
|
| 500 |
+
def get_output_embeddings(self):
|
| 501 |
+
return self.output_projection
|
| 502 |
+
|
| 503 |
+
def set_output_embeddings(self, new_embeddings):
|
| 504 |
+
self.output_projection = new_embeddings
|
| 505 |
+
|
| 506 |
+
@auto_docstring
|
| 507 |
+
def forward(
|
| 508 |
+
self,
|
| 509 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 510 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 511 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 512 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 513 |
+
past_key_values: Optional[Cache] = None,
|
| 514 |
+
labels: Optional[torch.LongTensor] = None,
|
| 515 |
+
use_cache: Optional[bool] = None,
|
| 516 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 517 |
+
output_attentions: Optional[bool] = None,
|
| 518 |
+
output_hidden_states: Optional[bool] = None,
|
| 519 |
+
return_dict: Optional[bool] = None,
|
| 520 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 521 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 522 |
+
) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
|
| 523 |
+
r"""
|
| 524 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 525 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
| 526 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
| 527 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
| 528 |
+
"""
|
| 529 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 530 |
+
|
| 531 |
+
outputs = self.biogpt(
|
| 532 |
+
input_ids,
|
| 533 |
+
attention_mask=attention_mask,
|
| 534 |
+
head_mask=head_mask,
|
| 535 |
+
inputs_embeds=inputs_embeds,
|
| 536 |
+
past_key_values=past_key_values,
|
| 537 |
+
use_cache=use_cache,
|
| 538 |
+
position_ids=position_ids,
|
| 539 |
+
output_attentions=output_attentions,
|
| 540 |
+
output_hidden_states=output_hidden_states,
|
| 541 |
+
return_dict=return_dict,
|
| 542 |
+
cache_position=cache_position,
|
| 543 |
+
**kwargs,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
sequence_output = outputs[0]
|
| 547 |
+
prediction_scores = self.output_projection(sequence_output)
|
| 548 |
+
|
| 549 |
+
lm_loss = None
|
| 550 |
+
if labels is not None:
|
| 551 |
+
lm_loss = self.loss_function(
|
| 552 |
+
prediction_scores,
|
| 553 |
+
labels,
|
| 554 |
+
vocab_size=self.config.vocab_size,
|
| 555 |
+
**kwargs,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
if not return_dict:
|
| 559 |
+
output = (prediction_scores,) + outputs[1:]
|
| 560 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
| 561 |
+
|
| 562 |
+
return CausalLMOutputWithCrossAttentions(
|
| 563 |
+
loss=lm_loss,
|
| 564 |
+
logits=prediction_scores,
|
| 565 |
+
past_key_values=outputs.past_key_values,
|
| 566 |
+
hidden_states=outputs.hidden_states,
|
| 567 |
+
attentions=outputs.attentions,
|
| 568 |
+
cross_attentions=outputs.cross_attentions,
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
@auto_docstring
|
| 573 |
+
class BioGptForTokenClassification(BioGptPreTrainedModel):
|
| 574 |
+
def __init__(self, config):
|
| 575 |
+
super().__init__(config)
|
| 576 |
+
self.num_labels = config.num_labels
|
| 577 |
+
|
| 578 |
+
self.biogpt = BioGptModel(config)
|
| 579 |
+
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
|
| 580 |
+
classifier_dropout = config.classifier_dropout
|
| 581 |
+
else:
|
| 582 |
+
classifier_dropout = config.hidden_dropout_prob
|
| 583 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 584 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 585 |
+
|
| 586 |
+
self.post_init()
|
| 587 |
+
|
| 588 |
+
@auto_docstring
|
| 589 |
+
def forward(
|
| 590 |
+
self,
|
| 591 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 592 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 593 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 594 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 595 |
+
past_key_values: Optional[Cache] = None,
|
| 596 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 597 |
+
labels: Optional[torch.LongTensor] = None,
|
| 598 |
+
use_cache: Optional[bool] = None,
|
| 599 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 600 |
+
output_attentions: Optional[bool] = None,
|
| 601 |
+
output_hidden_states: Optional[bool] = None,
|
| 602 |
+
return_dict: Optional[bool] = None,
|
| 603 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 604 |
+
) -> Union[tuple, TokenClassifierOutput]:
|
| 605 |
+
r"""
|
| 606 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 607 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 608 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 609 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 610 |
+
"""
|
| 611 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 612 |
+
|
| 613 |
+
transformer_outputs = self.biogpt(
|
| 614 |
+
input_ids,
|
| 615 |
+
past_key_values=past_key_values,
|
| 616 |
+
attention_mask=attention_mask,
|
| 617 |
+
head_mask=head_mask,
|
| 618 |
+
inputs_embeds=inputs_embeds,
|
| 619 |
+
use_cache=use_cache,
|
| 620 |
+
position_ids=position_ids,
|
| 621 |
+
output_attentions=output_attentions,
|
| 622 |
+
output_hidden_states=output_hidden_states,
|
| 623 |
+
return_dict=return_dict,
|
| 624 |
+
cache_position=cache_position,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
hidden_states = transformer_outputs[0]
|
| 628 |
+
hidden_states = self.dropout(hidden_states)
|
| 629 |
+
logits = self.classifier(hidden_states)
|
| 630 |
+
|
| 631 |
+
loss = None
|
| 632 |
+
if labels is not None:
|
| 633 |
+
loss_fct = CrossEntropyLoss()
|
| 634 |
+
# Only keep active parts of the loss
|
| 635 |
+
if attention_mask is not None:
|
| 636 |
+
active_loss = attention_mask.view(-1) == 1
|
| 637 |
+
active_logits = logits.view(-1, self.num_labels)
|
| 638 |
+
active_labels = torch.where(
|
| 639 |
+
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
| 640 |
+
)
|
| 641 |
+
loss = loss_fct(active_logits, active_labels)
|
| 642 |
+
else:
|
| 643 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 644 |
+
|
| 645 |
+
if not return_dict:
|
| 646 |
+
output = (logits,) + transformer_outputs[2:]
|
| 647 |
+
return ((loss,) + output) if loss is not None else output
|
| 648 |
+
|
| 649 |
+
return TokenClassifierOutput(
|
| 650 |
+
loss=loss,
|
| 651 |
+
logits=logits,
|
| 652 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 653 |
+
attentions=transformer_outputs.attentions,
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
@auto_docstring(
|
| 658 |
+
custom_intro="""
|
| 659 |
+
The BioGpt Model transformer with a sequence classification head on top (linear layer).
|
| 660 |
+
|
| 661 |
+
[`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
| 662 |
+
(e.g. GPT-2) do.
|
| 663 |
+
|
| 664 |
+
Since it does classification on the last token, it is required to know the position of the last token. If a
|
| 665 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
| 666 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
| 667 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
| 668 |
+
each row of the batch).
|
| 669 |
+
"""
|
| 670 |
+
)
|
| 671 |
+
class BioGptForSequenceClassification(BioGptPreTrainedModel):
|
| 672 |
+
def __init__(self, config: BioGptConfig):
|
| 673 |
+
super().__init__(config)
|
| 674 |
+
self.num_labels = config.num_labels
|
| 675 |
+
self.biogpt = BioGptModel(config)
|
| 676 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 677 |
+
|
| 678 |
+
# Initialize weights and apply final processing
|
| 679 |
+
self.post_init()
|
| 680 |
+
|
| 681 |
+
@auto_docstring
|
| 682 |
+
def forward(
|
| 683 |
+
self,
|
| 684 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 685 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 686 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 687 |
+
past_key_values: Optional[Cache] = None,
|
| 688 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 689 |
+
labels: Optional[torch.LongTensor] = None,
|
| 690 |
+
use_cache: Optional[bool] = None,
|
| 691 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 692 |
+
output_attentions: Optional[bool] = None,
|
| 693 |
+
output_hidden_states: Optional[bool] = None,
|
| 694 |
+
return_dict: Optional[bool] = None,
|
| 695 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 696 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 697 |
+
) -> Union[tuple, SequenceClassifierOutputWithPast]:
|
| 698 |
+
r"""
|
| 699 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 700 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 701 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 702 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 703 |
+
"""
|
| 704 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 705 |
+
|
| 706 |
+
transformer_outputs = self.biogpt(
|
| 707 |
+
input_ids,
|
| 708 |
+
past_key_values=past_key_values,
|
| 709 |
+
attention_mask=attention_mask,
|
| 710 |
+
head_mask=head_mask,
|
| 711 |
+
inputs_embeds=inputs_embeds,
|
| 712 |
+
use_cache=use_cache,
|
| 713 |
+
position_ids=position_ids,
|
| 714 |
+
output_attentions=output_attentions,
|
| 715 |
+
output_hidden_states=output_hidden_states,
|
| 716 |
+
return_dict=return_dict,
|
| 717 |
+
cache_position=cache_position,
|
| 718 |
+
)
|
| 719 |
+
hidden_states = transformer_outputs[0]
|
| 720 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 721 |
+
logits = self.score(hidden_states[:, slice_indices, :])
|
| 722 |
+
|
| 723 |
+
if input_ids is not None:
|
| 724 |
+
batch_size, sequence_length = input_ids.shape[:2]
|
| 725 |
+
else:
|
| 726 |
+
batch_size, sequence_length = inputs_embeds.shape[:2]
|
| 727 |
+
|
| 728 |
+
if self.config.pad_token_id is None:
|
| 729 |
+
sequence_length = -1
|
| 730 |
+
else:
|
| 731 |
+
if input_ids is not None:
|
| 732 |
+
sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
| 733 |
+
else:
|
| 734 |
+
sequence_length = -1
|
| 735 |
+
logger.warning_once(
|
| 736 |
+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
| 737 |
+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length]
|
| 741 |
+
|
| 742 |
+
loss = None
|
| 743 |
+
if labels is not None:
|
| 744 |
+
if self.config.problem_type is None:
|
| 745 |
+
if self.num_labels == 1:
|
| 746 |
+
self.config.problem_type = "regression"
|
| 747 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 748 |
+
self.config.problem_type = "single_label_classification"
|
| 749 |
+
else:
|
| 750 |
+
self.config.problem_type = "multi_label_classification"
|
| 751 |
+
|
| 752 |
+
if self.config.problem_type == "regression":
|
| 753 |
+
loss_fct = MSELoss()
|
| 754 |
+
if self.num_labels == 1:
|
| 755 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
| 756 |
+
else:
|
| 757 |
+
loss = loss_fct(pooled_logits, labels)
|
| 758 |
+
elif self.config.problem_type == "single_label_classification":
|
| 759 |
+
loss_fct = CrossEntropyLoss()
|
| 760 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
| 761 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 762 |
+
loss_fct = BCEWithLogitsLoss()
|
| 763 |
+
loss = loss_fct(pooled_logits, labels)
|
| 764 |
+
if not return_dict:
|
| 765 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
| 766 |
+
return ((loss,) + output) if loss is not None else output
|
| 767 |
+
|
| 768 |
+
return SequenceClassifierOutputWithPast(
|
| 769 |
+
loss=loss,
|
| 770 |
+
logits=pooled_logits,
|
| 771 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 772 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 773 |
+
attentions=transformer_outputs.attentions,
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
def get_input_embeddings(self):
|
| 777 |
+
return self.biogpt.embed_tokens
|
| 778 |
+
|
| 779 |
+
def set_input_embeddings(self, value):
|
| 780 |
+
self.biogpt.embed_tokens = value
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
__all__ = [
|
| 784 |
+
"BioGptForCausalLM",
|
| 785 |
+
"BioGptForTokenClassification",
|
| 786 |
+
"BioGptForSequenceClassification",
|
| 787 |
+
"BioGptModel",
|
| 788 |
+
"BioGptPreTrainedModel",
|
| 789 |
+
]
|
venv/lib/python3.13/site-packages/transformers/models/biogpt/tokenization_biogpt.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for BioGPT."""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 22 |
+
from ...utils import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
VOCAB_FILES_NAMES = {
|
| 28 |
+
"vocab_file": "vocab.json",
|
| 29 |
+
"merges_file": "merges.txt",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_pairs(word):
|
| 34 |
+
"""
|
| 35 |
+
Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
|
| 36 |
+
strings)
|
| 37 |
+
"""
|
| 38 |
+
pairs = set()
|
| 39 |
+
prev_char = word[0]
|
| 40 |
+
for char in word[1:]:
|
| 41 |
+
pairs.add((prev_char, char))
|
| 42 |
+
prev_char = char
|
| 43 |
+
return pairs
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class BioGptTokenizer(PreTrainedTokenizer):
|
| 47 |
+
"""
|
| 48 |
+
Construct an FAIRSEQ Transformer tokenizer. Moses tokenization followed by Byte-Pair Encoding.
|
| 49 |
+
|
| 50 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 51 |
+
this superclass for more information regarding those methods.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
vocab_file (`str`):
|
| 55 |
+
Path to the vocabulary file.
|
| 56 |
+
merges_file (`str`):
|
| 57 |
+
Merges file.
|
| 58 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 59 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 60 |
+
token instead.
|
| 61 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 62 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 63 |
+
|
| 64 |
+
<Tip>
|
| 65 |
+
|
| 66 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 67 |
+
sequence. The token used is the `cls_token`.
|
| 68 |
+
|
| 69 |
+
</Tip>
|
| 70 |
+
|
| 71 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 72 |
+
The end of sequence token.
|
| 73 |
+
|
| 74 |
+
<Tip>
|
| 75 |
+
|
| 76 |
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
| 77 |
+
The token used is the `sep_token`.
|
| 78 |
+
|
| 79 |
+
</Tip>
|
| 80 |
+
|
| 81 |
+
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
| 82 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 83 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 84 |
+
token of a sequence built with special tokens.
|
| 85 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 86 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 90 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 91 |
+
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
vocab_file,
|
| 95 |
+
merges_file,
|
| 96 |
+
unk_token="<unk>",
|
| 97 |
+
bos_token="<s>",
|
| 98 |
+
eos_token="</s>",
|
| 99 |
+
sep_token="</s>",
|
| 100 |
+
pad_token="<pad>",
|
| 101 |
+
**kwargs,
|
| 102 |
+
):
|
| 103 |
+
try:
|
| 104 |
+
import sacremoses
|
| 105 |
+
except ImportError:
|
| 106 |
+
raise ImportError(
|
| 107 |
+
"You need to install sacremoses to use BioGptTokenizer. "
|
| 108 |
+
"See https://pypi.org/project/sacremoses/ for installation."
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.lang = "en"
|
| 112 |
+
self.sm = sacremoses
|
| 113 |
+
# cache of sm.MosesTokenizer instance
|
| 114 |
+
self.cache_moses_tokenizer = {}
|
| 115 |
+
self.cache_moses_detokenizer = {}
|
| 116 |
+
|
| 117 |
+
""" Initialisation"""
|
| 118 |
+
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
| 119 |
+
self.encoder = json.load(vocab_handle)
|
| 120 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 121 |
+
with open(merges_file, encoding="utf-8") as merges_handle:
|
| 122 |
+
merges = merges_handle.read().split("\n")[:-1]
|
| 123 |
+
merges = [tuple(merge.split()[:2]) for merge in merges]
|
| 124 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 125 |
+
self.cache = {}
|
| 126 |
+
|
| 127 |
+
super().__init__(
|
| 128 |
+
bos_token=bos_token,
|
| 129 |
+
eos_token=eos_token,
|
| 130 |
+
sep_token=sep_token,
|
| 131 |
+
unk_token=unk_token,
|
| 132 |
+
pad_token=pad_token,
|
| 133 |
+
**kwargs,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
@property
|
| 137 |
+
def vocab_size(self):
|
| 138 |
+
"""Returns vocab size"""
|
| 139 |
+
return len(self.encoder)
|
| 140 |
+
|
| 141 |
+
def get_vocab(self):
|
| 142 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
| 143 |
+
|
| 144 |
+
def moses_tokenize(self, text, lang):
|
| 145 |
+
if lang not in self.cache_moses_tokenizer:
|
| 146 |
+
moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
|
| 147 |
+
self.cache_moses_tokenizer[lang] = moses_tokenizer
|
| 148 |
+
return self.cache_moses_tokenizer[lang].tokenize(
|
| 149 |
+
text, aggressive_dash_splits=True, return_str=False, escape=True
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def moses_detokenize(self, tokens, lang):
|
| 153 |
+
if lang not in self.cache_moses_detokenizer:
|
| 154 |
+
moses_detokenizer = self.sm.MosesDetokenizer(lang=lang)
|
| 155 |
+
self.cache_moses_detokenizer[lang] = moses_detokenizer
|
| 156 |
+
return self.cache_moses_detokenizer[lang].detokenize(tokens)
|
| 157 |
+
|
| 158 |
+
def bpe(self, token):
|
| 159 |
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
| 160 |
+
if token in self.cache:
|
| 161 |
+
return self.cache[token]
|
| 162 |
+
pairs = get_pairs(word)
|
| 163 |
+
|
| 164 |
+
if not pairs:
|
| 165 |
+
return token + "</w>"
|
| 166 |
+
|
| 167 |
+
while True:
|
| 168 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
| 169 |
+
if bigram not in self.bpe_ranks:
|
| 170 |
+
break
|
| 171 |
+
first, second = bigram
|
| 172 |
+
new_word = []
|
| 173 |
+
i = 0
|
| 174 |
+
while i < len(word):
|
| 175 |
+
try:
|
| 176 |
+
j = word.index(first, i)
|
| 177 |
+
except ValueError:
|
| 178 |
+
new_word.extend(word[i:])
|
| 179 |
+
break
|
| 180 |
+
else:
|
| 181 |
+
new_word.extend(word[i:j])
|
| 182 |
+
i = j
|
| 183 |
+
|
| 184 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 185 |
+
new_word.append(first + second)
|
| 186 |
+
i += 2
|
| 187 |
+
else:
|
| 188 |
+
new_word.append(word[i])
|
| 189 |
+
i += 1
|
| 190 |
+
new_word = tuple(new_word)
|
| 191 |
+
word = new_word
|
| 192 |
+
if len(word) == 1:
|
| 193 |
+
break
|
| 194 |
+
else:
|
| 195 |
+
pairs = get_pairs(word)
|
| 196 |
+
word = " ".join(word)
|
| 197 |
+
if word == "\n </w>":
|
| 198 |
+
word = "\n</w>"
|
| 199 |
+
self.cache[token] = word
|
| 200 |
+
return word
|
| 201 |
+
|
| 202 |
+
def _tokenize(self, text, bypass_tokenizer=False):
|
| 203 |
+
"""Returns a tokenized string."""
|
| 204 |
+
if bypass_tokenizer:
|
| 205 |
+
text = text.split()
|
| 206 |
+
else:
|
| 207 |
+
text = self.moses_tokenize(text, self.lang)
|
| 208 |
+
|
| 209 |
+
split_tokens = []
|
| 210 |
+
for token in text:
|
| 211 |
+
if token:
|
| 212 |
+
split_tokens.extend(list(self.bpe(token).split(" ")))
|
| 213 |
+
|
| 214 |
+
return split_tokens
|
| 215 |
+
|
| 216 |
+
def _convert_token_to_id(self, token):
|
| 217 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 218 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
| 219 |
+
|
| 220 |
+
def _convert_id_to_token(self, index):
|
| 221 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 222 |
+
return self.decoder.get(index, self.unk_token)
|
| 223 |
+
|
| 224 |
+
def convert_tokens_to_string(self, tokens):
|
| 225 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 226 |
+
# remove BPE
|
| 227 |
+
tokens = [t.replace(" ", "").replace("</w>", " ") for t in tokens]
|
| 228 |
+
tokens = "".join(tokens).split()
|
| 229 |
+
# detokenize
|
| 230 |
+
text = self.moses_detokenize(tokens, self.lang)
|
| 231 |
+
return text
|
| 232 |
+
|
| 233 |
+
def build_inputs_with_special_tokens(
|
| 234 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 235 |
+
) -> list[int]:
|
| 236 |
+
"""
|
| 237 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 238 |
+
adding special tokens. A BioGPT sequence has the following format:
|
| 239 |
+
|
| 240 |
+
- single sequence: `</s> X `
|
| 241 |
+
- pair of sequences: `</s> A </s> B `
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
token_ids_0 (`List[int]`):
|
| 245 |
+
List of IDs to which the special tokens will be added.
|
| 246 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 247 |
+
Optional second list of IDs for sequence pairs.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 251 |
+
"""
|
| 252 |
+
if token_ids_1 is None:
|
| 253 |
+
return [self.sep_token_id] + token_ids_0
|
| 254 |
+
sep = [self.sep_token_id]
|
| 255 |
+
return sep + token_ids_0 + sep + token_ids_1
|
| 256 |
+
|
| 257 |
+
def get_special_tokens_mask(
|
| 258 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
|
| 259 |
+
) -> list[int]:
|
| 260 |
+
"""
|
| 261 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 262 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
token_ids_0 (`List[int]`):
|
| 266 |
+
List of IDs.
|
| 267 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 268 |
+
Optional second list of IDs for sequence pairs.
|
| 269 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 270 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 274 |
+
"""
|
| 275 |
+
if already_has_special_tokens:
|
| 276 |
+
return super().get_special_tokens_mask(
|
| 277 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 278 |
+
)
|
| 279 |
+
# no bos used in fairseq
|
| 280 |
+
if token_ids_1 is not None:
|
| 281 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
|
| 282 |
+
return [1] + ([0] * len(token_ids_0))
|
| 283 |
+
|
| 284 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
|
| 285 |
+
if not os.path.isdir(save_directory):
|
| 286 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 287 |
+
return
|
| 288 |
+
vocab_file = os.path.join(
|
| 289 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 290 |
+
)
|
| 291 |
+
merge_file = os.path.join(
|
| 292 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 296 |
+
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
| 297 |
+
|
| 298 |
+
index = 0
|
| 299 |
+
with open(merge_file, "w", encoding="utf-8") as writer:
|
| 300 |
+
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
| 301 |
+
if index != token_index:
|
| 302 |
+
logger.warning(
|
| 303 |
+
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
| 304 |
+
" Please check that the tokenizer is not corrupted!"
|
| 305 |
+
)
|
| 306 |
+
index = token_index
|
| 307 |
+
writer.write(" ".join(bpe_tokens) + "\n")
|
| 308 |
+
index += 1
|
| 309 |
+
|
| 310 |
+
return vocab_file, merge_file
|
| 311 |
+
|
| 312 |
+
def __getstate__(self):
|
| 313 |
+
state = self.__dict__.copy()
|
| 314 |
+
state["sm"] = None
|
| 315 |
+
return state
|
| 316 |
+
|
| 317 |
+
def __setstate__(self, d):
|
| 318 |
+
self.__dict__ = d
|
| 319 |
+
|
| 320 |
+
try:
|
| 321 |
+
import sacremoses
|
| 322 |
+
except ImportError:
|
| 323 |
+
raise ImportError(
|
| 324 |
+
"You need to install sacremoses to use XLMTokenizer. "
|
| 325 |
+
"See https://pypi.org/project/sacremoses/ for installation."
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
self.sm = sacremoses
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
__all__ = ["BioGptTokenizer"]
|
venv/lib/python3.13/site-packages/transformers/models/bit/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_bit import *
|
| 22 |
+
from .image_processing_bit import *
|
| 23 |
+
from .image_processing_bit_fast import *
|
| 24 |
+
from .modeling_bit import *
|
| 25 |
+
else:
|
| 26 |
+
import sys
|
| 27 |
+
|
| 28 |
+
_file = globals()["__file__"]
|
| 29 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|