dheeena commited on
Commit
91b3b8c
·
verified ·
1 Parent(s): 52197b1

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. venv/lib/python3.13/site-packages/transformers/models/albert/__init__.py +31 -0
  2. venv/lib/python3.13/site-packages/transformers/models/albert/configuration_albert.py +170 -0
  3. venv/lib/python3.13/site-packages/transformers/models/albert/modeling_albert.py +1349 -0
  4. venv/lib/python3.13/site-packages/transformers/models/albert/modeling_flax_albert.py +1132 -0
  5. venv/lib/python3.13/site-packages/transformers/models/albert/modeling_tf_albert.py +1572 -0
  6. venv/lib/python3.13/site-packages/transformers/models/albert/tokenization_albert.py +320 -0
  7. venv/lib/python3.13/site-packages/transformers/models/albert/tokenization_albert_fast.py +178 -0
  8. venv/lib/python3.13/site-packages/transformers/models/apertus/__init__.py +32 -0
  9. venv/lib/python3.13/site-packages/transformers/models/apertus/configuration_apertus.py +214 -0
  10. venv/lib/python3.13/site-packages/transformers/models/apertus/modeling_apertus.py +488 -0
  11. venv/lib/python3.13/site-packages/transformers/models/apertus/modular_apertus.py +371 -0
  12. venv/lib/python3.13/site-packages/transformers/models/arcee/__init__.py +27 -0
  13. venv/lib/python3.13/site-packages/transformers/models/arcee/configuration_arcee.py +201 -0
  14. venv/lib/python3.13/site-packages/transformers/models/arcee/modeling_arcee.py +506 -0
  15. venv/lib/python3.13/site-packages/transformers/models/arcee/modular_arcee.py +225 -0
  16. venv/lib/python3.13/site-packages/transformers/models/aria/__init__.py +30 -0
  17. venv/lib/python3.13/site-packages/transformers/models/aria/configuration_aria.py +307 -0
  18. venv/lib/python3.13/site-packages/transformers/models/aria/image_processing_aria.py +527 -0
  19. venv/lib/python3.13/site-packages/transformers/models/aria/modeling_aria.py +1275 -0
  20. venv/lib/python3.13/site-packages/transformers/models/aria/modular_aria.py +1610 -0
  21. venv/lib/python3.13/site-packages/transformers/models/aria/processing_aria.py +189 -0
  22. venv/lib/python3.13/site-packages/transformers/models/auto/__init__.py +35 -0
  23. venv/lib/python3.13/site-packages/transformers/models/auto/auto_factory.py +882 -0
  24. venv/lib/python3.13/site-packages/transformers/models/auto/configuration_auto.py +1404 -0
  25. venv/lib/python3.13/site-packages/transformers/models/auto/feature_extraction_auto.py +422 -0
  26. venv/lib/python3.13/site-packages/transformers/models/auto/image_processing_auto.py +688 -0
  27. venv/lib/python3.13/site-packages/transformers/models/auto/modeling_auto.py +0 -0
  28. venv/lib/python3.13/site-packages/transformers/models/auto/modeling_flax_auto.py +413 -0
  29. venv/lib/python3.13/site-packages/transformers/models/auto/modeling_tf_auto.py +776 -0
  30. venv/lib/python3.13/site-packages/transformers/models/auto/processing_auto.py +443 -0
  31. venv/lib/python3.13/site-packages/transformers/models/auto/tokenization_auto.py +1235 -0
  32. venv/lib/python3.13/site-packages/transformers/models/auto/video_processing_auto.py +393 -0
  33. venv/lib/python3.13/site-packages/transformers/models/aya_vision/__init__.py +28 -0
  34. venv/lib/python3.13/site-packages/transformers/models/aya_vision/configuration_aya_vision.py +110 -0
  35. venv/lib/python3.13/site-packages/transformers/models/aya_vision/modeling_aya_vision.py +518 -0
  36. venv/lib/python3.13/site-packages/transformers/models/aya_vision/modular_aya_vision.py +297 -0
  37. venv/lib/python3.13/site-packages/transformers/models/aya_vision/processing_aya_vision.py +257 -0
  38. venv/lib/python3.13/site-packages/transformers/models/barthez/__init__.py +27 -0
  39. venv/lib/python3.13/site-packages/transformers/models/barthez/tokenization_barthez.py +291 -0
  40. venv/lib/python3.13/site-packages/transformers/models/barthez/tokenization_barthez_fast.py +193 -0
  41. venv/lib/python3.13/site-packages/transformers/models/bert_japanese/__init__.py +26 -0
  42. venv/lib/python3.13/site-packages/transformers/models/bert_japanese/tokenization_bert_japanese.py +952 -0
  43. venv/lib/python3.13/site-packages/transformers/models/bertweet/__init__.py +26 -0
  44. venv/lib/python3.13/site-packages/transformers/models/bertweet/tokenization_bertweet.py +769 -0
  45. venv/lib/python3.13/site-packages/transformers/models/biogpt/__init__.py +28 -0
  46. venv/lib/python3.13/site-packages/transformers/models/biogpt/configuration_biogpt.py +134 -0
  47. venv/lib/python3.13/site-packages/transformers/models/biogpt/modeling_biogpt.py +967 -0
  48. venv/lib/python3.13/site-packages/transformers/models/biogpt/modular_biogpt.py +789 -0
  49. venv/lib/python3.13/site-packages/transformers/models/biogpt/tokenization_biogpt.py +331 -0
  50. venv/lib/python3.13/site-packages/transformers/models/bit/__init__.py +29 -0
venv/lib/python3.13/site-packages/transformers/models/albert/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_albert import *
22
+ from .modeling_albert import *
23
+ from .modeling_flax_albert import *
24
+ from .modeling_tf_albert import *
25
+ from .tokenization_albert import *
26
+ from .tokenization_albert_fast import *
27
+ else:
28
+ import sys
29
+
30
+ _file = globals()["__file__"]
31
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
venv/lib/python3.13/site-packages/transformers/models/albert/configuration_albert.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ALBERT model configuration"""
17
+
18
+ from collections import OrderedDict
19
+ from collections.abc import Mapping
20
+
21
+ from ...configuration_utils import PretrainedConfig
22
+ from ...onnx import OnnxConfig
23
+
24
+
25
+ class AlbertConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`AlbertModel`] or a [`TFAlbertModel`]. It is used
28
+ to instantiate an ALBERT model according to the specified arguments, defining the model architecture. Instantiating
29
+ a configuration with the defaults will yield a similar configuration to that of the ALBERT
30
+ [albert/albert-xxlarge-v2](https://huggingface.co/albert/albert-xxlarge-v2) architecture.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 30000):
37
+ Vocabulary size of the ALBERT model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`AlbertModel`] or [`TFAlbertModel`].
39
+ embedding_size (`int`, *optional*, defaults to 128):
40
+ Dimensionality of vocabulary embeddings.
41
+ hidden_size (`int`, *optional*, defaults to 4096):
42
+ Dimensionality of the encoder layers and the pooler layer.
43
+ num_hidden_layers (`int`, *optional*, defaults to 12):
44
+ Number of hidden layers in the Transformer encoder.
45
+ num_hidden_groups (`int`, *optional*, defaults to 1):
46
+ Number of groups for the hidden layers, parameters in the same group are shared.
47
+ num_attention_heads (`int`, *optional*, defaults to 64):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ intermediate_size (`int`, *optional*, defaults to 16384):
50
+ The dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
51
+ inner_group_num (`int`, *optional*, defaults to 1):
52
+ The number of inner repetition of attention and ffn.
53
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu_new"`):
54
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
55
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
56
+ hidden_dropout_prob (`float`, *optional*, defaults to 0):
57
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
58
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0):
59
+ The dropout ratio for the attention probabilities.
60
+ max_position_embeddings (`int`, *optional*, defaults to 512):
61
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
62
+ (e.g., 512 or 1024 or 2048).
63
+ type_vocab_size (`int`, *optional*, defaults to 2):
64
+ The vocabulary size of the `token_type_ids` passed when calling [`AlbertModel`] or [`TFAlbertModel`].
65
+ initializer_range (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
68
+ The epsilon used by the layer normalization layers.
69
+ classifier_dropout_prob (`float`, *optional*, defaults to 0.1):
70
+ The dropout ratio for attached classifiers.
71
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
72
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
73
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
74
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
75
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
76
+ with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
77
+ pad_token_id (`int`, *optional*, defaults to 0):
78
+ Padding token id.
79
+ bos_token_id (`int`, *optional*, defaults to 2):
80
+ Beginning of stream token id.
81
+ eos_token_id (`int`, *optional*, defaults to 3):
82
+ End of stream token id.
83
+
84
+ Examples:
85
+
86
+ ```python
87
+ >>> from transformers import AlbertConfig, AlbertModel
88
+
89
+ >>> # Initializing an ALBERT-xxlarge style configuration
90
+ >>> albert_xxlarge_configuration = AlbertConfig()
91
+
92
+ >>> # Initializing an ALBERT-base style configuration
93
+ >>> albert_base_configuration = AlbertConfig(
94
+ ... hidden_size=768,
95
+ ... num_attention_heads=12,
96
+ ... intermediate_size=3072,
97
+ ... )
98
+
99
+ >>> # Initializing a model (with random weights) from the ALBERT-base style configuration
100
+ >>> model = AlbertModel(albert_xxlarge_configuration)
101
+
102
+ >>> # Accessing the model configuration
103
+ >>> configuration = model.config
104
+ ```"""
105
+
106
+ model_type = "albert"
107
+
108
+ def __init__(
109
+ self,
110
+ vocab_size=30000,
111
+ embedding_size=128,
112
+ hidden_size=4096,
113
+ num_hidden_layers=12,
114
+ num_hidden_groups=1,
115
+ num_attention_heads=64,
116
+ intermediate_size=16384,
117
+ inner_group_num=1,
118
+ hidden_act="gelu_new",
119
+ hidden_dropout_prob=0,
120
+ attention_probs_dropout_prob=0,
121
+ max_position_embeddings=512,
122
+ type_vocab_size=2,
123
+ initializer_range=0.02,
124
+ layer_norm_eps=1e-12,
125
+ classifier_dropout_prob=0.1,
126
+ position_embedding_type="absolute",
127
+ pad_token_id=0,
128
+ bos_token_id=2,
129
+ eos_token_id=3,
130
+ **kwargs,
131
+ ):
132
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
133
+
134
+ self.vocab_size = vocab_size
135
+ self.embedding_size = embedding_size
136
+ self.hidden_size = hidden_size
137
+ self.num_hidden_layers = num_hidden_layers
138
+ self.num_hidden_groups = num_hidden_groups
139
+ self.num_attention_heads = num_attention_heads
140
+ self.inner_group_num = inner_group_num
141
+ self.hidden_act = hidden_act
142
+ self.intermediate_size = intermediate_size
143
+ self.hidden_dropout_prob = hidden_dropout_prob
144
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
145
+ self.max_position_embeddings = max_position_embeddings
146
+ self.type_vocab_size = type_vocab_size
147
+ self.initializer_range = initializer_range
148
+ self.layer_norm_eps = layer_norm_eps
149
+ self.classifier_dropout_prob = classifier_dropout_prob
150
+ self.position_embedding_type = position_embedding_type
151
+
152
+
153
+ # Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert
154
+ class AlbertOnnxConfig(OnnxConfig):
155
+ @property
156
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
157
+ if self.task == "multiple-choice":
158
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
159
+ else:
160
+ dynamic_axis = {0: "batch", 1: "sequence"}
161
+ return OrderedDict(
162
+ [
163
+ ("input_ids", dynamic_axis),
164
+ ("attention_mask", dynamic_axis),
165
+ ("token_type_ids", dynamic_axis),
166
+ ]
167
+ )
168
+
169
+
170
+ __all__ = ["AlbertConfig", "AlbertOnnxConfig"]
venv/lib/python3.13/site-packages/transformers/models/albert/modeling_albert.py ADDED
@@ -0,0 +1,1349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch ALBERT model."""
16
+
17
+ import math
18
+ import os
19
+ from dataclasses import dataclass
20
+ from typing import Optional, Union
21
+
22
+ import torch
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from ...activations import ACT2FN
27
+ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
28
+ from ...modeling_outputs import (
29
+ BaseModelOutput,
30
+ BaseModelOutputWithPooling,
31
+ MaskedLMOutput,
32
+ MultipleChoiceModelOutput,
33
+ QuestionAnsweringModelOutput,
34
+ SequenceClassifierOutput,
35
+ TokenClassifierOutput,
36
+ )
37
+ from ...modeling_utils import PreTrainedModel
38
+ from ...pytorch_utils import (
39
+ apply_chunking_to_forward,
40
+ find_pruneable_heads_and_indices,
41
+ prune_linear_layer,
42
+ )
43
+ from ...utils import ModelOutput, auto_docstring, logging
44
+ from .configuration_albert import AlbertConfig
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
51
+ """Load tf checkpoints in a pytorch model."""
52
+ try:
53
+ import re
54
+
55
+ import numpy as np
56
+ import tensorflow as tf
57
+ except ImportError:
58
+ logger.error(
59
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
60
+ "https://www.tensorflow.org/install/ for installation instructions."
61
+ )
62
+ raise
63
+ tf_path = os.path.abspath(tf_checkpoint_path)
64
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
65
+ # Load weights from TF model
66
+ init_vars = tf.train.list_variables(tf_path)
67
+ names = []
68
+ arrays = []
69
+ for name, shape in init_vars:
70
+ logger.info(f"Loading TF weight {name} with shape {shape}")
71
+ array = tf.train.load_variable(tf_path, name)
72
+ names.append(name)
73
+ arrays.append(array)
74
+
75
+ for name, array in zip(names, arrays):
76
+ print(name)
77
+
78
+ for name, array in zip(names, arrays):
79
+ original_name = name
80
+
81
+ # If saved from the TF HUB module
82
+ name = name.replace("module/", "")
83
+
84
+ # Renaming and simplifying
85
+ name = name.replace("ffn_1", "ffn")
86
+ name = name.replace("bert/", "albert/")
87
+ name = name.replace("attention_1", "attention")
88
+ name = name.replace("transform/", "")
89
+ name = name.replace("LayerNorm_1", "full_layer_layer_norm")
90
+ name = name.replace("LayerNorm", "attention/LayerNorm")
91
+ name = name.replace("transformer/", "")
92
+
93
+ # The feed forward layer had an 'intermediate' step which has been abstracted away
94
+ name = name.replace("intermediate/dense/", "")
95
+ name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")
96
+
97
+ # ALBERT attention was split between self and output which have been abstracted away
98
+ name = name.replace("/output/", "/")
99
+ name = name.replace("/self/", "/")
100
+
101
+ # The pooler is a linear layer
102
+ name = name.replace("pooler/dense", "pooler")
103
+
104
+ # The classifier was simplified to predictions from cls/predictions
105
+ name = name.replace("cls/predictions", "predictions")
106
+ name = name.replace("predictions/attention", "predictions")
107
+
108
+ # Naming was changed to be more explicit
109
+ name = name.replace("embeddings/attention", "embeddings")
110
+ name = name.replace("inner_group_", "albert_layers/")
111
+ name = name.replace("group_", "albert_layer_groups/")
112
+
113
+ # Classifier
114
+ if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
115
+ name = "classifier/" + name
116
+
117
+ # No ALBERT model currently handles the next sentence prediction task
118
+ if "seq_relationship" in name:
119
+ name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
120
+ name = name.replace("weights", "weight")
121
+
122
+ name = name.split("/")
123
+
124
+ # Ignore the gradients applied by the LAMB/ADAM optimizers.
125
+ if (
126
+ "adam_m" in name
127
+ or "adam_v" in name
128
+ or "AdamWeightDecayOptimizer" in name
129
+ or "AdamWeightDecayOptimizer_1" in name
130
+ or "global_step" in name
131
+ ):
132
+ logger.info(f"Skipping {'/'.join(name)}")
133
+ continue
134
+
135
+ pointer = model
136
+ for m_name in name:
137
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
138
+ scope_names = re.split(r"_(\d+)", m_name)
139
+ else:
140
+ scope_names = [m_name]
141
+
142
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
143
+ pointer = getattr(pointer, "weight")
144
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
145
+ pointer = getattr(pointer, "bias")
146
+ elif scope_names[0] == "output_weights":
147
+ pointer = getattr(pointer, "weight")
148
+ elif scope_names[0] == "squad":
149
+ pointer = getattr(pointer, "classifier")
150
+ else:
151
+ try:
152
+ pointer = getattr(pointer, scope_names[0])
153
+ except AttributeError:
154
+ logger.info(f"Skipping {'/'.join(name)}")
155
+ continue
156
+ if len(scope_names) >= 2:
157
+ num = int(scope_names[1])
158
+ pointer = pointer[num]
159
+
160
+ if m_name[-11:] == "_embeddings":
161
+ pointer = getattr(pointer, "weight")
162
+ elif m_name == "kernel":
163
+ array = np.transpose(array)
164
+ try:
165
+ if pointer.shape != array.shape:
166
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
167
+ except ValueError as e:
168
+ e.args += (pointer.shape, array.shape)
169
+ raise
170
+ print(f"Initialize PyTorch weight {name} from {original_name}")
171
+ pointer.data = torch.from_numpy(array)
172
+
173
+ return model
174
+
175
+
176
+ class AlbertEmbeddings(nn.Module):
177
+ """
178
+ Construct the embeddings from word, position and token_type embeddings.
179
+ """
180
+
181
+ def __init__(self, config: AlbertConfig):
182
+ super().__init__()
183
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
184
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
185
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
186
+
187
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
188
+ # any TensorFlow checkpoint file
189
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
190
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
191
+
192
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
193
+ self.register_buffer(
194
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
195
+ )
196
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
197
+ self.register_buffer(
198
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
199
+ )
200
+
201
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
202
+ def forward(
203
+ self,
204
+ input_ids: Optional[torch.LongTensor] = None,
205
+ token_type_ids: Optional[torch.LongTensor] = None,
206
+ position_ids: Optional[torch.LongTensor] = None,
207
+ inputs_embeds: Optional[torch.FloatTensor] = None,
208
+ past_key_values_length: int = 0,
209
+ ) -> torch.Tensor:
210
+ if input_ids is not None:
211
+ input_shape = input_ids.size()
212
+ else:
213
+ input_shape = inputs_embeds.size()[:-1]
214
+
215
+ seq_length = input_shape[1]
216
+
217
+ if position_ids is None:
218
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
219
+
220
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
221
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
222
+ # issue #5664
223
+ if token_type_ids is None:
224
+ if hasattr(self, "token_type_ids"):
225
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
226
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
227
+ token_type_ids = buffered_token_type_ids_expanded
228
+ else:
229
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
230
+
231
+ if inputs_embeds is None:
232
+ inputs_embeds = self.word_embeddings(input_ids)
233
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
234
+
235
+ embeddings = inputs_embeds + token_type_embeddings
236
+ if self.position_embedding_type == "absolute":
237
+ position_embeddings = self.position_embeddings(position_ids)
238
+ embeddings += position_embeddings
239
+ embeddings = self.LayerNorm(embeddings)
240
+ embeddings = self.dropout(embeddings)
241
+ return embeddings
242
+
243
+
244
+ class AlbertAttention(nn.Module):
245
+ def __init__(self, config: AlbertConfig):
246
+ super().__init__()
247
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
248
+ raise ValueError(
249
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
250
+ f"heads ({config.num_attention_heads}"
251
+ )
252
+
253
+ self.num_attention_heads = config.num_attention_heads
254
+ self.hidden_size = config.hidden_size
255
+ self.attention_head_size = config.hidden_size // config.num_attention_heads
256
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
257
+
258
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
259
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
260
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
261
+
262
+ self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
263
+ self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
264
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
265
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
266
+ self.pruned_heads = set()
267
+
268
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
269
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
270
+ self.max_position_embeddings = config.max_position_embeddings
271
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
272
+
273
+ def prune_heads(self, heads: list[int]) -> None:
274
+ if len(heads) == 0:
275
+ return
276
+ heads, index = find_pruneable_heads_and_indices(
277
+ heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
278
+ )
279
+
280
+ # Prune linear layers
281
+ self.query = prune_linear_layer(self.query, index)
282
+ self.key = prune_linear_layer(self.key, index)
283
+ self.value = prune_linear_layer(self.value, index)
284
+ self.dense = prune_linear_layer(self.dense, index, dim=1)
285
+
286
+ # Update hyper params and store pruned heads
287
+ self.num_attention_heads = self.num_attention_heads - len(heads)
288
+ self.all_head_size = self.attention_head_size * self.num_attention_heads
289
+ self.pruned_heads = self.pruned_heads.union(heads)
290
+
291
+ def forward(
292
+ self,
293
+ hidden_states: torch.Tensor,
294
+ attention_mask: Optional[torch.FloatTensor] = None,
295
+ head_mask: Optional[torch.FloatTensor] = None,
296
+ output_attentions: bool = False,
297
+ ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
298
+ batch_size, seq_length, _ = hidden_states.shape
299
+ query_layer = self.query(hidden_states)
300
+ key_layer = self.key(hidden_states)
301
+ value_layer = self.value(hidden_states)
302
+ query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
303
+ 1, 2
304
+ )
305
+ key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
306
+ value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
307
+ 1, 2
308
+ )
309
+
310
+ # Take the dot product between "query" and "key" to get the raw attention scores.
311
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
312
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
313
+
314
+ if attention_mask is not None:
315
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
316
+ attention_scores = attention_scores + attention_mask
317
+
318
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
319
+ seq_length = hidden_states.size()[1]
320
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
321
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
322
+ distance = position_ids_l - position_ids_r
323
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
324
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
325
+
326
+ if self.position_embedding_type == "relative_key":
327
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
328
+ attention_scores = attention_scores + relative_position_scores
329
+ elif self.position_embedding_type == "relative_key_query":
330
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
331
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
332
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
333
+
334
+ # Normalize the attention scores to probabilities.
335
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
336
+
337
+ # This is actually dropping out entire tokens to attend to, which might
338
+ # seem a bit unusual, but is taken from the original Transformer paper.
339
+ attention_probs = self.attention_dropout(attention_probs)
340
+
341
+ # Mask heads if we want to
342
+ if head_mask is not None:
343
+ attention_probs = attention_probs * head_mask
344
+
345
+ context_layer = torch.matmul(attention_probs, value_layer)
346
+ context_layer = context_layer.transpose(2, 1).flatten(2)
347
+
348
+ projected_context_layer = self.dense(context_layer)
349
+ projected_context_layer_dropout = self.output_dropout(projected_context_layer)
350
+ layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
351
+ return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
352
+
353
+
354
+ class AlbertSdpaAttention(AlbertAttention):
355
+ def __init__(self, config):
356
+ super().__init__(config)
357
+ self.dropout_prob = config.attention_probs_dropout_prob
358
+
359
+ def forward(
360
+ self,
361
+ hidden_states: torch.Tensor,
362
+ attention_mask: Optional[torch.FloatTensor] = None,
363
+ head_mask: Optional[torch.FloatTensor] = None,
364
+ output_attentions: bool = False,
365
+ ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
366
+ if self.position_embedding_type != "absolute" or output_attentions:
367
+ logger.warning(
368
+ "AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
369
+ "non-absolute `position_embedding_type` or `output_attentions=True` . Falling back to "
370
+ "the eager attention implementation, but specifying the eager implementation will be required from "
371
+ "Transformers version v5.0.0 onwards. This warning can be removed using the argument "
372
+ '`attn_implementation="eager"` when loading the model.'
373
+ )
374
+ return super().forward(hidden_states, attention_mask, output_attentions=output_attentions)
375
+
376
+ batch_size, seq_len, _ = hidden_states.size()
377
+ query_layer = (
378
+ self.query(hidden_states)
379
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
380
+ .transpose(1, 2)
381
+ )
382
+ key_layer = (
383
+ self.key(hidden_states)
384
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
385
+ .transpose(1, 2)
386
+ )
387
+ value_layer = (
388
+ self.value(hidden_states)
389
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
390
+ .transpose(1, 2)
391
+ )
392
+
393
+ attention_output = torch.nn.functional.scaled_dot_product_attention(
394
+ query=query_layer,
395
+ key=key_layer,
396
+ value=value_layer,
397
+ attn_mask=attention_mask,
398
+ dropout_p=self.dropout_prob if self.training else 0.0,
399
+ is_causal=False,
400
+ )
401
+
402
+ attention_output = attention_output.transpose(1, 2)
403
+ attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size)
404
+
405
+ projected_context_layer = self.dense(attention_output)
406
+ projected_context_layer_dropout = self.output_dropout(projected_context_layer)
407
+ layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
408
+ return (layernormed_context_layer,)
409
+
410
+
411
+ ALBERT_ATTENTION_CLASSES = {
412
+ "eager": AlbertAttention,
413
+ "sdpa": AlbertSdpaAttention,
414
+ }
415
+
416
+
417
+ class AlbertLayer(nn.Module):
418
+ def __init__(self, config: AlbertConfig):
419
+ super().__init__()
420
+
421
+ self.config = config
422
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
423
+ self.seq_len_dim = 1
424
+ self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
425
+ self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config)
426
+ self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
427
+ self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
428
+ self.activation = ACT2FN[config.hidden_act]
429
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
430
+
431
+ def forward(
432
+ self,
433
+ hidden_states: torch.Tensor,
434
+ attention_mask: Optional[torch.FloatTensor] = None,
435
+ head_mask: Optional[torch.FloatTensor] = None,
436
+ output_attentions: bool = False,
437
+ output_hidden_states: bool = False,
438
+ ) -> tuple[torch.Tensor, torch.Tensor]:
439
+ attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
440
+
441
+ ffn_output = apply_chunking_to_forward(
442
+ self.ff_chunk,
443
+ self.chunk_size_feed_forward,
444
+ self.seq_len_dim,
445
+ attention_output[0],
446
+ )
447
+ hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
448
+
449
+ return (hidden_states,) + attention_output[1:] # add attentions if we output them
450
+
451
+ def ff_chunk(self, attention_output: torch.Tensor) -> torch.Tensor:
452
+ ffn_output = self.ffn(attention_output)
453
+ ffn_output = self.activation(ffn_output)
454
+ ffn_output = self.ffn_output(ffn_output)
455
+ return ffn_output
456
+
457
+
458
+ class AlbertLayerGroup(nn.Module):
459
+ def __init__(self, config: AlbertConfig):
460
+ super().__init__()
461
+
462
+ self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
463
+
464
+ def forward(
465
+ self,
466
+ hidden_states: torch.Tensor,
467
+ attention_mask: Optional[torch.FloatTensor] = None,
468
+ head_mask: Optional[torch.FloatTensor] = None,
469
+ output_attentions: bool = False,
470
+ output_hidden_states: bool = False,
471
+ ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
472
+ layer_hidden_states = ()
473
+ layer_attentions = ()
474
+
475
+ for layer_index, albert_layer in enumerate(self.albert_layers):
476
+ layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
477
+ hidden_states = layer_output[0]
478
+
479
+ if output_attentions:
480
+ layer_attentions = layer_attentions + (layer_output[1],)
481
+
482
+ if output_hidden_states:
483
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
484
+
485
+ outputs = (hidden_states,)
486
+ if output_hidden_states:
487
+ outputs = outputs + (layer_hidden_states,)
488
+ if output_attentions:
489
+ outputs = outputs + (layer_attentions,)
490
+ return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
491
+
492
+
493
+ class AlbertTransformer(nn.Module):
494
+ def __init__(self, config: AlbertConfig):
495
+ super().__init__()
496
+
497
+ self.config = config
498
+ self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
499
+ self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
500
+
501
+ def forward(
502
+ self,
503
+ hidden_states: torch.Tensor,
504
+ attention_mask: Optional[torch.FloatTensor] = None,
505
+ head_mask: Optional[torch.FloatTensor] = None,
506
+ output_attentions: bool = False,
507
+ output_hidden_states: bool = False,
508
+ return_dict: bool = True,
509
+ ) -> Union[BaseModelOutput, tuple]:
510
+ hidden_states = self.embedding_hidden_mapping_in(hidden_states)
511
+
512
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
513
+ all_attentions = () if output_attentions else None
514
+
515
+ head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask
516
+
517
+ for i in range(self.config.num_hidden_layers):
518
+ # Number of layers in a hidden group
519
+ layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
520
+
521
+ # Index of the hidden group
522
+ group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
523
+
524
+ layer_group_output = self.albert_layer_groups[group_idx](
525
+ hidden_states,
526
+ attention_mask,
527
+ head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
528
+ output_attentions,
529
+ output_hidden_states,
530
+ )
531
+ hidden_states = layer_group_output[0]
532
+
533
+ if output_attentions:
534
+ all_attentions = all_attentions + layer_group_output[-1]
535
+
536
+ if output_hidden_states:
537
+ all_hidden_states = all_hidden_states + (hidden_states,)
538
+
539
+ if not return_dict:
540
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
541
+ return BaseModelOutput(
542
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
543
+ )
544
+
545
+
546
+ @auto_docstring
547
+ class AlbertPreTrainedModel(PreTrainedModel):
548
+ config: AlbertConfig
549
+ load_tf_weights = load_tf_weights_in_albert
550
+ base_model_prefix = "albert"
551
+ _supports_sdpa = True
552
+
553
+ def _init_weights(self, module):
554
+ """Initialize the weights."""
555
+ if isinstance(module, nn.Linear):
556
+ # Slightly different from the TF version which uses truncated_normal for initialization
557
+ # cf https://github.com/pytorch/pytorch/pull/5617
558
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
559
+ if module.bias is not None:
560
+ module.bias.data.zero_()
561
+ elif isinstance(module, nn.Embedding):
562
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
563
+ if module.padding_idx is not None:
564
+ module.weight.data[module.padding_idx].zero_()
565
+ elif isinstance(module, nn.LayerNorm):
566
+ module.bias.data.zero_()
567
+ module.weight.data.fill_(1.0)
568
+ elif isinstance(module, AlbertMLMHead):
569
+ module.bias.data.zero_()
570
+
571
+
572
+ @dataclass
573
+ @auto_docstring(
574
+ custom_intro="""
575
+ Output type of [`AlbertForPreTraining`].
576
+ """
577
+ )
578
+ class AlbertForPreTrainingOutput(ModelOutput):
579
+ r"""
580
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
581
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
582
+ (classification) loss.
583
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
584
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
585
+ sop_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
586
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
587
+ before SoftMax).
588
+ """
589
+
590
+ loss: Optional[torch.FloatTensor] = None
591
+ prediction_logits: Optional[torch.FloatTensor] = None
592
+ sop_logits: Optional[torch.FloatTensor] = None
593
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
594
+ attentions: Optional[tuple[torch.FloatTensor]] = None
595
+
596
+
597
+ @auto_docstring
598
+ class AlbertModel(AlbertPreTrainedModel):
599
+ config: AlbertConfig
600
+ base_model_prefix = "albert"
601
+
602
+ def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True):
603
+ r"""
604
+ add_pooling_layer (bool, *optional*, defaults to `True`):
605
+ Whether to add a pooling layer
606
+ """
607
+ super().__init__(config)
608
+
609
+ self.config = config
610
+ self.embeddings = AlbertEmbeddings(config)
611
+ self.encoder = AlbertTransformer(config)
612
+ if add_pooling_layer:
613
+ self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
614
+ self.pooler_activation = nn.Tanh()
615
+ else:
616
+ self.pooler = None
617
+ self.pooler_activation = None
618
+
619
+ self.attn_implementation = config._attn_implementation
620
+ self.position_embedding_type = config.position_embedding_type
621
+
622
+ # Initialize weights and apply final processing
623
+ self.post_init()
624
+
625
+ def get_input_embeddings(self) -> nn.Embedding:
626
+ return self.embeddings.word_embeddings
627
+
628
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
629
+ self.embeddings.word_embeddings = value
630
+
631
+ def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
632
+ """
633
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
634
+ a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
635
+ model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers.
636
+
637
+ These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
638
+ while [2,3] correspond to the two inner groups of the second hidden layer.
639
+
640
+ Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more
641
+ information about head pruning
642
+ """
643
+ for layer, heads in heads_to_prune.items():
644
+ group_idx = int(layer / self.config.inner_group_num)
645
+ inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
646
+ self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
647
+
648
+ @auto_docstring
649
+ def forward(
650
+ self,
651
+ input_ids: Optional[torch.LongTensor] = None,
652
+ attention_mask: Optional[torch.FloatTensor] = None,
653
+ token_type_ids: Optional[torch.LongTensor] = None,
654
+ position_ids: Optional[torch.LongTensor] = None,
655
+ head_mask: Optional[torch.FloatTensor] = None,
656
+ inputs_embeds: Optional[torch.FloatTensor] = None,
657
+ output_attentions: Optional[bool] = None,
658
+ output_hidden_states: Optional[bool] = None,
659
+ return_dict: Optional[bool] = None,
660
+ ) -> Union[BaseModelOutputWithPooling, tuple]:
661
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
662
+ output_hidden_states = (
663
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
664
+ )
665
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
666
+
667
+ if input_ids is not None and inputs_embeds is not None:
668
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
669
+ elif input_ids is not None:
670
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
671
+ input_shape = input_ids.size()
672
+ elif inputs_embeds is not None:
673
+ input_shape = inputs_embeds.size()[:-1]
674
+ else:
675
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
676
+
677
+ batch_size, seq_length = input_shape
678
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
679
+
680
+ if attention_mask is None:
681
+ attention_mask = torch.ones(input_shape, device=device)
682
+ if token_type_ids is None:
683
+ if hasattr(self.embeddings, "token_type_ids"):
684
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
685
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
686
+ token_type_ids = buffered_token_type_ids_expanded
687
+ else:
688
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
689
+
690
+ embedding_output = self.embeddings(
691
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
692
+ )
693
+
694
+ use_sdpa_attention_mask = (
695
+ self.attn_implementation == "sdpa"
696
+ and self.position_embedding_type == "absolute"
697
+ and head_mask is None
698
+ and not output_attentions
699
+ )
700
+
701
+ if use_sdpa_attention_mask:
702
+ extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
703
+ attention_mask, embedding_output.dtype, tgt_len=seq_length
704
+ )
705
+ else:
706
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
707
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
708
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
709
+
710
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
711
+
712
+ encoder_outputs = self.encoder(
713
+ embedding_output,
714
+ extended_attention_mask,
715
+ head_mask=head_mask,
716
+ output_attentions=output_attentions,
717
+ output_hidden_states=output_hidden_states,
718
+ return_dict=return_dict,
719
+ )
720
+
721
+ sequence_output = encoder_outputs[0]
722
+
723
+ pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
724
+
725
+ if not return_dict:
726
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
727
+
728
+ return BaseModelOutputWithPooling(
729
+ last_hidden_state=sequence_output,
730
+ pooler_output=pooled_output,
731
+ hidden_states=encoder_outputs.hidden_states,
732
+ attentions=encoder_outputs.attentions,
733
+ )
734
+
735
+
736
+ @auto_docstring(
737
+ custom_intro="""
738
+ Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
739
+ `sentence order prediction (classification)` head.
740
+ """
741
+ )
742
+ class AlbertForPreTraining(AlbertPreTrainedModel):
743
+ _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
744
+
745
+ def __init__(self, config: AlbertConfig):
746
+ super().__init__(config)
747
+
748
+ self.albert = AlbertModel(config)
749
+ self.predictions = AlbertMLMHead(config)
750
+ self.sop_classifier = AlbertSOPHead(config)
751
+
752
+ # Initialize weights and apply final processing
753
+ self.post_init()
754
+
755
+ def get_output_embeddings(self) -> nn.Linear:
756
+ return self.predictions.decoder
757
+
758
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
759
+ self.predictions.decoder = new_embeddings
760
+
761
+ def get_input_embeddings(self) -> nn.Embedding:
762
+ return self.albert.embeddings.word_embeddings
763
+
764
+ @auto_docstring
765
+ def forward(
766
+ self,
767
+ input_ids: Optional[torch.LongTensor] = None,
768
+ attention_mask: Optional[torch.FloatTensor] = None,
769
+ token_type_ids: Optional[torch.LongTensor] = None,
770
+ position_ids: Optional[torch.LongTensor] = None,
771
+ head_mask: Optional[torch.FloatTensor] = None,
772
+ inputs_embeds: Optional[torch.FloatTensor] = None,
773
+ labels: Optional[torch.LongTensor] = None,
774
+ sentence_order_label: Optional[torch.LongTensor] = None,
775
+ output_attentions: Optional[bool] = None,
776
+ output_hidden_states: Optional[bool] = None,
777
+ return_dict: Optional[bool] = None,
778
+ ) -> Union[AlbertForPreTrainingOutput, tuple]:
779
+ r"""
780
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
781
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
782
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
783
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
784
+ sentence_order_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
785
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
786
+ (see `input_ids` docstring) Indices should be in `[0, 1]`. `0` indicates original order (sequence A, then
787
+ sequence B), `1` indicates switched order (sequence B, then sequence A).
788
+
789
+ Example:
790
+
791
+ ```python
792
+ >>> from transformers import AutoTokenizer, AlbertForPreTraining
793
+ >>> import torch
794
+
795
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
796
+ >>> model = AlbertForPreTraining.from_pretrained("albert/albert-base-v2")
797
+
798
+ >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
799
+ >>> # Batch size 1
800
+ >>> outputs = model(input_ids)
801
+
802
+ >>> prediction_logits = outputs.prediction_logits
803
+ >>> sop_logits = outputs.sop_logits
804
+ ```"""
805
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
806
+
807
+ outputs = self.albert(
808
+ input_ids,
809
+ attention_mask=attention_mask,
810
+ token_type_ids=token_type_ids,
811
+ position_ids=position_ids,
812
+ head_mask=head_mask,
813
+ inputs_embeds=inputs_embeds,
814
+ output_attentions=output_attentions,
815
+ output_hidden_states=output_hidden_states,
816
+ return_dict=return_dict,
817
+ )
818
+
819
+ sequence_output, pooled_output = outputs[:2]
820
+
821
+ prediction_scores = self.predictions(sequence_output)
822
+ sop_scores = self.sop_classifier(pooled_output)
823
+
824
+ total_loss = None
825
+ if labels is not None and sentence_order_label is not None:
826
+ loss_fct = CrossEntropyLoss()
827
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
828
+ sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
829
+ total_loss = masked_lm_loss + sentence_order_loss
830
+
831
+ if not return_dict:
832
+ output = (prediction_scores, sop_scores) + outputs[2:]
833
+ return ((total_loss,) + output) if total_loss is not None else output
834
+
835
+ return AlbertForPreTrainingOutput(
836
+ loss=total_loss,
837
+ prediction_logits=prediction_scores,
838
+ sop_logits=sop_scores,
839
+ hidden_states=outputs.hidden_states,
840
+ attentions=outputs.attentions,
841
+ )
842
+
843
+
844
+ class AlbertMLMHead(nn.Module):
845
+ def __init__(self, config: AlbertConfig):
846
+ super().__init__()
847
+
848
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
849
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
850
+ self.dense = nn.Linear(config.hidden_size, config.embedding_size)
851
+ self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
852
+ self.activation = ACT2FN[config.hidden_act]
853
+ self.decoder.bias = self.bias
854
+
855
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
856
+ hidden_states = self.dense(hidden_states)
857
+ hidden_states = self.activation(hidden_states)
858
+ hidden_states = self.LayerNorm(hidden_states)
859
+ hidden_states = self.decoder(hidden_states)
860
+
861
+ prediction_scores = hidden_states
862
+
863
+ return prediction_scores
864
+
865
+ def _tie_weights(self) -> None:
866
+ # For accelerate compatibility and to not break backward compatibility
867
+ if self.decoder.bias.device.type == "meta":
868
+ self.decoder.bias = self.bias
869
+ else:
870
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
871
+ self.bias = self.decoder.bias
872
+
873
+
874
+ class AlbertSOPHead(nn.Module):
875
+ def __init__(self, config: AlbertConfig):
876
+ super().__init__()
877
+
878
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
879
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
880
+
881
+ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
882
+ dropout_pooled_output = self.dropout(pooled_output)
883
+ logits = self.classifier(dropout_pooled_output)
884
+ return logits
885
+
886
+
887
+ @auto_docstring
888
+ class AlbertForMaskedLM(AlbertPreTrainedModel):
889
+ _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
890
+
891
+ def __init__(self, config):
892
+ super().__init__(config)
893
+
894
+ self.albert = AlbertModel(config, add_pooling_layer=False)
895
+ self.predictions = AlbertMLMHead(config)
896
+
897
+ # Initialize weights and apply final processing
898
+ self.post_init()
899
+
900
+ def get_output_embeddings(self) -> nn.Linear:
901
+ return self.predictions.decoder
902
+
903
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
904
+ self.predictions.decoder = new_embeddings
905
+ self.predictions.bias = new_embeddings.bias
906
+
907
+ def get_input_embeddings(self) -> nn.Embedding:
908
+ return self.albert.embeddings.word_embeddings
909
+
910
+ @auto_docstring
911
+ def forward(
912
+ self,
913
+ input_ids: Optional[torch.LongTensor] = None,
914
+ attention_mask: Optional[torch.FloatTensor] = None,
915
+ token_type_ids: Optional[torch.LongTensor] = None,
916
+ position_ids: Optional[torch.LongTensor] = None,
917
+ head_mask: Optional[torch.FloatTensor] = None,
918
+ inputs_embeds: Optional[torch.FloatTensor] = None,
919
+ labels: Optional[torch.LongTensor] = None,
920
+ output_attentions: Optional[bool] = None,
921
+ output_hidden_states: Optional[bool] = None,
922
+ return_dict: Optional[bool] = None,
923
+ ) -> Union[MaskedLMOutput, tuple]:
924
+ r"""
925
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
926
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
927
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
928
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
929
+
930
+ Example:
931
+
932
+ ```python
933
+ >>> import torch
934
+ >>> from transformers import AutoTokenizer, AlbertForMaskedLM
935
+
936
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
937
+ >>> model = AlbertForMaskedLM.from_pretrained("albert/albert-base-v2")
938
+
939
+ >>> # add mask_token
940
+ >>> inputs = tokenizer("The capital of [MASK] is Paris.", return_tensors="pt")
941
+ >>> with torch.no_grad():
942
+ ... logits = model(**inputs).logits
943
+
944
+ >>> # retrieve index of [MASK]
945
+ >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
946
+ >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
947
+ >>> tokenizer.decode(predicted_token_id)
948
+ 'france'
949
+ ```
950
+
951
+ ```python
952
+ >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
953
+ >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
954
+ >>> outputs = model(**inputs, labels=labels)
955
+ >>> round(outputs.loss.item(), 2)
956
+ 0.81
957
+ ```
958
+ """
959
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
960
+
961
+ outputs = self.albert(
962
+ input_ids=input_ids,
963
+ attention_mask=attention_mask,
964
+ token_type_ids=token_type_ids,
965
+ position_ids=position_ids,
966
+ head_mask=head_mask,
967
+ inputs_embeds=inputs_embeds,
968
+ output_attentions=output_attentions,
969
+ output_hidden_states=output_hidden_states,
970
+ return_dict=return_dict,
971
+ )
972
+ sequence_outputs = outputs[0]
973
+
974
+ prediction_scores = self.predictions(sequence_outputs)
975
+
976
+ masked_lm_loss = None
977
+ if labels is not None:
978
+ loss_fct = CrossEntropyLoss()
979
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
980
+
981
+ if not return_dict:
982
+ output = (prediction_scores,) + outputs[2:]
983
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
984
+
985
+ return MaskedLMOutput(
986
+ loss=masked_lm_loss,
987
+ logits=prediction_scores,
988
+ hidden_states=outputs.hidden_states,
989
+ attentions=outputs.attentions,
990
+ )
991
+
992
+
993
+ @auto_docstring(
994
+ custom_intro="""
995
+ Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
996
+ output) e.g. for GLUE tasks.
997
+ """
998
+ )
999
+ class AlbertForSequenceClassification(AlbertPreTrainedModel):
1000
+ def __init__(self, config: AlbertConfig):
1001
+ super().__init__(config)
1002
+ self.num_labels = config.num_labels
1003
+ self.config = config
1004
+
1005
+ self.albert = AlbertModel(config)
1006
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
1007
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
1008
+
1009
+ # Initialize weights and apply final processing
1010
+ self.post_init()
1011
+
1012
+ @auto_docstring
1013
+ def forward(
1014
+ self,
1015
+ input_ids: Optional[torch.LongTensor] = None,
1016
+ attention_mask: Optional[torch.FloatTensor] = None,
1017
+ token_type_ids: Optional[torch.LongTensor] = None,
1018
+ position_ids: Optional[torch.LongTensor] = None,
1019
+ head_mask: Optional[torch.FloatTensor] = None,
1020
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1021
+ labels: Optional[torch.LongTensor] = None,
1022
+ output_attentions: Optional[bool] = None,
1023
+ output_hidden_states: Optional[bool] = None,
1024
+ return_dict: Optional[bool] = None,
1025
+ ) -> Union[SequenceClassifierOutput, tuple]:
1026
+ r"""
1027
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1028
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1029
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1030
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1031
+ """
1032
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1033
+
1034
+ outputs = self.albert(
1035
+ input_ids=input_ids,
1036
+ attention_mask=attention_mask,
1037
+ token_type_ids=token_type_ids,
1038
+ position_ids=position_ids,
1039
+ head_mask=head_mask,
1040
+ inputs_embeds=inputs_embeds,
1041
+ output_attentions=output_attentions,
1042
+ output_hidden_states=output_hidden_states,
1043
+ return_dict=return_dict,
1044
+ )
1045
+
1046
+ pooled_output = outputs[1]
1047
+
1048
+ pooled_output = self.dropout(pooled_output)
1049
+ logits = self.classifier(pooled_output)
1050
+
1051
+ loss = None
1052
+ if labels is not None:
1053
+ if self.config.problem_type is None:
1054
+ if self.num_labels == 1:
1055
+ self.config.problem_type = "regression"
1056
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1057
+ self.config.problem_type = "single_label_classification"
1058
+ else:
1059
+ self.config.problem_type = "multi_label_classification"
1060
+
1061
+ if self.config.problem_type == "regression":
1062
+ loss_fct = MSELoss()
1063
+ if self.num_labels == 1:
1064
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1065
+ else:
1066
+ loss = loss_fct(logits, labels)
1067
+ elif self.config.problem_type == "single_label_classification":
1068
+ loss_fct = CrossEntropyLoss()
1069
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1070
+ elif self.config.problem_type == "multi_label_classification":
1071
+ loss_fct = BCEWithLogitsLoss()
1072
+ loss = loss_fct(logits, labels)
1073
+
1074
+ if not return_dict:
1075
+ output = (logits,) + outputs[2:]
1076
+ return ((loss,) + output) if loss is not None else output
1077
+
1078
+ return SequenceClassifierOutput(
1079
+ loss=loss,
1080
+ logits=logits,
1081
+ hidden_states=outputs.hidden_states,
1082
+ attentions=outputs.attentions,
1083
+ )
1084
+
1085
+
1086
+ @auto_docstring
1087
+ class AlbertForTokenClassification(AlbertPreTrainedModel):
1088
+ def __init__(self, config: AlbertConfig):
1089
+ super().__init__(config)
1090
+ self.num_labels = config.num_labels
1091
+
1092
+ self.albert = AlbertModel(config, add_pooling_layer=False)
1093
+ classifier_dropout_prob = (
1094
+ config.classifier_dropout_prob
1095
+ if config.classifier_dropout_prob is not None
1096
+ else config.hidden_dropout_prob
1097
+ )
1098
+ self.dropout = nn.Dropout(classifier_dropout_prob)
1099
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
1100
+
1101
+ # Initialize weights and apply final processing
1102
+ self.post_init()
1103
+
1104
+ @auto_docstring
1105
+ def forward(
1106
+ self,
1107
+ input_ids: Optional[torch.LongTensor] = None,
1108
+ attention_mask: Optional[torch.FloatTensor] = None,
1109
+ token_type_ids: Optional[torch.LongTensor] = None,
1110
+ position_ids: Optional[torch.LongTensor] = None,
1111
+ head_mask: Optional[torch.FloatTensor] = None,
1112
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1113
+ labels: Optional[torch.LongTensor] = None,
1114
+ output_attentions: Optional[bool] = None,
1115
+ output_hidden_states: Optional[bool] = None,
1116
+ return_dict: Optional[bool] = None,
1117
+ ) -> Union[TokenClassifierOutput, tuple]:
1118
+ r"""
1119
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1120
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1121
+ """
1122
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1123
+
1124
+ outputs = self.albert(
1125
+ input_ids,
1126
+ attention_mask=attention_mask,
1127
+ token_type_ids=token_type_ids,
1128
+ position_ids=position_ids,
1129
+ head_mask=head_mask,
1130
+ inputs_embeds=inputs_embeds,
1131
+ output_attentions=output_attentions,
1132
+ output_hidden_states=output_hidden_states,
1133
+ return_dict=return_dict,
1134
+ )
1135
+
1136
+ sequence_output = outputs[0]
1137
+
1138
+ sequence_output = self.dropout(sequence_output)
1139
+ logits = self.classifier(sequence_output)
1140
+
1141
+ loss = None
1142
+ if labels is not None:
1143
+ loss_fct = CrossEntropyLoss()
1144
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1145
+
1146
+ if not return_dict:
1147
+ output = (logits,) + outputs[2:]
1148
+ return ((loss,) + output) if loss is not None else output
1149
+
1150
+ return TokenClassifierOutput(
1151
+ loss=loss,
1152
+ logits=logits,
1153
+ hidden_states=outputs.hidden_states,
1154
+ attentions=outputs.attentions,
1155
+ )
1156
+
1157
+
1158
+ @auto_docstring
1159
+ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
1160
+ def __init__(self, config: AlbertConfig):
1161
+ super().__init__(config)
1162
+ self.num_labels = config.num_labels
1163
+
1164
+ self.albert = AlbertModel(config, add_pooling_layer=False)
1165
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1166
+
1167
+ # Initialize weights and apply final processing
1168
+ self.post_init()
1169
+
1170
+ @auto_docstring
1171
+ def forward(
1172
+ self,
1173
+ input_ids: Optional[torch.LongTensor] = None,
1174
+ attention_mask: Optional[torch.FloatTensor] = None,
1175
+ token_type_ids: Optional[torch.LongTensor] = None,
1176
+ position_ids: Optional[torch.LongTensor] = None,
1177
+ head_mask: Optional[torch.FloatTensor] = None,
1178
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1179
+ start_positions: Optional[torch.LongTensor] = None,
1180
+ end_positions: Optional[torch.LongTensor] = None,
1181
+ output_attentions: Optional[bool] = None,
1182
+ output_hidden_states: Optional[bool] = None,
1183
+ return_dict: Optional[bool] = None,
1184
+ ) -> Union[AlbertForPreTrainingOutput, tuple]:
1185
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1186
+
1187
+ outputs = self.albert(
1188
+ input_ids=input_ids,
1189
+ attention_mask=attention_mask,
1190
+ token_type_ids=token_type_ids,
1191
+ position_ids=position_ids,
1192
+ head_mask=head_mask,
1193
+ inputs_embeds=inputs_embeds,
1194
+ output_attentions=output_attentions,
1195
+ output_hidden_states=output_hidden_states,
1196
+ return_dict=return_dict,
1197
+ )
1198
+
1199
+ sequence_output = outputs[0]
1200
+
1201
+ logits: torch.Tensor = self.qa_outputs(sequence_output)
1202
+ start_logits, end_logits = logits.split(1, dim=-1)
1203
+ start_logits = start_logits.squeeze(-1).contiguous()
1204
+ end_logits = end_logits.squeeze(-1).contiguous()
1205
+
1206
+ total_loss = None
1207
+ if start_positions is not None and end_positions is not None:
1208
+ # If we are on multi-GPU, split add a dimension
1209
+ if len(start_positions.size()) > 1:
1210
+ start_positions = start_positions.squeeze(-1)
1211
+ if len(end_positions.size()) > 1:
1212
+ end_positions = end_positions.squeeze(-1)
1213
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1214
+ ignored_index = start_logits.size(1)
1215
+ start_positions = start_positions.clamp(0, ignored_index)
1216
+ end_positions = end_positions.clamp(0, ignored_index)
1217
+
1218
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1219
+ start_loss = loss_fct(start_logits, start_positions)
1220
+ end_loss = loss_fct(end_logits, end_positions)
1221
+ total_loss = (start_loss + end_loss) / 2
1222
+
1223
+ if not return_dict:
1224
+ output = (start_logits, end_logits) + outputs[2:]
1225
+ return ((total_loss,) + output) if total_loss is not None else output
1226
+
1227
+ return QuestionAnsweringModelOutput(
1228
+ loss=total_loss,
1229
+ start_logits=start_logits,
1230
+ end_logits=end_logits,
1231
+ hidden_states=outputs.hidden_states,
1232
+ attentions=outputs.attentions,
1233
+ )
1234
+
1235
+
1236
+ @auto_docstring
1237
+ class AlbertForMultipleChoice(AlbertPreTrainedModel):
1238
+ def __init__(self, config: AlbertConfig):
1239
+ super().__init__(config)
1240
+
1241
+ self.albert = AlbertModel(config)
1242
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
1243
+ self.classifier = nn.Linear(config.hidden_size, 1)
1244
+
1245
+ # Initialize weights and apply final processing
1246
+ self.post_init()
1247
+
1248
+ @auto_docstring
1249
+ def forward(
1250
+ self,
1251
+ input_ids: Optional[torch.LongTensor] = None,
1252
+ attention_mask: Optional[torch.FloatTensor] = None,
1253
+ token_type_ids: Optional[torch.LongTensor] = None,
1254
+ position_ids: Optional[torch.LongTensor] = None,
1255
+ head_mask: Optional[torch.FloatTensor] = None,
1256
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1257
+ labels: Optional[torch.LongTensor] = None,
1258
+ output_attentions: Optional[bool] = None,
1259
+ output_hidden_states: Optional[bool] = None,
1260
+ return_dict: Optional[bool] = None,
1261
+ ) -> Union[AlbertForPreTrainingOutput, tuple]:
1262
+ r"""
1263
+ input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
1264
+ Indices of input sequence tokens in the vocabulary.
1265
+
1266
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
1267
+ [`PreTrainedTokenizer.encode`] for details.
1268
+
1269
+ [What are input IDs?](../glossary#input-ids)
1270
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
1271
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1272
+ 1]`:
1273
+
1274
+ - 0 corresponds to a *sentence A* token,
1275
+ - 1 corresponds to a *sentence B* token.
1276
+
1277
+ [What are token type IDs?](../glossary#token-type-ids)
1278
+ position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
1279
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1280
+ config.max_position_embeddings - 1]`.
1281
+
1282
+ [What are position IDs?](../glossary#position-ids)
1283
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
1284
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1285
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1286
+ model's internal embedding lookup matrix.
1287
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1288
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1289
+ num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
1290
+ *input_ids* above)
1291
+ """
1292
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1293
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1294
+
1295
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1296
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1297
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1298
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1299
+ inputs_embeds = (
1300
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1301
+ if inputs_embeds is not None
1302
+ else None
1303
+ )
1304
+ outputs = self.albert(
1305
+ input_ids,
1306
+ attention_mask=attention_mask,
1307
+ token_type_ids=token_type_ids,
1308
+ position_ids=position_ids,
1309
+ head_mask=head_mask,
1310
+ inputs_embeds=inputs_embeds,
1311
+ output_attentions=output_attentions,
1312
+ output_hidden_states=output_hidden_states,
1313
+ return_dict=return_dict,
1314
+ )
1315
+
1316
+ pooled_output = outputs[1]
1317
+
1318
+ pooled_output = self.dropout(pooled_output)
1319
+ logits: torch.Tensor = self.classifier(pooled_output)
1320
+ reshaped_logits = logits.view(-1, num_choices)
1321
+
1322
+ loss = None
1323
+ if labels is not None:
1324
+ loss_fct = CrossEntropyLoss()
1325
+ loss = loss_fct(reshaped_logits, labels)
1326
+
1327
+ if not return_dict:
1328
+ output = (reshaped_logits,) + outputs[2:]
1329
+ return ((loss,) + output) if loss is not None else output
1330
+
1331
+ return MultipleChoiceModelOutput(
1332
+ loss=loss,
1333
+ logits=reshaped_logits,
1334
+ hidden_states=outputs.hidden_states,
1335
+ attentions=outputs.attentions,
1336
+ )
1337
+
1338
+
1339
+ __all__ = [
1340
+ "load_tf_weights_in_albert",
1341
+ "AlbertPreTrainedModel",
1342
+ "AlbertModel",
1343
+ "AlbertForPreTraining",
1344
+ "AlbertForMaskedLM",
1345
+ "AlbertForSequenceClassification",
1346
+ "AlbertForTokenClassification",
1347
+ "AlbertForQuestionAnswering",
1348
+ "AlbertForMultipleChoice",
1349
+ ]
venv/lib/python3.13/site-packages/transformers/models/albert/modeling_flax_albert.py ADDED
@@ -0,0 +1,1132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Google AI, Google Brain and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Callable, Optional
17
+
18
+ import flax
19
+ import flax.linen as nn
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
24
+ from flax.linen.attention import dot_product_attention_weights
25
+ from flax.traverse_util import flatten_dict, unflatten_dict
26
+ from jax import lax
27
+
28
+ from ...modeling_flax_outputs import (
29
+ FlaxBaseModelOutput,
30
+ FlaxBaseModelOutputWithPooling,
31
+ FlaxMaskedLMOutput,
32
+ FlaxMultipleChoiceModelOutput,
33
+ FlaxQuestionAnsweringModelOutput,
34
+ FlaxSequenceClassifierOutput,
35
+ FlaxTokenClassifierOutput,
36
+ )
37
+ from ...modeling_flax_utils import (
38
+ ACT2FN,
39
+ FlaxPreTrainedModel,
40
+ append_call_sample_docstring,
41
+ append_replace_return_docstrings,
42
+ overwrite_call_docstring,
43
+ )
44
+ from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
45
+ from .configuration_albert import AlbertConfig
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+ _CHECKPOINT_FOR_DOC = "albert/albert-base-v2"
51
+ _CONFIG_FOR_DOC = "AlbertConfig"
52
+
53
+
54
+ @flax.struct.dataclass
55
+ class FlaxAlbertForPreTrainingOutput(ModelOutput):
56
+ """
57
+ Output type of [`FlaxAlbertForPreTraining`].
58
+
59
+ Args:
60
+ prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
61
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
62
+ sop_logits (`jnp.ndarray` of shape `(batch_size, 2)`):
63
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
64
+ before SoftMax).
65
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
66
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
67
+ `(batch_size, sequence_length, hidden_size)`.
68
+
69
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
70
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
71
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
72
+ sequence_length)`.
73
+
74
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
75
+ heads.
76
+ """
77
+
78
+ prediction_logits: jnp.ndarray = None
79
+ sop_logits: jnp.ndarray = None
80
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
81
+ attentions: Optional[tuple[jnp.ndarray]] = None
82
+
83
+
84
+ ALBERT_START_DOCSTRING = r"""
85
+
86
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
87
+ library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
88
+
89
+ This model is also a
90
+ [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
91
+ a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
92
+ behavior.
93
+
94
+ Finally, this model supports inherent JAX features such as:
95
+
96
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
97
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
98
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
99
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
100
+
101
+ Parameters:
102
+ config ([`AlbertConfig`]): Model configuration class with all the parameters of the model.
103
+ Initializing with a config file does not load the weights associated with the model, only the
104
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
105
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
106
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
107
+ `jax.numpy.bfloat16` (on TPUs).
108
+
109
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
110
+ specified all the computation will be performed with the given `dtype`.
111
+
112
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
113
+ parameters.**
114
+
115
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
116
+ [`~FlaxPreTrainedModel.to_bf16`].
117
+ """
118
+
119
+ ALBERT_INPUTS_DOCSTRING = r"""
120
+ Args:
121
+ input_ids (`numpy.ndarray` of shape `({0})`):
122
+ Indices of input sequence tokens in the vocabulary.
123
+
124
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
125
+ [`PreTrainedTokenizer.__call__`] for details.
126
+
127
+ [What are input IDs?](../glossary#input-ids)
128
+ attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
129
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
130
+
131
+ - 1 for tokens that are **not masked**,
132
+ - 0 for tokens that are **masked**.
133
+
134
+ [What are attention masks?](../glossary#attention-mask)
135
+ token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
136
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
137
+ 1]`:
138
+
139
+ - 0 corresponds to a *sentence A* token,
140
+ - 1 corresponds to a *sentence B* token.
141
+
142
+ [What are token type IDs?](../glossary#token-type-ids)
143
+ position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
144
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
145
+ config.max_position_embeddings - 1]`.
146
+ return_dict (`bool`, *optional*):
147
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
148
+
149
+ """
150
+
151
+
152
+ class FlaxAlbertEmbeddings(nn.Module):
153
+ """Construct the embeddings from word, position and token_type embeddings."""
154
+
155
+ config: AlbertConfig
156
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
157
+
158
+ def setup(self):
159
+ self.word_embeddings = nn.Embed(
160
+ self.config.vocab_size,
161
+ self.config.embedding_size,
162
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
163
+ )
164
+ self.position_embeddings = nn.Embed(
165
+ self.config.max_position_embeddings,
166
+ self.config.embedding_size,
167
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
168
+ )
169
+ self.token_type_embeddings = nn.Embed(
170
+ self.config.type_vocab_size,
171
+ self.config.embedding_size,
172
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
173
+ )
174
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
175
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
176
+
177
+ def __call__(self, input_ids, token_type_ids, position_ids, deterministic: bool = True):
178
+ # Embed
179
+ inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
180
+ position_embeds = self.position_embeddings(position_ids.astype("i4"))
181
+ token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
182
+
183
+ # Sum all embeddings
184
+ hidden_states = inputs_embeds + token_type_embeddings + position_embeds
185
+
186
+ # Layer Norm
187
+ hidden_states = self.LayerNorm(hidden_states)
188
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
189
+ return hidden_states
190
+
191
+
192
+ class FlaxAlbertSelfAttention(nn.Module):
193
+ config: AlbertConfig
194
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
195
+
196
+ def setup(self):
197
+ if self.config.hidden_size % self.config.num_attention_heads != 0:
198
+ raise ValueError(
199
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
200
+ " : {self.config.num_attention_heads}"
201
+ )
202
+
203
+ self.query = nn.Dense(
204
+ self.config.hidden_size,
205
+ dtype=self.dtype,
206
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
207
+ )
208
+ self.key = nn.Dense(
209
+ self.config.hidden_size,
210
+ dtype=self.dtype,
211
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
212
+ )
213
+ self.value = nn.Dense(
214
+ self.config.hidden_size,
215
+ dtype=self.dtype,
216
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
217
+ )
218
+ self.dense = nn.Dense(
219
+ self.config.hidden_size,
220
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
221
+ dtype=self.dtype,
222
+ )
223
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
224
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
225
+
226
+ def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
227
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
228
+
229
+ query_states = self.query(hidden_states).reshape(
230
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
231
+ )
232
+ value_states = self.value(hidden_states).reshape(
233
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
234
+ )
235
+ key_states = self.key(hidden_states).reshape(
236
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
237
+ )
238
+
239
+ # Convert the boolean attention mask to an attention bias.
240
+ if attention_mask is not None:
241
+ # attention mask in the form of attention bias
242
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
243
+ attention_bias = lax.select(
244
+ attention_mask > 0,
245
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
246
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
247
+ )
248
+ else:
249
+ attention_bias = None
250
+
251
+ dropout_rng = None
252
+ if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
253
+ dropout_rng = self.make_rng("dropout")
254
+
255
+ attn_weights = dot_product_attention_weights(
256
+ query_states,
257
+ key_states,
258
+ bias=attention_bias,
259
+ dropout_rng=dropout_rng,
260
+ dropout_rate=self.config.attention_probs_dropout_prob,
261
+ broadcast_dropout=True,
262
+ deterministic=deterministic,
263
+ dtype=self.dtype,
264
+ precision=None,
265
+ )
266
+
267
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
268
+ attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
269
+
270
+ projected_attn_output = self.dense(attn_output)
271
+ projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic)
272
+ layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states)
273
+ outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,)
274
+ return outputs
275
+
276
+
277
+ class FlaxAlbertLayer(nn.Module):
278
+ config: AlbertConfig
279
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
280
+
281
+ def setup(self):
282
+ self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype)
283
+ self.ffn = nn.Dense(
284
+ self.config.intermediate_size,
285
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
286
+ dtype=self.dtype,
287
+ )
288
+ self.activation = ACT2FN[self.config.hidden_act]
289
+ self.ffn_output = nn.Dense(
290
+ self.config.hidden_size,
291
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
292
+ dtype=self.dtype,
293
+ )
294
+ self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
295
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
296
+
297
+ def __call__(
298
+ self,
299
+ hidden_states,
300
+ attention_mask,
301
+ deterministic: bool = True,
302
+ output_attentions: bool = False,
303
+ ):
304
+ attention_outputs = self.attention(
305
+ hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
306
+ )
307
+ attention_output = attention_outputs[0]
308
+ ffn_output = self.ffn(attention_output)
309
+ ffn_output = self.activation(ffn_output)
310
+ ffn_output = self.ffn_output(ffn_output)
311
+ ffn_output = self.dropout(ffn_output, deterministic=deterministic)
312
+ hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
313
+
314
+ outputs = (hidden_states,)
315
+
316
+ if output_attentions:
317
+ outputs += (attention_outputs[1],)
318
+ return outputs
319
+
320
+
321
+ class FlaxAlbertLayerCollection(nn.Module):
322
+ config: AlbertConfig
323
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
324
+
325
+ def setup(self):
326
+ self.layers = [
327
+ FlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num)
328
+ ]
329
+
330
+ def __call__(
331
+ self,
332
+ hidden_states,
333
+ attention_mask,
334
+ deterministic: bool = True,
335
+ output_attentions: bool = False,
336
+ output_hidden_states: bool = False,
337
+ ):
338
+ layer_hidden_states = ()
339
+ layer_attentions = ()
340
+
341
+ for layer_index, albert_layer in enumerate(self.layers):
342
+ layer_output = albert_layer(
343
+ hidden_states,
344
+ attention_mask,
345
+ deterministic=deterministic,
346
+ output_attentions=output_attentions,
347
+ )
348
+ hidden_states = layer_output[0]
349
+
350
+ if output_attentions:
351
+ layer_attentions = layer_attentions + (layer_output[1],)
352
+
353
+ if output_hidden_states:
354
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
355
+
356
+ outputs = (hidden_states,)
357
+ if output_hidden_states:
358
+ outputs = outputs + (layer_hidden_states,)
359
+ if output_attentions:
360
+ outputs = outputs + (layer_attentions,)
361
+ return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
362
+
363
+
364
+ class FlaxAlbertLayerCollections(nn.Module):
365
+ config: AlbertConfig
366
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
367
+ layer_index: Optional[str] = None
368
+
369
+ def setup(self):
370
+ self.albert_layers = FlaxAlbertLayerCollection(self.config, dtype=self.dtype)
371
+
372
+ def __call__(
373
+ self,
374
+ hidden_states,
375
+ attention_mask,
376
+ deterministic: bool = True,
377
+ output_attentions: bool = False,
378
+ output_hidden_states: bool = False,
379
+ ):
380
+ outputs = self.albert_layers(
381
+ hidden_states,
382
+ attention_mask,
383
+ deterministic=deterministic,
384
+ output_attentions=output_attentions,
385
+ output_hidden_states=output_hidden_states,
386
+ )
387
+ return outputs
388
+
389
+
390
+ class FlaxAlbertLayerGroups(nn.Module):
391
+ config: AlbertConfig
392
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
393
+
394
+ def setup(self):
395
+ self.layers = [
396
+ FlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype)
397
+ for i in range(self.config.num_hidden_groups)
398
+ ]
399
+
400
+ def __call__(
401
+ self,
402
+ hidden_states,
403
+ attention_mask,
404
+ deterministic: bool = True,
405
+ output_attentions: bool = False,
406
+ output_hidden_states: bool = False,
407
+ return_dict: bool = True,
408
+ ):
409
+ all_attentions = () if output_attentions else None
410
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
411
+
412
+ for i in range(self.config.num_hidden_layers):
413
+ # Index of the hidden group
414
+ group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
415
+ layer_group_output = self.layers[group_idx](
416
+ hidden_states,
417
+ attention_mask,
418
+ deterministic=deterministic,
419
+ output_attentions=output_attentions,
420
+ output_hidden_states=output_hidden_states,
421
+ )
422
+ hidden_states = layer_group_output[0]
423
+
424
+ if output_attentions:
425
+ all_attentions = all_attentions + layer_group_output[-1]
426
+
427
+ if output_hidden_states:
428
+ all_hidden_states = all_hidden_states + (hidden_states,)
429
+
430
+ if not return_dict:
431
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
432
+ return FlaxBaseModelOutput(
433
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
434
+ )
435
+
436
+
437
+ class FlaxAlbertEncoder(nn.Module):
438
+ config: AlbertConfig
439
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
440
+
441
+ def setup(self):
442
+ self.embedding_hidden_mapping_in = nn.Dense(
443
+ self.config.hidden_size,
444
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
445
+ dtype=self.dtype,
446
+ )
447
+ self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype)
448
+
449
+ def __call__(
450
+ self,
451
+ hidden_states,
452
+ attention_mask,
453
+ deterministic: bool = True,
454
+ output_attentions: bool = False,
455
+ output_hidden_states: bool = False,
456
+ return_dict: bool = True,
457
+ ):
458
+ hidden_states = self.embedding_hidden_mapping_in(hidden_states)
459
+ return self.albert_layer_groups(
460
+ hidden_states,
461
+ attention_mask,
462
+ deterministic=deterministic,
463
+ output_attentions=output_attentions,
464
+ output_hidden_states=output_hidden_states,
465
+ )
466
+
467
+
468
+ class FlaxAlbertOnlyMLMHead(nn.Module):
469
+ config: AlbertConfig
470
+ dtype: jnp.dtype = jnp.float32
471
+ bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
472
+
473
+ def setup(self):
474
+ self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype)
475
+ self.activation = ACT2FN[self.config.hidden_act]
476
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
477
+ self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
478
+ self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
479
+
480
+ def __call__(self, hidden_states, shared_embedding=None):
481
+ hidden_states = self.dense(hidden_states)
482
+ hidden_states = self.activation(hidden_states)
483
+ hidden_states = self.LayerNorm(hidden_states)
484
+
485
+ if shared_embedding is not None:
486
+ hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
487
+ else:
488
+ hidden_states = self.decoder(hidden_states)
489
+
490
+ hidden_states += self.bias
491
+ return hidden_states
492
+
493
+
494
+ class FlaxAlbertSOPHead(nn.Module):
495
+ config: AlbertConfig
496
+ dtype: jnp.dtype = jnp.float32
497
+
498
+ def setup(self):
499
+ self.dropout = nn.Dropout(self.config.classifier_dropout_prob)
500
+ self.classifier = nn.Dense(2, dtype=self.dtype)
501
+
502
+ def __call__(self, pooled_output, deterministic=True):
503
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
504
+ logits = self.classifier(pooled_output)
505
+ return logits
506
+
507
+
508
+ class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel):
509
+ """
510
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
511
+ models.
512
+ """
513
+
514
+ config_class = AlbertConfig
515
+ base_model_prefix = "albert"
516
+ module_class: nn.Module = None
517
+
518
+ def __init__(
519
+ self,
520
+ config: AlbertConfig,
521
+ input_shape: tuple = (1, 1),
522
+ seed: int = 0,
523
+ dtype: jnp.dtype = jnp.float32,
524
+ _do_init: bool = True,
525
+ **kwargs,
526
+ ):
527
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
528
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
529
+
530
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
531
+ # init input tensors
532
+ input_ids = jnp.zeros(input_shape, dtype="i4")
533
+ token_type_ids = jnp.zeros_like(input_ids)
534
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
535
+ attention_mask = jnp.ones_like(input_ids)
536
+
537
+ params_rng, dropout_rng = jax.random.split(rng)
538
+ rngs = {"params": params_rng, "dropout": dropout_rng}
539
+
540
+ random_params = self.module.init(
541
+ rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False
542
+ )["params"]
543
+
544
+ if params is not None:
545
+ random_params = flatten_dict(unfreeze(random_params))
546
+ params = flatten_dict(unfreeze(params))
547
+ for missing_key in self._missing_keys:
548
+ params[missing_key] = random_params[missing_key]
549
+ self._missing_keys = set()
550
+ return freeze(unflatten_dict(params))
551
+ else:
552
+ return random_params
553
+
554
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
555
+ def __call__(
556
+ self,
557
+ input_ids,
558
+ attention_mask=None,
559
+ token_type_ids=None,
560
+ position_ids=None,
561
+ params: Optional[dict] = None,
562
+ dropout_rng: jax.random.PRNGKey = None,
563
+ train: bool = False,
564
+ output_attentions: Optional[bool] = None,
565
+ output_hidden_states: Optional[bool] = None,
566
+ return_dict: Optional[bool] = None,
567
+ ):
568
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
569
+ output_hidden_states = (
570
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
571
+ )
572
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
573
+
574
+ # init input tensors if not passed
575
+ if token_type_ids is None:
576
+ token_type_ids = jnp.zeros_like(input_ids)
577
+
578
+ if position_ids is None:
579
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
580
+
581
+ if attention_mask is None:
582
+ attention_mask = jnp.ones_like(input_ids)
583
+
584
+ # Handle any PRNG if needed
585
+ rngs = {}
586
+ if dropout_rng is not None:
587
+ rngs["dropout"] = dropout_rng
588
+
589
+ return self.module.apply(
590
+ {"params": params or self.params},
591
+ jnp.array(input_ids, dtype="i4"),
592
+ jnp.array(attention_mask, dtype="i4"),
593
+ jnp.array(token_type_ids, dtype="i4"),
594
+ jnp.array(position_ids, dtype="i4"),
595
+ not train,
596
+ output_attentions,
597
+ output_hidden_states,
598
+ return_dict,
599
+ rngs=rngs,
600
+ )
601
+
602
+
603
+ class FlaxAlbertModule(nn.Module):
604
+ config: AlbertConfig
605
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
606
+ add_pooling_layer: bool = True
607
+
608
+ def setup(self):
609
+ self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype)
610
+ self.encoder = FlaxAlbertEncoder(self.config, dtype=self.dtype)
611
+ if self.add_pooling_layer:
612
+ self.pooler = nn.Dense(
613
+ self.config.hidden_size,
614
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
615
+ dtype=self.dtype,
616
+ name="pooler",
617
+ )
618
+ self.pooler_activation = nn.tanh
619
+ else:
620
+ self.pooler = None
621
+ self.pooler_activation = None
622
+
623
+ def __call__(
624
+ self,
625
+ input_ids,
626
+ attention_mask,
627
+ token_type_ids: Optional[np.ndarray] = None,
628
+ position_ids: Optional[np.ndarray] = None,
629
+ deterministic: bool = True,
630
+ output_attentions: bool = False,
631
+ output_hidden_states: bool = False,
632
+ return_dict: bool = True,
633
+ ):
634
+ # make sure `token_type_ids` is correctly initialized when not passed
635
+ if token_type_ids is None:
636
+ token_type_ids = jnp.zeros_like(input_ids)
637
+
638
+ # make sure `position_ids` is correctly initialized when not passed
639
+ if position_ids is None:
640
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
641
+
642
+ hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic)
643
+
644
+ outputs = self.encoder(
645
+ hidden_states,
646
+ attention_mask,
647
+ deterministic=deterministic,
648
+ output_attentions=output_attentions,
649
+ output_hidden_states=output_hidden_states,
650
+ return_dict=return_dict,
651
+ )
652
+ hidden_states = outputs[0]
653
+ if self.add_pooling_layer:
654
+ pooled = self.pooler(hidden_states[:, 0])
655
+ pooled = self.pooler_activation(pooled)
656
+ else:
657
+ pooled = None
658
+
659
+ if not return_dict:
660
+ # if pooled is None, don't return it
661
+ if pooled is None:
662
+ return (hidden_states,) + outputs[1:]
663
+ return (hidden_states, pooled) + outputs[1:]
664
+
665
+ return FlaxBaseModelOutputWithPooling(
666
+ last_hidden_state=hidden_states,
667
+ pooler_output=pooled,
668
+ hidden_states=outputs.hidden_states,
669
+ attentions=outputs.attentions,
670
+ )
671
+
672
+
673
+ @add_start_docstrings(
674
+ "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.",
675
+ ALBERT_START_DOCSTRING,
676
+ )
677
+ class FlaxAlbertModel(FlaxAlbertPreTrainedModel):
678
+ module_class = FlaxAlbertModule
679
+
680
+
681
+ append_call_sample_docstring(FlaxAlbertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)
682
+
683
+
684
+ class FlaxAlbertForPreTrainingModule(nn.Module):
685
+ config: AlbertConfig
686
+ dtype: jnp.dtype = jnp.float32
687
+
688
+ def setup(self):
689
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
690
+ self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
691
+ self.sop_classifier = FlaxAlbertSOPHead(config=self.config, dtype=self.dtype)
692
+
693
+ def __call__(
694
+ self,
695
+ input_ids,
696
+ attention_mask,
697
+ token_type_ids,
698
+ position_ids,
699
+ deterministic: bool = True,
700
+ output_attentions: bool = False,
701
+ output_hidden_states: bool = False,
702
+ return_dict: bool = True,
703
+ ):
704
+ # Model
705
+ outputs = self.albert(
706
+ input_ids,
707
+ attention_mask,
708
+ token_type_ids,
709
+ position_ids,
710
+ deterministic=deterministic,
711
+ output_attentions=output_attentions,
712
+ output_hidden_states=output_hidden_states,
713
+ return_dict=return_dict,
714
+ )
715
+
716
+ if self.config.tie_word_embeddings:
717
+ shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
718
+ else:
719
+ shared_embedding = None
720
+
721
+ hidden_states = outputs[0]
722
+ pooled_output = outputs[1]
723
+
724
+ prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
725
+ sop_scores = self.sop_classifier(pooled_output, deterministic=deterministic)
726
+
727
+ if not return_dict:
728
+ return (prediction_scores, sop_scores) + outputs[2:]
729
+
730
+ return FlaxAlbertForPreTrainingOutput(
731
+ prediction_logits=prediction_scores,
732
+ sop_logits=sop_scores,
733
+ hidden_states=outputs.hidden_states,
734
+ attentions=outputs.attentions,
735
+ )
736
+
737
+
738
+ @add_start_docstrings(
739
+ """
740
+ Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
741
+ `sentence order prediction (classification)` head.
742
+ """,
743
+ ALBERT_START_DOCSTRING,
744
+ )
745
+ class FlaxAlbertForPreTraining(FlaxAlbertPreTrainedModel):
746
+ module_class = FlaxAlbertForPreTrainingModule
747
+
748
+
749
+ FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING = """
750
+ Returns:
751
+
752
+ Example:
753
+
754
+ ```python
755
+ >>> from transformers import AutoTokenizer, FlaxAlbertForPreTraining
756
+
757
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
758
+ >>> model = FlaxAlbertForPreTraining.from_pretrained("albert/albert-base-v2")
759
+
760
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
761
+ >>> outputs = model(**inputs)
762
+
763
+ >>> prediction_logits = outputs.prediction_logits
764
+ >>> seq_relationship_logits = outputs.sop_logits
765
+ ```
766
+ """
767
+
768
+ overwrite_call_docstring(
769
+ FlaxAlbertForPreTraining,
770
+ ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING,
771
+ )
772
+ append_replace_return_docstrings(
773
+ FlaxAlbertForPreTraining, output_type=FlaxAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
774
+ )
775
+
776
+
777
+ class FlaxAlbertForMaskedLMModule(nn.Module):
778
+ config: AlbertConfig
779
+ dtype: jnp.dtype = jnp.float32
780
+
781
+ def setup(self):
782
+ self.albert = FlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
783
+ self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
784
+
785
+ def __call__(
786
+ self,
787
+ input_ids,
788
+ attention_mask,
789
+ token_type_ids,
790
+ position_ids,
791
+ deterministic: bool = True,
792
+ output_attentions: bool = False,
793
+ output_hidden_states: bool = False,
794
+ return_dict: bool = True,
795
+ ):
796
+ # Model
797
+ outputs = self.albert(
798
+ input_ids,
799
+ attention_mask,
800
+ token_type_ids,
801
+ position_ids,
802
+ deterministic=deterministic,
803
+ output_attentions=output_attentions,
804
+ output_hidden_states=output_hidden_states,
805
+ return_dict=return_dict,
806
+ )
807
+
808
+ hidden_states = outputs[0]
809
+ if self.config.tie_word_embeddings:
810
+ shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
811
+ else:
812
+ shared_embedding = None
813
+
814
+ # Compute the prediction scores
815
+ logits = self.predictions(hidden_states, shared_embedding=shared_embedding)
816
+
817
+ if not return_dict:
818
+ return (logits,) + outputs[1:]
819
+
820
+ return FlaxMaskedLMOutput(
821
+ logits=logits,
822
+ hidden_states=outputs.hidden_states,
823
+ attentions=outputs.attentions,
824
+ )
825
+
826
+
827
+ @add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING)
828
+ class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel):
829
+ module_class = FlaxAlbertForMaskedLMModule
830
+
831
+
832
+ append_call_sample_docstring(
833
+ FlaxAlbertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC, revision="refs/pr/11"
834
+ )
835
+
836
+
837
+ class FlaxAlbertForSequenceClassificationModule(nn.Module):
838
+ config: AlbertConfig
839
+ dtype: jnp.dtype = jnp.float32
840
+
841
+ def setup(self):
842
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
843
+ classifier_dropout = (
844
+ self.config.classifier_dropout_prob
845
+ if self.config.classifier_dropout_prob is not None
846
+ else self.config.hidden_dropout_prob
847
+ )
848
+ self.dropout = nn.Dropout(rate=classifier_dropout)
849
+ self.classifier = nn.Dense(
850
+ self.config.num_labels,
851
+ dtype=self.dtype,
852
+ )
853
+
854
+ def __call__(
855
+ self,
856
+ input_ids,
857
+ attention_mask,
858
+ token_type_ids,
859
+ position_ids,
860
+ deterministic: bool = True,
861
+ output_attentions: bool = False,
862
+ output_hidden_states: bool = False,
863
+ return_dict: bool = True,
864
+ ):
865
+ # Model
866
+ outputs = self.albert(
867
+ input_ids,
868
+ attention_mask,
869
+ token_type_ids,
870
+ position_ids,
871
+ deterministic=deterministic,
872
+ output_attentions=output_attentions,
873
+ output_hidden_states=output_hidden_states,
874
+ return_dict=return_dict,
875
+ )
876
+
877
+ pooled_output = outputs[1]
878
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
879
+ logits = self.classifier(pooled_output)
880
+
881
+ if not return_dict:
882
+ return (logits,) + outputs[2:]
883
+
884
+ return FlaxSequenceClassifierOutput(
885
+ logits=logits,
886
+ hidden_states=outputs.hidden_states,
887
+ attentions=outputs.attentions,
888
+ )
889
+
890
+
891
+ @add_start_docstrings(
892
+ """
893
+ Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
894
+ output) e.g. for GLUE tasks.
895
+ """,
896
+ ALBERT_START_DOCSTRING,
897
+ )
898
+ class FlaxAlbertForSequenceClassification(FlaxAlbertPreTrainedModel):
899
+ module_class = FlaxAlbertForSequenceClassificationModule
900
+
901
+
902
+ append_call_sample_docstring(
903
+ FlaxAlbertForSequenceClassification,
904
+ _CHECKPOINT_FOR_DOC,
905
+ FlaxSequenceClassifierOutput,
906
+ _CONFIG_FOR_DOC,
907
+ )
908
+
909
+
910
+ class FlaxAlbertForMultipleChoiceModule(nn.Module):
911
+ config: AlbertConfig
912
+ dtype: jnp.dtype = jnp.float32
913
+
914
+ def setup(self):
915
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
916
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
917
+ self.classifier = nn.Dense(1, dtype=self.dtype)
918
+
919
+ def __call__(
920
+ self,
921
+ input_ids,
922
+ attention_mask,
923
+ token_type_ids,
924
+ position_ids,
925
+ deterministic: bool = True,
926
+ output_attentions: bool = False,
927
+ output_hidden_states: bool = False,
928
+ return_dict: bool = True,
929
+ ):
930
+ num_choices = input_ids.shape[1]
931
+ input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
932
+ attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
933
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
934
+ position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
935
+
936
+ # Model
937
+ outputs = self.albert(
938
+ input_ids,
939
+ attention_mask,
940
+ token_type_ids,
941
+ position_ids,
942
+ deterministic=deterministic,
943
+ output_attentions=output_attentions,
944
+ output_hidden_states=output_hidden_states,
945
+ return_dict=return_dict,
946
+ )
947
+
948
+ pooled_output = outputs[1]
949
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
950
+ logits = self.classifier(pooled_output)
951
+
952
+ reshaped_logits = logits.reshape(-1, num_choices)
953
+
954
+ if not return_dict:
955
+ return (reshaped_logits,) + outputs[2:]
956
+
957
+ return FlaxMultipleChoiceModelOutput(
958
+ logits=reshaped_logits,
959
+ hidden_states=outputs.hidden_states,
960
+ attentions=outputs.attentions,
961
+ )
962
+
963
+
964
+ @add_start_docstrings(
965
+ """
966
+ Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
967
+ softmax) e.g. for RocStories/SWAG tasks.
968
+ """,
969
+ ALBERT_START_DOCSTRING,
970
+ )
971
+ class FlaxAlbertForMultipleChoice(FlaxAlbertPreTrainedModel):
972
+ module_class = FlaxAlbertForMultipleChoiceModule
973
+
974
+
975
+ overwrite_call_docstring(
976
+ FlaxAlbertForMultipleChoice, ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
977
+ )
978
+ append_call_sample_docstring(
979
+ FlaxAlbertForMultipleChoice,
980
+ _CHECKPOINT_FOR_DOC,
981
+ FlaxMultipleChoiceModelOutput,
982
+ _CONFIG_FOR_DOC,
983
+ )
984
+
985
+
986
+ class FlaxAlbertForTokenClassificationModule(nn.Module):
987
+ config: AlbertConfig
988
+ dtype: jnp.dtype = jnp.float32
989
+
990
+ def setup(self):
991
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
992
+ classifier_dropout = (
993
+ self.config.classifier_dropout_prob
994
+ if self.config.classifier_dropout_prob is not None
995
+ else self.config.hidden_dropout_prob
996
+ )
997
+ self.dropout = nn.Dropout(rate=classifier_dropout)
998
+ self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
999
+
1000
+ def __call__(
1001
+ self,
1002
+ input_ids,
1003
+ attention_mask,
1004
+ token_type_ids,
1005
+ position_ids,
1006
+ deterministic: bool = True,
1007
+ output_attentions: bool = False,
1008
+ output_hidden_states: bool = False,
1009
+ return_dict: bool = True,
1010
+ ):
1011
+ # Model
1012
+ outputs = self.albert(
1013
+ input_ids,
1014
+ attention_mask,
1015
+ token_type_ids,
1016
+ position_ids,
1017
+ deterministic=deterministic,
1018
+ output_attentions=output_attentions,
1019
+ output_hidden_states=output_hidden_states,
1020
+ return_dict=return_dict,
1021
+ )
1022
+
1023
+ hidden_states = outputs[0]
1024
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
1025
+ logits = self.classifier(hidden_states)
1026
+
1027
+ if not return_dict:
1028
+ return (logits,) + outputs[1:]
1029
+
1030
+ return FlaxTokenClassifierOutput(
1031
+ logits=logits,
1032
+ hidden_states=outputs.hidden_states,
1033
+ attentions=outputs.attentions,
1034
+ )
1035
+
1036
+
1037
+ @add_start_docstrings(
1038
+ """
1039
+ Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1040
+ Named-Entity-Recognition (NER) tasks.
1041
+ """,
1042
+ ALBERT_START_DOCSTRING,
1043
+ )
1044
+ class FlaxAlbertForTokenClassification(FlaxAlbertPreTrainedModel):
1045
+ module_class = FlaxAlbertForTokenClassificationModule
1046
+
1047
+
1048
+ append_call_sample_docstring(
1049
+ FlaxAlbertForTokenClassification,
1050
+ _CHECKPOINT_FOR_DOC,
1051
+ FlaxTokenClassifierOutput,
1052
+ _CONFIG_FOR_DOC,
1053
+ )
1054
+
1055
+
1056
+ class FlaxAlbertForQuestionAnsweringModule(nn.Module):
1057
+ config: AlbertConfig
1058
+ dtype: jnp.dtype = jnp.float32
1059
+
1060
+ def setup(self):
1061
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
1062
+ self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
1063
+
1064
+ def __call__(
1065
+ self,
1066
+ input_ids,
1067
+ attention_mask,
1068
+ token_type_ids,
1069
+ position_ids,
1070
+ deterministic: bool = True,
1071
+ output_attentions: bool = False,
1072
+ output_hidden_states: bool = False,
1073
+ return_dict: bool = True,
1074
+ ):
1075
+ # Model
1076
+ outputs = self.albert(
1077
+ input_ids,
1078
+ attention_mask,
1079
+ token_type_ids,
1080
+ position_ids,
1081
+ deterministic=deterministic,
1082
+ output_attentions=output_attentions,
1083
+ output_hidden_states=output_hidden_states,
1084
+ return_dict=return_dict,
1085
+ )
1086
+
1087
+ hidden_states = outputs[0]
1088
+
1089
+ logits = self.qa_outputs(hidden_states)
1090
+ start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
1091
+ start_logits = start_logits.squeeze(-1)
1092
+ end_logits = end_logits.squeeze(-1)
1093
+
1094
+ if not return_dict:
1095
+ return (start_logits, end_logits) + outputs[1:]
1096
+
1097
+ return FlaxQuestionAnsweringModelOutput(
1098
+ start_logits=start_logits,
1099
+ end_logits=end_logits,
1100
+ hidden_states=outputs.hidden_states,
1101
+ attentions=outputs.attentions,
1102
+ )
1103
+
1104
+
1105
+ @add_start_docstrings(
1106
+ """
1107
+ Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1108
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1109
+ """,
1110
+ ALBERT_START_DOCSTRING,
1111
+ )
1112
+ class FlaxAlbertForQuestionAnswering(FlaxAlbertPreTrainedModel):
1113
+ module_class = FlaxAlbertForQuestionAnsweringModule
1114
+
1115
+
1116
+ append_call_sample_docstring(
1117
+ FlaxAlbertForQuestionAnswering,
1118
+ _CHECKPOINT_FOR_DOC,
1119
+ FlaxQuestionAnsweringModelOutput,
1120
+ _CONFIG_FOR_DOC,
1121
+ )
1122
+
1123
+ __all__ = [
1124
+ "FlaxAlbertPreTrainedModel",
1125
+ "FlaxAlbertModel",
1126
+ "FlaxAlbertForPreTraining",
1127
+ "FlaxAlbertForMaskedLM",
1128
+ "FlaxAlbertForSequenceClassification",
1129
+ "FlaxAlbertForMultipleChoice",
1130
+ "FlaxAlbertForTokenClassification",
1131
+ "FlaxAlbertForQuestionAnswering",
1132
+ ]
venv/lib/python3.13/site-packages/transformers/models/albert/modeling_tf_albert.py ADDED
@@ -0,0 +1,1572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """TF 2.0 ALBERT model."""
17
+
18
+ from __future__ import annotations
19
+
20
+ import math
21
+ from dataclasses import dataclass
22
+
23
+ import numpy as np
24
+ import tensorflow as tf
25
+
26
+ from ...activations_tf import get_tf_activation
27
+ from ...modeling_tf_outputs import (
28
+ TFBaseModelOutput,
29
+ TFBaseModelOutputWithPooling,
30
+ TFMaskedLMOutput,
31
+ TFMultipleChoiceModelOutput,
32
+ TFQuestionAnsweringModelOutput,
33
+ TFSequenceClassifierOutput,
34
+ TFTokenClassifierOutput,
35
+ )
36
+ from ...modeling_tf_utils import (
37
+ TFMaskedLanguageModelingLoss,
38
+ TFModelInputType,
39
+ TFMultipleChoiceLoss,
40
+ TFPreTrainedModel,
41
+ TFQuestionAnsweringLoss,
42
+ TFSequenceClassificationLoss,
43
+ TFTokenClassificationLoss,
44
+ get_initializer,
45
+ keras,
46
+ keras_serializable,
47
+ unpack_inputs,
48
+ )
49
+ from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
50
+ from ...utils import (
51
+ ModelOutput,
52
+ add_code_sample_docstrings,
53
+ add_start_docstrings,
54
+ add_start_docstrings_to_model_forward,
55
+ logging,
56
+ replace_return_docstrings,
57
+ )
58
+ from .configuration_albert import AlbertConfig
59
+
60
+
61
+ logger = logging.get_logger(__name__)
62
+
63
+ _CHECKPOINT_FOR_DOC = "albert/albert-base-v2"
64
+ _CONFIG_FOR_DOC = "AlbertConfig"
65
+
66
+
67
+ class TFAlbertPreTrainingLoss:
68
+ """
69
+ Loss function suitable for ALBERT pretraining, that is, the task of pretraining a language model by combining SOP +
70
+ MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
71
+ """
72
+
73
+ def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
74
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
75
+ if self.config.tf_legacy_loss:
76
+ # make sure only labels that are not equal to -100
77
+ # are taken into account as loss
78
+ masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
79
+ masked_lm_reduced_logits = tf.boolean_mask(
80
+ tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
81
+ mask=masked_lm_active_loss,
82
+ )
83
+ masked_lm_labels = tf.boolean_mask(
84
+ tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
85
+ )
86
+ sentence_order_active_loss = tf.not_equal(
87
+ tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100
88
+ )
89
+ sentence_order_reduced_logits = tf.boolean_mask(
90
+ tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss
91
+ )
92
+ sentence_order_label = tf.boolean_mask(
93
+ tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss
94
+ )
95
+ masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
96
+ sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits)
97
+ masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0]))
98
+ masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)
99
+
100
+ return masked_lm_loss + sentence_order_loss
101
+
102
+ # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
103
+ unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
104
+ # make sure only labels that are not equal to -100
105
+ # are taken into account for the loss computation
106
+ lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
107
+ masked_lm_losses = unmasked_lm_losses * lm_loss_mask
108
+ reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask)
109
+
110
+ sop_logits = tf.reshape(logits[1], (-1, 2))
111
+ # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
112
+ unmasked_sop_loss = loss_fn(y_true=tf.nn.relu(labels["sentence_order_label"]), y_pred=sop_logits)
113
+ sop_loss_mask = tf.cast(labels["sentence_order_label"] != -100, dtype=unmasked_sop_loss.dtype)
114
+
115
+ masked_sop_loss = unmasked_sop_loss * sop_loss_mask
116
+ reduced_masked_sop_loss = tf.reduce_sum(masked_sop_loss) / tf.reduce_sum(sop_loss_mask)
117
+
118
+ return tf.reshape(reduced_masked_lm_loss + reduced_masked_sop_loss, (1,))
119
+
120
+
121
+ class TFAlbertEmbeddings(keras.layers.Layer):
122
+ """Construct the embeddings from word, position and token_type embeddings."""
123
+
124
+ def __init__(self, config: AlbertConfig, **kwargs):
125
+ super().__init__(**kwargs)
126
+
127
+ self.config = config
128
+ self.embedding_size = config.embedding_size
129
+ self.max_position_embeddings = config.max_position_embeddings
130
+ self.initializer_range = config.initializer_range
131
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
132
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
133
+
134
+ def build(self, input_shape=None):
135
+ with tf.name_scope("word_embeddings"):
136
+ self.weight = self.add_weight(
137
+ name="weight",
138
+ shape=[self.config.vocab_size, self.embedding_size],
139
+ initializer=get_initializer(self.initializer_range),
140
+ )
141
+
142
+ with tf.name_scope("token_type_embeddings"):
143
+ self.token_type_embeddings = self.add_weight(
144
+ name="embeddings",
145
+ shape=[self.config.type_vocab_size, self.embedding_size],
146
+ initializer=get_initializer(self.initializer_range),
147
+ )
148
+
149
+ with tf.name_scope("position_embeddings"):
150
+ self.position_embeddings = self.add_weight(
151
+ name="embeddings",
152
+ shape=[self.max_position_embeddings, self.embedding_size],
153
+ initializer=get_initializer(self.initializer_range),
154
+ )
155
+
156
+ if self.built:
157
+ return
158
+ self.built = True
159
+ if getattr(self, "LayerNorm", None) is not None:
160
+ with tf.name_scope(self.LayerNorm.name):
161
+ self.LayerNorm.build([None, None, self.config.embedding_size])
162
+
163
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
164
+ def call(
165
+ self,
166
+ input_ids: tf.Tensor | None = None,
167
+ position_ids: tf.Tensor | None = None,
168
+ token_type_ids: tf.Tensor | None = None,
169
+ inputs_embeds: tf.Tensor | None = None,
170
+ past_key_values_length=0,
171
+ training: bool = False,
172
+ ) -> tf.Tensor:
173
+ """
174
+ Applies embedding based on inputs tensor.
175
+
176
+ Returns:
177
+ final_embeddings (`tf.Tensor`): output embedding tensor.
178
+ """
179
+ if input_ids is None and inputs_embeds is None:
180
+ raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
181
+
182
+ if input_ids is not None:
183
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
184
+ inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
185
+
186
+ input_shape = shape_list(inputs_embeds)[:-1]
187
+
188
+ if token_type_ids is None:
189
+ token_type_ids = tf.fill(dims=input_shape, value=0)
190
+
191
+ if position_ids is None:
192
+ position_ids = tf.expand_dims(
193
+ tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
194
+ )
195
+
196
+ position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
197
+ token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
198
+ final_embeddings = inputs_embeds + position_embeds + token_type_embeds
199
+ final_embeddings = self.LayerNorm(inputs=final_embeddings)
200
+ final_embeddings = self.dropout(inputs=final_embeddings, training=training)
201
+
202
+ return final_embeddings
203
+
204
+
205
+ class TFAlbertAttention(keras.layers.Layer):
206
+ """Contains the complete attention sublayer, including both dropouts and layer norm."""
207
+
208
+ def __init__(self, config: AlbertConfig, **kwargs):
209
+ super().__init__(**kwargs)
210
+
211
+ if config.hidden_size % config.num_attention_heads != 0:
212
+ raise ValueError(
213
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number "
214
+ f"of attention heads ({config.num_attention_heads})"
215
+ )
216
+
217
+ self.num_attention_heads = config.num_attention_heads
218
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
219
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
220
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
221
+ self.output_attentions = config.output_attentions
222
+
223
+ self.query = keras.layers.Dense(
224
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
225
+ )
226
+ self.key = keras.layers.Dense(
227
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
228
+ )
229
+ self.value = keras.layers.Dense(
230
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
231
+ )
232
+ self.dense = keras.layers.Dense(
233
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
234
+ )
235
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
236
+ # Two different dropout probabilities; see https://github.com/google-research/albert/blob/master/modeling.py#L971-L993
237
+ self.attention_dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
238
+ self.output_dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
239
+ self.config = config
240
+
241
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
242
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
243
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
244
+
245
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
246
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
247
+
248
+ def call(
249
+ self,
250
+ input_tensor: tf.Tensor,
251
+ attention_mask: tf.Tensor,
252
+ head_mask: tf.Tensor,
253
+ output_attentions: bool,
254
+ training: bool = False,
255
+ ) -> tuple[tf.Tensor]:
256
+ batch_size = shape_list(input_tensor)[0]
257
+ mixed_query_layer = self.query(inputs=input_tensor)
258
+ mixed_key_layer = self.key(inputs=input_tensor)
259
+ mixed_value_layer = self.value(inputs=input_tensor)
260
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
261
+ key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
262
+ value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
263
+
264
+ # Take the dot product between "query" and "key" to get the raw attention scores.
265
+ # (batch size, num_heads, seq_len_q, seq_len_k)
266
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
267
+ dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
268
+ attention_scores = tf.divide(attention_scores, dk)
269
+
270
+ if attention_mask is not None:
271
+ # Apply the attention mask is (precomputed for all layers in TFAlbertModel call() function)
272
+ attention_scores = tf.add(attention_scores, attention_mask)
273
+
274
+ # Normalize the attention scores to probabilities.
275
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
276
+
277
+ # This is actually dropping out entire tokens to attend to, which might
278
+ # seem a bit unusual, but is taken from the original Transformer paper.
279
+ attention_probs = self.attention_dropout(inputs=attention_probs, training=training)
280
+
281
+ # Mask heads if we want to
282
+ if head_mask is not None:
283
+ attention_probs = tf.multiply(attention_probs, head_mask)
284
+
285
+ context_layer = tf.matmul(attention_probs, value_layer)
286
+ context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
287
+
288
+ # (batch_size, seq_len_q, all_head_size)
289
+ context_layer = tf.reshape(tensor=context_layer, shape=(batch_size, -1, self.all_head_size))
290
+ self_outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
291
+ hidden_states = self_outputs[0]
292
+ hidden_states = self.dense(inputs=hidden_states)
293
+ hidden_states = self.output_dropout(inputs=hidden_states, training=training)
294
+ attention_output = self.LayerNorm(inputs=hidden_states + input_tensor)
295
+
296
+ # add attentions if we output them
297
+ outputs = (attention_output,) + self_outputs[1:]
298
+
299
+ return outputs
300
+
301
+ def build(self, input_shape=None):
302
+ if self.built:
303
+ return
304
+ self.built = True
305
+ if getattr(self, "query", None) is not None:
306
+ with tf.name_scope(self.query.name):
307
+ self.query.build([None, None, self.config.hidden_size])
308
+ if getattr(self, "key", None) is not None:
309
+ with tf.name_scope(self.key.name):
310
+ self.key.build([None, None, self.config.hidden_size])
311
+ if getattr(self, "value", None) is not None:
312
+ with tf.name_scope(self.value.name):
313
+ self.value.build([None, None, self.config.hidden_size])
314
+ if getattr(self, "dense", None) is not None:
315
+ with tf.name_scope(self.dense.name):
316
+ self.dense.build([None, None, self.config.hidden_size])
317
+ if getattr(self, "LayerNorm", None) is not None:
318
+ with tf.name_scope(self.LayerNorm.name):
319
+ self.LayerNorm.build([None, None, self.config.hidden_size])
320
+
321
+
322
+ class TFAlbertLayer(keras.layers.Layer):
323
+ def __init__(self, config: AlbertConfig, **kwargs):
324
+ super().__init__(**kwargs)
325
+
326
+ self.attention = TFAlbertAttention(config, name="attention")
327
+ self.ffn = keras.layers.Dense(
328
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn"
329
+ )
330
+
331
+ if isinstance(config.hidden_act, str):
332
+ self.activation = get_tf_activation(config.hidden_act)
333
+ else:
334
+ self.activation = config.hidden_act
335
+
336
+ self.ffn_output = keras.layers.Dense(
337
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn_output"
338
+ )
339
+ self.full_layer_layer_norm = keras.layers.LayerNormalization(
340
+ epsilon=config.layer_norm_eps, name="full_layer_layer_norm"
341
+ )
342
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
343
+ self.config = config
344
+
345
+ def call(
346
+ self,
347
+ hidden_states: tf.Tensor,
348
+ attention_mask: tf.Tensor,
349
+ head_mask: tf.Tensor,
350
+ output_attentions: bool,
351
+ training: bool = False,
352
+ ) -> tuple[tf.Tensor]:
353
+ attention_outputs = self.attention(
354
+ input_tensor=hidden_states,
355
+ attention_mask=attention_mask,
356
+ head_mask=head_mask,
357
+ output_attentions=output_attentions,
358
+ training=training,
359
+ )
360
+ ffn_output = self.ffn(inputs=attention_outputs[0])
361
+ ffn_output = self.activation(ffn_output)
362
+ ffn_output = self.ffn_output(inputs=ffn_output)
363
+ ffn_output = self.dropout(inputs=ffn_output, training=training)
364
+ hidden_states = self.full_layer_layer_norm(inputs=ffn_output + attention_outputs[0])
365
+
366
+ # add attentions if we output them
367
+ outputs = (hidden_states,) + attention_outputs[1:]
368
+
369
+ return outputs
370
+
371
+ def build(self, input_shape=None):
372
+ if self.built:
373
+ return
374
+ self.built = True
375
+ if getattr(self, "attention", None) is not None:
376
+ with tf.name_scope(self.attention.name):
377
+ self.attention.build(None)
378
+ if getattr(self, "ffn", None) is not None:
379
+ with tf.name_scope(self.ffn.name):
380
+ self.ffn.build([None, None, self.config.hidden_size])
381
+ if getattr(self, "ffn_output", None) is not None:
382
+ with tf.name_scope(self.ffn_output.name):
383
+ self.ffn_output.build([None, None, self.config.intermediate_size])
384
+ if getattr(self, "full_layer_layer_norm", None) is not None:
385
+ with tf.name_scope(self.full_layer_layer_norm.name):
386
+ self.full_layer_layer_norm.build([None, None, self.config.hidden_size])
387
+
388
+
389
+ class TFAlbertLayerGroup(keras.layers.Layer):
390
+ def __init__(self, config: AlbertConfig, **kwargs):
391
+ super().__init__(**kwargs)
392
+
393
+ self.albert_layers = [
394
+ TFAlbertLayer(config, name=f"albert_layers_._{i}") for i in range(config.inner_group_num)
395
+ ]
396
+
397
+ def call(
398
+ self,
399
+ hidden_states: tf.Tensor,
400
+ attention_mask: tf.Tensor,
401
+ head_mask: tf.Tensor,
402
+ output_attentions: bool,
403
+ output_hidden_states: bool,
404
+ training: bool = False,
405
+ ) -> TFBaseModelOutput | tuple[tf.Tensor]:
406
+ layer_hidden_states = () if output_hidden_states else None
407
+ layer_attentions = () if output_attentions else None
408
+
409
+ for layer_index, albert_layer in enumerate(self.albert_layers):
410
+ if output_hidden_states:
411
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
412
+
413
+ layer_output = albert_layer(
414
+ hidden_states=hidden_states,
415
+ attention_mask=attention_mask,
416
+ head_mask=head_mask[layer_index],
417
+ output_attentions=output_attentions,
418
+ training=training,
419
+ )
420
+ hidden_states = layer_output[0]
421
+
422
+ if output_attentions:
423
+ layer_attentions = layer_attentions + (layer_output[1],)
424
+
425
+ # Add last layer
426
+ if output_hidden_states:
427
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
428
+
429
+ return tuple(v for v in [hidden_states, layer_hidden_states, layer_attentions] if v is not None)
430
+
431
+ def build(self, input_shape=None):
432
+ if self.built:
433
+ return
434
+ self.built = True
435
+ if getattr(self, "albert_layers", None) is not None:
436
+ for layer in self.albert_layers:
437
+ with tf.name_scope(layer.name):
438
+ layer.build(None)
439
+
440
+
441
+ class TFAlbertTransformer(keras.layers.Layer):
442
+ def __init__(self, config: AlbertConfig, **kwargs):
443
+ super().__init__(**kwargs)
444
+
445
+ self.num_hidden_layers = config.num_hidden_layers
446
+ self.num_hidden_groups = config.num_hidden_groups
447
+ # Number of layers in a hidden group
448
+ self.layers_per_group = int(config.num_hidden_layers / config.num_hidden_groups)
449
+ self.embedding_hidden_mapping_in = keras.layers.Dense(
450
+ units=config.hidden_size,
451
+ kernel_initializer=get_initializer(config.initializer_range),
452
+ name="embedding_hidden_mapping_in",
453
+ )
454
+ self.albert_layer_groups = [
455
+ TFAlbertLayerGroup(config, name=f"albert_layer_groups_._{i}") for i in range(config.num_hidden_groups)
456
+ ]
457
+ self.config = config
458
+
459
+ def call(
460
+ self,
461
+ hidden_states: tf.Tensor,
462
+ attention_mask: tf.Tensor,
463
+ head_mask: tf.Tensor,
464
+ output_attentions: bool,
465
+ output_hidden_states: bool,
466
+ return_dict: bool,
467
+ training: bool = False,
468
+ ) -> TFBaseModelOutput | tuple[tf.Tensor]:
469
+ hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states)
470
+ all_attentions = () if output_attentions else None
471
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
472
+
473
+ for i in range(self.num_hidden_layers):
474
+ # Index of the hidden group
475
+ group_idx = int(i / (self.num_hidden_layers / self.num_hidden_groups))
476
+ layer_group_output = self.albert_layer_groups[group_idx](
477
+ hidden_states=hidden_states,
478
+ attention_mask=attention_mask,
479
+ head_mask=head_mask[group_idx * self.layers_per_group : (group_idx + 1) * self.layers_per_group],
480
+ output_attentions=output_attentions,
481
+ output_hidden_states=output_hidden_states,
482
+ training=training,
483
+ )
484
+ hidden_states = layer_group_output[0]
485
+
486
+ if output_attentions:
487
+ all_attentions = all_attentions + layer_group_output[-1]
488
+
489
+ if output_hidden_states:
490
+ all_hidden_states = all_hidden_states + (hidden_states,)
491
+
492
+ if not return_dict:
493
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
494
+
495
+ return TFBaseModelOutput(
496
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
497
+ )
498
+
499
+ def build(self, input_shape=None):
500
+ if self.built:
501
+ return
502
+ self.built = True
503
+ if getattr(self, "embedding_hidden_mapping_in", None) is not None:
504
+ with tf.name_scope(self.embedding_hidden_mapping_in.name):
505
+ self.embedding_hidden_mapping_in.build([None, None, self.config.embedding_size])
506
+ if getattr(self, "albert_layer_groups", None) is not None:
507
+ for layer in self.albert_layer_groups:
508
+ with tf.name_scope(layer.name):
509
+ layer.build(None)
510
+
511
+
512
+ class TFAlbertPreTrainedModel(TFPreTrainedModel):
513
+ """
514
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
515
+ models.
516
+ """
517
+
518
+ config_class = AlbertConfig
519
+ base_model_prefix = "albert"
520
+
521
+
522
+ class TFAlbertMLMHead(keras.layers.Layer):
523
+ def __init__(self, config: AlbertConfig, input_embeddings: keras.layers.Layer, **kwargs):
524
+ super().__init__(**kwargs)
525
+
526
+ self.config = config
527
+ self.embedding_size = config.embedding_size
528
+ self.dense = keras.layers.Dense(
529
+ config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
530
+ )
531
+ if isinstance(config.hidden_act, str):
532
+ self.activation = get_tf_activation(config.hidden_act)
533
+ else:
534
+ self.activation = config.hidden_act
535
+
536
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
537
+
538
+ # The output weights are the same as the input embeddings, but there is
539
+ # an output-only bias for each token.
540
+ self.decoder = input_embeddings
541
+
542
+ def build(self, input_shape=None):
543
+ self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
544
+ self.decoder_bias = self.add_weight(
545
+ shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias"
546
+ )
547
+
548
+ if self.built:
549
+ return
550
+ self.built = True
551
+ if getattr(self, "dense", None) is not None:
552
+ with tf.name_scope(self.dense.name):
553
+ self.dense.build([None, None, self.config.hidden_size])
554
+ if getattr(self, "LayerNorm", None) is not None:
555
+ with tf.name_scope(self.LayerNorm.name):
556
+ self.LayerNorm.build([None, None, self.config.embedding_size])
557
+
558
+ def get_output_embeddings(self) -> keras.layers.Layer:
559
+ return self.decoder
560
+
561
+ def set_output_embeddings(self, value: tf.Variable):
562
+ self.decoder.weight = value
563
+ self.decoder.vocab_size = shape_list(value)[0]
564
+
565
+ def get_bias(self) -> dict[str, tf.Variable]:
566
+ return {"bias": self.bias, "decoder_bias": self.decoder_bias}
567
+
568
+ def set_bias(self, value: tf.Variable):
569
+ self.bias = value["bias"]
570
+ self.decoder_bias = value["decoder_bias"]
571
+ self.config.vocab_size = shape_list(value["bias"])[0]
572
+
573
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
574
+ hidden_states = self.dense(inputs=hidden_states)
575
+ hidden_states = self.activation(hidden_states)
576
+ hidden_states = self.LayerNorm(inputs=hidden_states)
577
+ seq_length = shape_list(tensor=hidden_states)[1]
578
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
579
+ hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True)
580
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
581
+ hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.decoder_bias)
582
+
583
+ return hidden_states
584
+
585
+
586
+ @keras_serializable
587
+ class TFAlbertMainLayer(keras.layers.Layer):
588
+ config_class = AlbertConfig
589
+
590
+ def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True, **kwargs):
591
+ super().__init__(**kwargs)
592
+
593
+ self.config = config
594
+
595
+ self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
596
+ self.encoder = TFAlbertTransformer(config, name="encoder")
597
+ self.pooler = (
598
+ keras.layers.Dense(
599
+ units=config.hidden_size,
600
+ kernel_initializer=get_initializer(config.initializer_range),
601
+ activation="tanh",
602
+ name="pooler",
603
+ )
604
+ if add_pooling_layer
605
+ else None
606
+ )
607
+
608
+ def get_input_embeddings(self) -> keras.layers.Layer:
609
+ return self.embeddings
610
+
611
+ def set_input_embeddings(self, value: tf.Variable):
612
+ self.embeddings.weight = value
613
+ self.embeddings.vocab_size = shape_list(value)[0]
614
+
615
+ def _prune_heads(self, heads_to_prune):
616
+ """
617
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
618
+ class PreTrainedModel
619
+ """
620
+ raise NotImplementedError
621
+
622
+ @unpack_inputs
623
+ def call(
624
+ self,
625
+ input_ids: TFModelInputType | None = None,
626
+ attention_mask: np.ndarray | tf.Tensor | None = None,
627
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
628
+ position_ids: np.ndarray | tf.Tensor | None = None,
629
+ head_mask: np.ndarray | tf.Tensor | None = None,
630
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
631
+ output_attentions: bool | None = None,
632
+ output_hidden_states: bool | None = None,
633
+ return_dict: bool | None = None,
634
+ training: bool = False,
635
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
636
+ if input_ids is not None and inputs_embeds is not None:
637
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
638
+ elif input_ids is not None:
639
+ input_shape = shape_list(input_ids)
640
+ elif inputs_embeds is not None:
641
+ input_shape = shape_list(inputs_embeds)[:-1]
642
+ else:
643
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
644
+
645
+ if attention_mask is None:
646
+ attention_mask = tf.fill(dims=input_shape, value=1)
647
+
648
+ if token_type_ids is None:
649
+ token_type_ids = tf.fill(dims=input_shape, value=0)
650
+
651
+ embedding_output = self.embeddings(
652
+ input_ids=input_ids,
653
+ position_ids=position_ids,
654
+ token_type_ids=token_type_ids,
655
+ inputs_embeds=inputs_embeds,
656
+ training=training,
657
+ )
658
+
659
+ # We create a 3D attention mask from a 2D tensor mask.
660
+ # Sizes are [batch_size, 1, 1, to_seq_length]
661
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
662
+ # this attention mask is more simple than the triangular masking of causal attention
663
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
664
+ extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
665
+
666
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
667
+ # masked positions, this operation will create a tensor which is 0.0 for
668
+ # positions we want to attend and -10000.0 for masked positions.
669
+ # Since we are adding it to the raw scores before the softmax, this is
670
+ # effectively the same as removing these entirely.
671
+ extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
672
+ one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
673
+ ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
674
+ extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
675
+
676
+ # Prepare head mask if needed
677
+ # 1.0 in head_mask indicate we keep the head
678
+ # attention_probs has shape bsz x n_heads x N x N
679
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
680
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
681
+ if head_mask is not None:
682
+ raise NotImplementedError
683
+ else:
684
+ head_mask = [None] * self.config.num_hidden_layers
685
+
686
+ encoder_outputs = self.encoder(
687
+ hidden_states=embedding_output,
688
+ attention_mask=extended_attention_mask,
689
+ head_mask=head_mask,
690
+ output_attentions=output_attentions,
691
+ output_hidden_states=output_hidden_states,
692
+ return_dict=return_dict,
693
+ training=training,
694
+ )
695
+
696
+ sequence_output = encoder_outputs[0]
697
+ pooled_output = self.pooler(inputs=sequence_output[:, 0]) if self.pooler is not None else None
698
+
699
+ if not return_dict:
700
+ return (
701
+ sequence_output,
702
+ pooled_output,
703
+ ) + encoder_outputs[1:]
704
+
705
+ return TFBaseModelOutputWithPooling(
706
+ last_hidden_state=sequence_output,
707
+ pooler_output=pooled_output,
708
+ hidden_states=encoder_outputs.hidden_states,
709
+ attentions=encoder_outputs.attentions,
710
+ )
711
+
712
+ def build(self, input_shape=None):
713
+ if self.built:
714
+ return
715
+ self.built = True
716
+ if getattr(self, "embeddings", None) is not None:
717
+ with tf.name_scope(self.embeddings.name):
718
+ self.embeddings.build(None)
719
+ if getattr(self, "encoder", None) is not None:
720
+ with tf.name_scope(self.encoder.name):
721
+ self.encoder.build(None)
722
+ if getattr(self, "pooler", None) is not None:
723
+ with tf.name_scope(self.pooler.name):
724
+ self.pooler.build([None, None, self.config.hidden_size])
725
+
726
+
727
+ @dataclass
728
+ class TFAlbertForPreTrainingOutput(ModelOutput):
729
+ """
730
+ Output type of [`TFAlbertForPreTraining`].
731
+
732
+ Args:
733
+ prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
734
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
735
+ sop_logits (`tf.Tensor` of shape `(batch_size, 2)`):
736
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
737
+ before SoftMax).
738
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
739
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
740
+ `(batch_size, sequence_length, hidden_size)`.
741
+
742
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
743
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
744
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
745
+ sequence_length)`.
746
+
747
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
748
+ heads.
749
+ """
750
+
751
+ loss: tf.Tensor | None = None
752
+ prediction_logits: tf.Tensor | None = None
753
+ sop_logits: tf.Tensor | None = None
754
+ hidden_states: tuple[tf.Tensor] | None = None
755
+ attentions: tuple[tf.Tensor] | None = None
756
+
757
+
758
+ ALBERT_START_DOCSTRING = r"""
759
+
760
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
761
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
762
+ etc.)
763
+
764
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
765
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
766
+ behavior.
767
+
768
+ <Tip>
769
+
770
+ TensorFlow models and layers in `transformers` accept two formats as input:
771
+
772
+ - having all inputs as keyword arguments (like PyTorch models), or
773
+ - having all inputs as a list, tuple or dict in the first positional argument.
774
+
775
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
776
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
777
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
778
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
779
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
780
+ positional argument:
781
+
782
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
783
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
784
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
785
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
786
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
787
+
788
+ Note that when creating models and layers with
789
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
790
+ about any of this, as you can just pass inputs like you would to any other Python function!
791
+
792
+ </Tip>
793
+
794
+ Args:
795
+ config ([`AlbertConfig`]): Model configuration class with all the parameters of the model.
796
+ Initializing with a config file does not load the weights associated with the model, only the
797
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
798
+ """
799
+
800
+ ALBERT_INPUTS_DOCSTRING = r"""
801
+ Args:
802
+ input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
803
+ Indices of input sequence tokens in the vocabulary.
804
+
805
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
806
+ [`PreTrainedTokenizer.encode`] for details.
807
+
808
+ [What are input IDs?](../glossary#input-ids)
809
+ attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
810
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
811
+
812
+ - 1 for tokens that are **not masked**,
813
+ - 0 for tokens that are **masked**.
814
+
815
+ [What are attention masks?](../glossary#attention-mask)
816
+ token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
817
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
818
+ 1]`:
819
+
820
+ - 0 corresponds to a *sentence A* token,
821
+ - 1 corresponds to a *sentence B* token.
822
+
823
+ [What are token type IDs?](../glossary#token-type-ids)
824
+ position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
825
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
826
+ config.max_position_embeddings - 1]`.
827
+
828
+ [What are position IDs?](../glossary#position-ids)
829
+ head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
830
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
831
+
832
+ - 1 indicates the head is **not masked**,
833
+ - 0 indicates the head is **masked**.
834
+
835
+ inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
836
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
837
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
838
+ model's internal embedding lookup matrix.
839
+ output_attentions (`bool`, *optional*):
840
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
841
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
842
+ config will be used instead.
843
+ output_hidden_states (`bool`, *optional*):
844
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
845
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
846
+ used instead.
847
+ return_dict (`bool`, *optional*):
848
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
849
+ eager mode, in graph mode the value will always be set to True.
850
+ training (`bool`, *optional*, defaults to `False`):
851
+ Whether or not to use the model in training mode (some modules like dropout modules have different
852
+ behaviors between training and evaluation).
853
+ """
854
+
855
+
856
+ @add_start_docstrings(
857
+ "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.",
858
+ ALBERT_START_DOCSTRING,
859
+ )
860
+ class TFAlbertModel(TFAlbertPreTrainedModel):
861
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
862
+ super().__init__(config, *inputs, **kwargs)
863
+
864
+ self.albert = TFAlbertMainLayer(config, name="albert")
865
+
866
+ @unpack_inputs
867
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
868
+ @add_code_sample_docstrings(
869
+ checkpoint=_CHECKPOINT_FOR_DOC,
870
+ output_type=TFBaseModelOutputWithPooling,
871
+ config_class=_CONFIG_FOR_DOC,
872
+ )
873
+ def call(
874
+ self,
875
+ input_ids: TFModelInputType | None = None,
876
+ attention_mask: np.ndarray | tf.Tensor | None = None,
877
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
878
+ position_ids: np.ndarray | tf.Tensor | None = None,
879
+ head_mask: np.ndarray | tf.Tensor | None = None,
880
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
881
+ output_attentions: bool | None = None,
882
+ output_hidden_states: bool | None = None,
883
+ return_dict: bool | None = None,
884
+ training: bool | None = False,
885
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
886
+ outputs = self.albert(
887
+ input_ids=input_ids,
888
+ attention_mask=attention_mask,
889
+ token_type_ids=token_type_ids,
890
+ position_ids=position_ids,
891
+ head_mask=head_mask,
892
+ inputs_embeds=inputs_embeds,
893
+ output_attentions=output_attentions,
894
+ output_hidden_states=output_hidden_states,
895
+ return_dict=return_dict,
896
+ training=training,
897
+ )
898
+
899
+ return outputs
900
+
901
+ def build(self, input_shape=None):
902
+ if self.built:
903
+ return
904
+ self.built = True
905
+ if getattr(self, "albert", None) is not None:
906
+ with tf.name_scope(self.albert.name):
907
+ self.albert.build(None)
908
+
909
+
910
+ @add_start_docstrings(
911
+ """
912
+ Albert Model with two heads on top for pretraining: a `masked language modeling` head and a `sentence order
913
+ prediction` (classification) head.
914
+ """,
915
+ ALBERT_START_DOCSTRING,
916
+ )
917
+ class TFAlbertForPreTraining(TFAlbertPreTrainedModel, TFAlbertPreTrainingLoss):
918
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
919
+ _keys_to_ignore_on_load_unexpected = [r"predictions.decoder.weight"]
920
+
921
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
922
+ super().__init__(config, *inputs, **kwargs)
923
+
924
+ self.num_labels = config.num_labels
925
+
926
+ self.albert = TFAlbertMainLayer(config, name="albert")
927
+ self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions")
928
+ self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier")
929
+
930
+ def get_lm_head(self) -> keras.layers.Layer:
931
+ return self.predictions
932
+
933
+ @unpack_inputs
934
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
935
+ @replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
936
+ def call(
937
+ self,
938
+ input_ids: TFModelInputType | None = None,
939
+ attention_mask: np.ndarray | tf.Tensor | None = None,
940
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
941
+ position_ids: np.ndarray | tf.Tensor | None = None,
942
+ head_mask: np.ndarray | tf.Tensor | None = None,
943
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
944
+ output_attentions: bool | None = None,
945
+ output_hidden_states: bool | None = None,
946
+ return_dict: bool | None = None,
947
+ labels: np.ndarray | tf.Tensor | None = None,
948
+ sentence_order_label: np.ndarray | tf.Tensor | None = None,
949
+ training: bool | None = False,
950
+ ) -> TFAlbertForPreTrainingOutput | tuple[tf.Tensor]:
951
+ r"""
952
+ Return:
953
+
954
+ Example:
955
+
956
+ ```python
957
+ >>> import tensorflow as tf
958
+ >>> from transformers import AutoTokenizer, TFAlbertForPreTraining
959
+
960
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
961
+ >>> model = TFAlbertForPreTraining.from_pretrained("albert/albert-base-v2")
962
+
963
+ >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]
964
+ >>> # Batch size 1
965
+ >>> outputs = model(input_ids)
966
+
967
+ >>> prediction_logits = outputs.prediction_logits
968
+ >>> sop_logits = outputs.sop_logits
969
+ ```"""
970
+
971
+ outputs = self.albert(
972
+ input_ids=input_ids,
973
+ attention_mask=attention_mask,
974
+ token_type_ids=token_type_ids,
975
+ position_ids=position_ids,
976
+ head_mask=head_mask,
977
+ inputs_embeds=inputs_embeds,
978
+ output_attentions=output_attentions,
979
+ output_hidden_states=output_hidden_states,
980
+ return_dict=return_dict,
981
+ training=training,
982
+ )
983
+ sequence_output, pooled_output = outputs[:2]
984
+ prediction_scores = self.predictions(hidden_states=sequence_output)
985
+ sop_scores = self.sop_classifier(pooled_output=pooled_output, training=training)
986
+ total_loss = None
987
+
988
+ if labels is not None and sentence_order_label is not None:
989
+ d_labels = {"labels": labels}
990
+ d_labels["sentence_order_label"] = sentence_order_label
991
+ total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, sop_scores))
992
+
993
+ if not return_dict:
994
+ output = (prediction_scores, sop_scores) + outputs[2:]
995
+ return ((total_loss,) + output) if total_loss is not None else output
996
+
997
+ return TFAlbertForPreTrainingOutput(
998
+ loss=total_loss,
999
+ prediction_logits=prediction_scores,
1000
+ sop_logits=sop_scores,
1001
+ hidden_states=outputs.hidden_states,
1002
+ attentions=outputs.attentions,
1003
+ )
1004
+
1005
+ def build(self, input_shape=None):
1006
+ if self.built:
1007
+ return
1008
+ self.built = True
1009
+ if getattr(self, "albert", None) is not None:
1010
+ with tf.name_scope(self.albert.name):
1011
+ self.albert.build(None)
1012
+ if getattr(self, "predictions", None) is not None:
1013
+ with tf.name_scope(self.predictions.name):
1014
+ self.predictions.build(None)
1015
+ if getattr(self, "sop_classifier", None) is not None:
1016
+ with tf.name_scope(self.sop_classifier.name):
1017
+ self.sop_classifier.build(None)
1018
+
1019
+
1020
+ class TFAlbertSOPHead(keras.layers.Layer):
1021
+ def __init__(self, config: AlbertConfig, **kwargs):
1022
+ super().__init__(**kwargs)
1023
+
1024
+ self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob)
1025
+ self.classifier = keras.layers.Dense(
1026
+ units=config.num_labels,
1027
+ kernel_initializer=get_initializer(config.initializer_range),
1028
+ name="classifier",
1029
+ )
1030
+ self.config = config
1031
+
1032
+ def call(self, pooled_output: tf.Tensor, training: bool) -> tf.Tensor:
1033
+ dropout_pooled_output = self.dropout(inputs=pooled_output, training=training)
1034
+ logits = self.classifier(inputs=dropout_pooled_output)
1035
+
1036
+ return logits
1037
+
1038
+ def build(self, input_shape=None):
1039
+ if self.built:
1040
+ return
1041
+ self.built = True
1042
+ if getattr(self, "classifier", None) is not None:
1043
+ with tf.name_scope(self.classifier.name):
1044
+ self.classifier.build([None, None, self.config.hidden_size])
1045
+
1046
+
1047
+ @add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING)
1048
+ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
1049
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1050
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions.decoder.weight"]
1051
+
1052
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
1053
+ super().__init__(config, *inputs, **kwargs)
1054
+
1055
+ self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
1056
+ self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions")
1057
+
1058
+ def get_lm_head(self) -> keras.layers.Layer:
1059
+ return self.predictions
1060
+
1061
+ @unpack_inputs
1062
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1063
+ @replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
1064
+ def call(
1065
+ self,
1066
+ input_ids: TFModelInputType | None = None,
1067
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1068
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1069
+ position_ids: np.ndarray | tf.Tensor | None = None,
1070
+ head_mask: np.ndarray | tf.Tensor | None = None,
1071
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1072
+ output_attentions: bool | None = None,
1073
+ output_hidden_states: bool | None = None,
1074
+ return_dict: bool | None = None,
1075
+ labels: np.ndarray | tf.Tensor | None = None,
1076
+ training: bool | None = False,
1077
+ ) -> TFMaskedLMOutput | tuple[tf.Tensor]:
1078
+ r"""
1079
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1080
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1081
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1082
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1083
+
1084
+ Returns:
1085
+
1086
+ Example:
1087
+
1088
+ ```python
1089
+ >>> import tensorflow as tf
1090
+ >>> from transformers import AutoTokenizer, TFAlbertForMaskedLM
1091
+
1092
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
1093
+ >>> model = TFAlbertForMaskedLM.from_pretrained("albert/albert-base-v2")
1094
+
1095
+ >>> # add mask_token
1096
+ >>> inputs = tokenizer(f"The capital of [MASK] is Paris.", return_tensors="tf")
1097
+ >>> logits = model(**inputs).logits
1098
+
1099
+ >>> # retrieve index of [MASK]
1100
+ >>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1]
1101
+ >>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1)
1102
+ >>> tokenizer.decode(predicted_token_id)
1103
+ 'france'
1104
+ ```
1105
+
1106
+ ```python
1107
+ >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
1108
+ >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
1109
+ >>> outputs = model(**inputs, labels=labels)
1110
+ >>> round(float(outputs.loss), 2)
1111
+ 0.81
1112
+ ```
1113
+ """
1114
+ outputs = self.albert(
1115
+ input_ids=input_ids,
1116
+ attention_mask=attention_mask,
1117
+ token_type_ids=token_type_ids,
1118
+ position_ids=position_ids,
1119
+ head_mask=head_mask,
1120
+ inputs_embeds=inputs_embeds,
1121
+ output_attentions=output_attentions,
1122
+ output_hidden_states=output_hidden_states,
1123
+ return_dict=return_dict,
1124
+ training=training,
1125
+ )
1126
+ sequence_output = outputs[0]
1127
+ prediction_scores = self.predictions(hidden_states=sequence_output, training=training)
1128
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
1129
+
1130
+ if not return_dict:
1131
+ output = (prediction_scores,) + outputs[2:]
1132
+
1133
+ return ((loss,) + output) if loss is not None else output
1134
+
1135
+ return TFMaskedLMOutput(
1136
+ loss=loss,
1137
+ logits=prediction_scores,
1138
+ hidden_states=outputs.hidden_states,
1139
+ attentions=outputs.attentions,
1140
+ )
1141
+
1142
+ def build(self, input_shape=None):
1143
+ if self.built:
1144
+ return
1145
+ self.built = True
1146
+ if getattr(self, "albert", None) is not None:
1147
+ with tf.name_scope(self.albert.name):
1148
+ self.albert.build(None)
1149
+ if getattr(self, "predictions", None) is not None:
1150
+ with tf.name_scope(self.predictions.name):
1151
+ self.predictions.build(None)
1152
+
1153
+
1154
+ @add_start_docstrings(
1155
+ """
1156
+ Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1157
+ output) e.g. for GLUE tasks.
1158
+ """,
1159
+ ALBERT_START_DOCSTRING,
1160
+ )
1161
+ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClassificationLoss):
1162
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1163
+ _keys_to_ignore_on_load_unexpected = [r"predictions"]
1164
+ _keys_to_ignore_on_load_missing = [r"dropout"]
1165
+
1166
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
1167
+ super().__init__(config, *inputs, **kwargs)
1168
+
1169
+ self.num_labels = config.num_labels
1170
+
1171
+ self.albert = TFAlbertMainLayer(config, name="albert")
1172
+ self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob)
1173
+ self.classifier = keras.layers.Dense(
1174
+ units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
1175
+ )
1176
+ self.config = config
1177
+
1178
+ @unpack_inputs
1179
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1180
+ @add_code_sample_docstrings(
1181
+ checkpoint="vumichien/albert-base-v2-imdb",
1182
+ output_type=TFSequenceClassifierOutput,
1183
+ config_class=_CONFIG_FOR_DOC,
1184
+ expected_output="'LABEL_1'",
1185
+ expected_loss=0.12,
1186
+ )
1187
+ def call(
1188
+ self,
1189
+ input_ids: TFModelInputType | None = None,
1190
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1191
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1192
+ position_ids: np.ndarray | tf.Tensor | None = None,
1193
+ head_mask: np.ndarray | tf.Tensor | None = None,
1194
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1195
+ output_attentions: bool | None = None,
1196
+ output_hidden_states: bool | None = None,
1197
+ return_dict: bool | None = None,
1198
+ labels: np.ndarray | tf.Tensor | None = None,
1199
+ training: bool | None = False,
1200
+ ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]:
1201
+ r"""
1202
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1203
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1204
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1205
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1206
+ """
1207
+ outputs = self.albert(
1208
+ input_ids=input_ids,
1209
+ attention_mask=attention_mask,
1210
+ token_type_ids=token_type_ids,
1211
+ position_ids=position_ids,
1212
+ head_mask=head_mask,
1213
+ inputs_embeds=inputs_embeds,
1214
+ output_attentions=output_attentions,
1215
+ output_hidden_states=output_hidden_states,
1216
+ return_dict=return_dict,
1217
+ training=training,
1218
+ )
1219
+ pooled_output = outputs[1]
1220
+ pooled_output = self.dropout(inputs=pooled_output, training=training)
1221
+ logits = self.classifier(inputs=pooled_output)
1222
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
1223
+
1224
+ if not return_dict:
1225
+ output = (logits,) + outputs[2:]
1226
+
1227
+ return ((loss,) + output) if loss is not None else output
1228
+
1229
+ return TFSequenceClassifierOutput(
1230
+ loss=loss,
1231
+ logits=logits,
1232
+ hidden_states=outputs.hidden_states,
1233
+ attentions=outputs.attentions,
1234
+ )
1235
+
1236
+ def build(self, input_shape=None):
1237
+ if self.built:
1238
+ return
1239
+ self.built = True
1240
+ if getattr(self, "albert", None) is not None:
1241
+ with tf.name_scope(self.albert.name):
1242
+ self.albert.build(None)
1243
+ if getattr(self, "classifier", None) is not None:
1244
+ with tf.name_scope(self.classifier.name):
1245
+ self.classifier.build([None, None, self.config.hidden_size])
1246
+
1247
+
1248
+ @add_start_docstrings(
1249
+ """
1250
+ Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1251
+ Named-Entity-Recognition (NER) tasks.
1252
+ """,
1253
+ ALBERT_START_DOCSTRING,
1254
+ )
1255
+ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
1256
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1257
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
1258
+ _keys_to_ignore_on_load_missing = [r"dropout"]
1259
+
1260
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
1261
+ super().__init__(config, *inputs, **kwargs)
1262
+
1263
+ self.num_labels = config.num_labels
1264
+
1265
+ self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
1266
+ classifier_dropout_prob = (
1267
+ config.classifier_dropout_prob
1268
+ if config.classifier_dropout_prob is not None
1269
+ else config.hidden_dropout_prob
1270
+ )
1271
+ self.dropout = keras.layers.Dropout(rate=classifier_dropout_prob)
1272
+ self.classifier = keras.layers.Dense(
1273
+ units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
1274
+ )
1275
+ self.config = config
1276
+
1277
+ @unpack_inputs
1278
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1279
+ @add_code_sample_docstrings(
1280
+ checkpoint=_CHECKPOINT_FOR_DOC,
1281
+ output_type=TFTokenClassifierOutput,
1282
+ config_class=_CONFIG_FOR_DOC,
1283
+ )
1284
+ def call(
1285
+ self,
1286
+ input_ids: TFModelInputType | None = None,
1287
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1288
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1289
+ position_ids: np.ndarray | tf.Tensor | None = None,
1290
+ head_mask: np.ndarray | tf.Tensor | None = None,
1291
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1292
+ output_attentions: bool | None = None,
1293
+ output_hidden_states: bool | None = None,
1294
+ return_dict: bool | None = None,
1295
+ labels: np.ndarray | tf.Tensor | None = None,
1296
+ training: bool | None = False,
1297
+ ) -> TFTokenClassifierOutput | tuple[tf.Tensor]:
1298
+ r"""
1299
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1300
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1301
+ """
1302
+ outputs = self.albert(
1303
+ input_ids=input_ids,
1304
+ attention_mask=attention_mask,
1305
+ token_type_ids=token_type_ids,
1306
+ position_ids=position_ids,
1307
+ head_mask=head_mask,
1308
+ inputs_embeds=inputs_embeds,
1309
+ output_attentions=output_attentions,
1310
+ output_hidden_states=output_hidden_states,
1311
+ return_dict=return_dict,
1312
+ training=training,
1313
+ )
1314
+ sequence_output = outputs[0]
1315
+ sequence_output = self.dropout(inputs=sequence_output, training=training)
1316
+ logits = self.classifier(inputs=sequence_output)
1317
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
1318
+
1319
+ if not return_dict:
1320
+ output = (logits,) + outputs[2:]
1321
+
1322
+ return ((loss,) + output) if loss is not None else output
1323
+
1324
+ return TFTokenClassifierOutput(
1325
+ loss=loss,
1326
+ logits=logits,
1327
+ hidden_states=outputs.hidden_states,
1328
+ attentions=outputs.attentions,
1329
+ )
1330
+
1331
+ def build(self, input_shape=None):
1332
+ if self.built:
1333
+ return
1334
+ self.built = True
1335
+ if getattr(self, "albert", None) is not None:
1336
+ with tf.name_scope(self.albert.name):
1337
+ self.albert.build(None)
1338
+ if getattr(self, "classifier", None) is not None:
1339
+ with tf.name_scope(self.classifier.name):
1340
+ self.classifier.build([None, None, self.config.hidden_size])
1341
+
1342
+
1343
+ @add_start_docstrings(
1344
+ """
1345
+ Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1346
+ layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1347
+ """,
1348
+ ALBERT_START_DOCSTRING,
1349
+ )
1350
+ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
1351
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1352
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
1353
+
1354
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
1355
+ super().__init__(config, *inputs, **kwargs)
1356
+
1357
+ self.num_labels = config.num_labels
1358
+
1359
+ self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
1360
+ self.qa_outputs = keras.layers.Dense(
1361
+ units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
1362
+ )
1363
+ self.config = config
1364
+
1365
+ @unpack_inputs
1366
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1367
+ @add_code_sample_docstrings(
1368
+ checkpoint="vumichien/albert-base-v2-squad2",
1369
+ output_type=TFQuestionAnsweringModelOutput,
1370
+ config_class=_CONFIG_FOR_DOC,
1371
+ qa_target_start_index=12,
1372
+ qa_target_end_index=13,
1373
+ expected_output="'a nice puppet'",
1374
+ expected_loss=7.36,
1375
+ )
1376
+ def call(
1377
+ self,
1378
+ input_ids: TFModelInputType | None = None,
1379
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1380
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1381
+ position_ids: np.ndarray | tf.Tensor | None = None,
1382
+ head_mask: np.ndarray | tf.Tensor | None = None,
1383
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1384
+ output_attentions: bool | None = None,
1385
+ output_hidden_states: bool | None = None,
1386
+ return_dict: bool | None = None,
1387
+ start_positions: np.ndarray | tf.Tensor | None = None,
1388
+ end_positions: np.ndarray | tf.Tensor | None = None,
1389
+ training: bool | None = False,
1390
+ ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]:
1391
+ r"""
1392
+ start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1393
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1394
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1395
+ are not taken into account for computing the loss.
1396
+ end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1397
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1398
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1399
+ are not taken into account for computing the loss.
1400
+ """
1401
+ outputs = self.albert(
1402
+ input_ids=input_ids,
1403
+ attention_mask=attention_mask,
1404
+ token_type_ids=token_type_ids,
1405
+ position_ids=position_ids,
1406
+ head_mask=head_mask,
1407
+ inputs_embeds=inputs_embeds,
1408
+ output_attentions=output_attentions,
1409
+ output_hidden_states=output_hidden_states,
1410
+ return_dict=return_dict,
1411
+ training=training,
1412
+ )
1413
+ sequence_output = outputs[0]
1414
+ logits = self.qa_outputs(inputs=sequence_output)
1415
+ start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
1416
+ start_logits = tf.squeeze(input=start_logits, axis=-1)
1417
+ end_logits = tf.squeeze(input=end_logits, axis=-1)
1418
+ loss = None
1419
+
1420
+ if start_positions is not None and end_positions is not None:
1421
+ labels = {"start_position": start_positions}
1422
+ labels["end_position"] = end_positions
1423
+ loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
1424
+
1425
+ if not return_dict:
1426
+ output = (start_logits, end_logits) + outputs[2:]
1427
+
1428
+ return ((loss,) + output) if loss is not None else output
1429
+
1430
+ return TFQuestionAnsweringModelOutput(
1431
+ loss=loss,
1432
+ start_logits=start_logits,
1433
+ end_logits=end_logits,
1434
+ hidden_states=outputs.hidden_states,
1435
+ attentions=outputs.attentions,
1436
+ )
1437
+
1438
+ def build(self, input_shape=None):
1439
+ if self.built:
1440
+ return
1441
+ self.built = True
1442
+ if getattr(self, "albert", None) is not None:
1443
+ with tf.name_scope(self.albert.name):
1444
+ self.albert.build(None)
1445
+ if getattr(self, "qa_outputs", None) is not None:
1446
+ with tf.name_scope(self.qa_outputs.name):
1447
+ self.qa_outputs.build([None, None, self.config.hidden_size])
1448
+
1449
+
1450
+ @add_start_docstrings(
1451
+ """
1452
+ Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1453
+ softmax) e.g. for RocStories/SWAG tasks.
1454
+ """,
1455
+ ALBERT_START_DOCSTRING,
1456
+ )
1457
+ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
1458
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1459
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
1460
+ _keys_to_ignore_on_load_missing = [r"dropout"]
1461
+
1462
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
1463
+ super().__init__(config, *inputs, **kwargs)
1464
+
1465
+ self.albert = TFAlbertMainLayer(config, name="albert")
1466
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
1467
+ self.classifier = keras.layers.Dense(
1468
+ units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
1469
+ )
1470
+ self.config = config
1471
+
1472
+ @unpack_inputs
1473
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1474
+ @add_code_sample_docstrings(
1475
+ checkpoint=_CHECKPOINT_FOR_DOC,
1476
+ output_type=TFMultipleChoiceModelOutput,
1477
+ config_class=_CONFIG_FOR_DOC,
1478
+ )
1479
+ def call(
1480
+ self,
1481
+ input_ids: TFModelInputType | None = None,
1482
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1483
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1484
+ position_ids: np.ndarray | tf.Tensor | None = None,
1485
+ head_mask: np.ndarray | tf.Tensor | None = None,
1486
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1487
+ output_attentions: bool | None = None,
1488
+ output_hidden_states: bool | None = None,
1489
+ return_dict: bool | None = None,
1490
+ labels: np.ndarray | tf.Tensor | None = None,
1491
+ training: bool | None = False,
1492
+ ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]:
1493
+ r"""
1494
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1495
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
1496
+ where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
1497
+ """
1498
+
1499
+ if input_ids is not None:
1500
+ num_choices = shape_list(input_ids)[1]
1501
+ seq_length = shape_list(input_ids)[2]
1502
+ else:
1503
+ num_choices = shape_list(inputs_embeds)[1]
1504
+ seq_length = shape_list(inputs_embeds)[2]
1505
+
1506
+ flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
1507
+ flat_attention_mask = (
1508
+ tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
1509
+ )
1510
+ flat_token_type_ids = (
1511
+ tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
1512
+ )
1513
+ flat_position_ids = (
1514
+ tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None
1515
+ )
1516
+ flat_inputs_embeds = (
1517
+ tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]))
1518
+ if inputs_embeds is not None
1519
+ else None
1520
+ )
1521
+ outputs = self.albert(
1522
+ input_ids=flat_input_ids,
1523
+ attention_mask=flat_attention_mask,
1524
+ token_type_ids=flat_token_type_ids,
1525
+ position_ids=flat_position_ids,
1526
+ head_mask=head_mask,
1527
+ inputs_embeds=flat_inputs_embeds,
1528
+ output_attentions=output_attentions,
1529
+ output_hidden_states=output_hidden_states,
1530
+ return_dict=return_dict,
1531
+ training=training,
1532
+ )
1533
+ pooled_output = outputs[1]
1534
+ pooled_output = self.dropout(inputs=pooled_output, training=training)
1535
+ logits = self.classifier(inputs=pooled_output)
1536
+ reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
1537
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
1538
+
1539
+ if not return_dict:
1540
+ output = (reshaped_logits,) + outputs[2:]
1541
+ return ((loss,) + output) if loss is not None else output
1542
+
1543
+ return TFMultipleChoiceModelOutput(
1544
+ loss=loss,
1545
+ logits=reshaped_logits,
1546
+ hidden_states=outputs.hidden_states,
1547
+ attentions=outputs.attentions,
1548
+ )
1549
+
1550
+ def build(self, input_shape=None):
1551
+ if self.built:
1552
+ return
1553
+ self.built = True
1554
+ if getattr(self, "albert", None) is not None:
1555
+ with tf.name_scope(self.albert.name):
1556
+ self.albert.build(None)
1557
+ if getattr(self, "classifier", None) is not None:
1558
+ with tf.name_scope(self.classifier.name):
1559
+ self.classifier.build([None, None, self.config.hidden_size])
1560
+
1561
+
1562
+ __all__ = [
1563
+ "TFAlbertPreTrainedModel",
1564
+ "TFAlbertModel",
1565
+ "TFAlbertForPreTraining",
1566
+ "TFAlbertForMaskedLM",
1567
+ "TFAlbertForSequenceClassification",
1568
+ "TFAlbertForTokenClassification",
1569
+ "TFAlbertForQuestionAnswering",
1570
+ "TFAlbertForMultipleChoice",
1571
+ "TFAlbertMainLayer",
1572
+ ]
venv/lib/python3.13/site-packages/transformers/models/albert/tokenization_albert.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for ALBERT model."""
16
+
17
+ import os
18
+ import unicodedata
19
+ from shutil import copyfile
20
+ from typing import Any, Optional
21
+
22
+ import sentencepiece as spm
23
+
24
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
25
+ from ...utils import logging
26
+ from ...utils.import_utils import requires
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
31
+
32
+
33
+ SPIECE_UNDERLINE = "▁"
34
+
35
+
36
+ @requires(backends=("sentencepiece",))
37
+ class AlbertTokenizer(PreTrainedTokenizer):
38
+ """
39
+ Construct an ALBERT tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
40
+
41
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
42
+ this superclass for more information regarding those methods.
43
+
44
+ Args:
45
+ vocab_file (`str`):
46
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
47
+ contains the vocabulary necessary to instantiate a tokenizer.
48
+ do_lower_case (`bool`, *optional*, defaults to `True`):
49
+ Whether or not to lowercase the input when tokenizing.
50
+ remove_space (`bool`, *optional*, defaults to `True`):
51
+ Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
52
+ keep_accents (`bool`, *optional*, defaults to `False`):
53
+ Whether or not to keep accents when tokenizing.
54
+ bos_token (`str`, *optional*, defaults to `"[CLS]"`):
55
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
56
+
57
+ <Tip>
58
+
59
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
60
+ sequence. The token used is the `cls_token`.
61
+
62
+ </Tip>
63
+
64
+ eos_token (`str`, *optional*, defaults to `"[SEP]"`):
65
+ The end of sequence token.
66
+
67
+ <Tip>
68
+
69
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
70
+ The token used is the `sep_token`.
71
+
72
+ </Tip>
73
+
74
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
75
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
76
+ token instead.
77
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
78
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
79
+ sequence classification or for a text and a question for question answering. It is also used as the last
80
+ token of a sequence built with special tokens.
81
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
82
+ The token used for padding, for example when batching sequences of different lengths.
83
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
84
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
85
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
86
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
87
+ The token used for masking values. This is the token used when training this model with masked language
88
+ modeling. This is the token which the model will try to predict.
89
+ sp_model_kwargs (`dict`, *optional*):
90
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
91
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
92
+ to set:
93
+
94
+ - `enable_sampling`: Enable subword regularization.
95
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
96
+
97
+ - `nbest_size = {0,1}`: No sampling is performed.
98
+ - `nbest_size > 1`: samples from the nbest_size results.
99
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
100
+ using forward-filtering-and-backward-sampling algorithm.
101
+
102
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
103
+ BPE-dropout.
104
+
105
+ Attributes:
106
+ sp_model (`SentencePieceProcessor`):
107
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
108
+ """
109
+
110
+ vocab_files_names = VOCAB_FILES_NAMES
111
+
112
+ def __init__(
113
+ self,
114
+ vocab_file,
115
+ do_lower_case=True,
116
+ remove_space=True,
117
+ keep_accents=False,
118
+ bos_token="[CLS]",
119
+ eos_token="[SEP]",
120
+ unk_token="<unk>",
121
+ sep_token="[SEP]",
122
+ pad_token="<pad>",
123
+ cls_token="[CLS]",
124
+ mask_token="[MASK]",
125
+ sp_model_kwargs: Optional[dict[str, Any]] = None,
126
+ **kwargs,
127
+ ) -> None:
128
+ # Mask token behave like a normal word, i.e. include the space before it and
129
+ # is included in the raw text, there should be a match in a non-normalized sentence.
130
+ mask_token = (
131
+ AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
132
+ if isinstance(mask_token, str)
133
+ else mask_token
134
+ )
135
+
136
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
137
+
138
+ self.do_lower_case = do_lower_case
139
+ self.remove_space = remove_space
140
+ self.keep_accents = keep_accents
141
+ self.vocab_file = vocab_file
142
+
143
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
144
+ self.sp_model.Load(vocab_file)
145
+
146
+ super().__init__(
147
+ do_lower_case=do_lower_case,
148
+ remove_space=remove_space,
149
+ keep_accents=keep_accents,
150
+ bos_token=bos_token,
151
+ eos_token=eos_token,
152
+ unk_token=unk_token,
153
+ sep_token=sep_token,
154
+ pad_token=pad_token,
155
+ cls_token=cls_token,
156
+ mask_token=mask_token,
157
+ sp_model_kwargs=self.sp_model_kwargs,
158
+ **kwargs,
159
+ )
160
+
161
+ @property
162
+ def vocab_size(self) -> int:
163
+ return len(self.sp_model)
164
+
165
+ def get_vocab(self) -> dict[str, int]:
166
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
167
+ vocab.update(self.added_tokens_encoder)
168
+ return vocab
169
+
170
+ def __getstate__(self):
171
+ state = self.__dict__.copy()
172
+ state["sp_model"] = None
173
+ return state
174
+
175
+ def __setstate__(self, d):
176
+ self.__dict__ = d
177
+
178
+ # for backward compatibility
179
+ if not hasattr(self, "sp_model_kwargs"):
180
+ self.sp_model_kwargs = {}
181
+
182
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
183
+ self.sp_model.Load(self.vocab_file)
184
+
185
+ def preprocess_text(self, inputs):
186
+ if self.remove_space:
187
+ outputs = " ".join(inputs.strip().split())
188
+ else:
189
+ outputs = inputs
190
+ outputs = outputs.replace("``", '"').replace("''", '"')
191
+
192
+ if not self.keep_accents:
193
+ outputs = unicodedata.normalize("NFKD", outputs)
194
+ outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
195
+ if self.do_lower_case:
196
+ outputs = outputs.lower()
197
+
198
+ return outputs
199
+
200
+ def _tokenize(self, text: str) -> list[str]:
201
+ """Tokenize a string."""
202
+ text = self.preprocess_text(text)
203
+ pieces = self.sp_model.encode(text, out_type=str)
204
+ new_pieces = []
205
+ for piece in pieces:
206
+ if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
207
+ # Logic to handle special cases see https://github.com/google-research/bert/blob/master/README.md#tokenization
208
+ # `9,9` -> ['▁9', ',', '9'] instead of [`_9,`, '9']
209
+ cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
210
+ if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
211
+ if len(cur_pieces[0]) == 1:
212
+ cur_pieces = cur_pieces[1:]
213
+ else:
214
+ cur_pieces[0] = cur_pieces[0][1:]
215
+ cur_pieces.append(piece[-1])
216
+ new_pieces.extend(cur_pieces)
217
+ else:
218
+ new_pieces.append(piece)
219
+
220
+ return new_pieces
221
+
222
+ def _convert_token_to_id(self, token):
223
+ """Converts a token (str) in an id using the vocab."""
224
+ return self.sp_model.PieceToId(token)
225
+
226
+ def _convert_id_to_token(self, index):
227
+ """Converts an index (integer) in a token (str) using the vocab."""
228
+ return self.sp_model.IdToPiece(index)
229
+
230
+ def convert_tokens_to_string(self, tokens):
231
+ """Converts a sequence of tokens (string) in a single string."""
232
+ current_sub_tokens = []
233
+ out_string = ""
234
+ prev_is_special = False
235
+ for token in tokens:
236
+ # make sure that special tokens are not decoded using sentencepiece model
237
+ if token in self.all_special_tokens:
238
+ if not prev_is_special:
239
+ out_string += " "
240
+ out_string += self.sp_model.decode(current_sub_tokens) + token
241
+ prev_is_special = True
242
+ current_sub_tokens = []
243
+ else:
244
+ current_sub_tokens.append(token)
245
+ prev_is_special = False
246
+ out_string += self.sp_model.decode(current_sub_tokens)
247
+ return out_string.strip()
248
+
249
+ def build_inputs_with_special_tokens(
250
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
251
+ ) -> list[int]:
252
+ """
253
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
254
+ adding special tokens. An ALBERT sequence has the following format:
255
+
256
+ - single sequence: `[CLS] X [SEP]`
257
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
258
+
259
+ Args:
260
+ token_ids_0 (`List[int]`):
261
+ List of IDs to which the special tokens will be added.
262
+ token_ids_1 (`List[int]`, *optional*):
263
+ Optional second list of IDs for sequence pairs.
264
+
265
+ Returns:
266
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
267
+ """
268
+ sep = [self.sep_token_id]
269
+ cls = [self.cls_token_id]
270
+ if token_ids_1 is None:
271
+ return cls + token_ids_0 + sep
272
+ return cls + token_ids_0 + sep + token_ids_1 + sep
273
+
274
+ def get_special_tokens_mask(
275
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
276
+ ) -> list[int]:
277
+ """
278
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
279
+ special tokens using the tokenizer `prepare_for_model` method.
280
+
281
+ Args:
282
+ token_ids_0 (`List[int]`):
283
+ List of IDs.
284
+ token_ids_1 (`List[int]`, *optional*):
285
+ Optional second list of IDs for sequence pairs.
286
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
287
+ Whether or not the token list is already formatted with special tokens for the model.
288
+
289
+ Returns:
290
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
291
+ """
292
+
293
+ if already_has_special_tokens:
294
+ return super().get_special_tokens_mask(
295
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
296
+ )
297
+
298
+ if token_ids_1 is not None:
299
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
300
+ return [1] + ([0] * len(token_ids_0)) + [1]
301
+
302
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
303
+ if not os.path.isdir(save_directory):
304
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
305
+ return
306
+ out_vocab_file = os.path.join(
307
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
308
+ )
309
+
310
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
311
+ copyfile(self.vocab_file, out_vocab_file)
312
+ elif not os.path.isfile(self.vocab_file):
313
+ with open(out_vocab_file, "wb") as fi:
314
+ content_spiece_model = self.sp_model.serialized_model_proto()
315
+ fi.write(content_spiece_model)
316
+
317
+ return (out_vocab_file,)
318
+
319
+
320
+ __all__ = ["AlbertTokenizer"]
venv/lib/python3.13/site-packages/transformers/models/albert/tokenization_albert_fast.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for ALBERT model."""
16
+
17
+ import os
18
+ from shutil import copyfile
19
+ from typing import Optional
20
+
21
+ from ...tokenization_utils import AddedToken
22
+ from ...tokenization_utils_fast import PreTrainedTokenizerFast
23
+ from ...utils import is_sentencepiece_available, logging
24
+
25
+
26
+ if is_sentencepiece_available():
27
+ from .tokenization_albert import AlbertTokenizer
28
+ else:
29
+ AlbertTokenizer = None
30
+
31
+ logger = logging.get_logger(__name__)
32
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
33
+
34
+
35
+ SPIECE_UNDERLINE = "▁"
36
+
37
+
38
+ class AlbertTokenizerFast(PreTrainedTokenizerFast):
39
+ """
40
+ Construct a "fast" ALBERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on
41
+ [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). This
42
+ tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should refer to
43
+ this superclass for more information regarding those methods
44
+
45
+ Args:
46
+ vocab_file (`str`):
47
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
48
+ contains the vocabulary necessary to instantiate a tokenizer.
49
+ do_lower_case (`bool`, *optional*, defaults to `True`):
50
+ Whether or not to lowercase the input when tokenizing.
51
+ remove_space (`bool`, *optional*, defaults to `True`):
52
+ Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
53
+ keep_accents (`bool`, *optional*, defaults to `False`):
54
+ Whether or not to keep accents when tokenizing.
55
+ bos_token (`str`, *optional*, defaults to `"[CLS]"`):
56
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
57
+
58
+ <Tip>
59
+
60
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
61
+ sequence. The token used is the `cls_token`.
62
+
63
+ </Tip>
64
+
65
+ eos_token (`str`, *optional*, defaults to `"[SEP]"`):
66
+ The end of sequence token. .. note:: When building a sequence using special tokens, this is not the token
67
+ that is used for the end of sequence. The token used is the `sep_token`.
68
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
69
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
70
+ token instead.
71
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
72
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
73
+ sequence classification or for a text and a question for question answering. It is also used as the last
74
+ token of a sequence built with special tokens.
75
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
76
+ The token used for padding, for example when batching sequences of different lengths.
77
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
78
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
79
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
80
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
81
+ The token used for masking values. This is the token used when training this model with masked language
82
+ modeling. This is the token which the model will try to predict.
83
+ """
84
+
85
+ vocab_files_names = VOCAB_FILES_NAMES
86
+ slow_tokenizer_class = AlbertTokenizer
87
+
88
+ def __init__(
89
+ self,
90
+ vocab_file=None,
91
+ tokenizer_file=None,
92
+ do_lower_case=True,
93
+ remove_space=True,
94
+ keep_accents=False,
95
+ bos_token="[CLS]",
96
+ eos_token="[SEP]",
97
+ unk_token="<unk>",
98
+ sep_token="[SEP]",
99
+ pad_token="<pad>",
100
+ cls_token="[CLS]",
101
+ mask_token="[MASK]",
102
+ **kwargs,
103
+ ):
104
+ # Mask token behave like a normal word, i.e. include the space before it and
105
+ # is included in the raw text, there should be a match in a non-normalized sentence.
106
+ mask_token = (
107
+ AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
108
+ if isinstance(mask_token, str)
109
+ else mask_token
110
+ )
111
+
112
+ super().__init__(
113
+ vocab_file,
114
+ tokenizer_file=tokenizer_file,
115
+ do_lower_case=do_lower_case,
116
+ remove_space=remove_space,
117
+ keep_accents=keep_accents,
118
+ bos_token=bos_token,
119
+ eos_token=eos_token,
120
+ unk_token=unk_token,
121
+ sep_token=sep_token,
122
+ pad_token=pad_token,
123
+ cls_token=cls_token,
124
+ mask_token=mask_token,
125
+ **kwargs,
126
+ )
127
+
128
+ self.do_lower_case = do_lower_case
129
+ self.remove_space = remove_space
130
+ self.keep_accents = keep_accents
131
+ self.vocab_file = vocab_file
132
+
133
+ def build_inputs_with_special_tokens(
134
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
135
+ ) -> list[int]:
136
+ """
137
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
138
+ adding special tokens. An ALBERT sequence has the following format:
139
+
140
+ - single sequence: `[CLS] X [SEP]`
141
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
142
+
143
+ Args:
144
+ token_ids_0 (`List[int]`):
145
+ List of IDs to which the special tokens will be added
146
+ token_ids_1 (`List[int]`, *optional*):
147
+ Optional second list of IDs for sequence pairs.
148
+
149
+ Returns:
150
+ `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
151
+ """
152
+ sep = [self.sep_token_id]
153
+ cls = [self.cls_token_id]
154
+ if token_ids_1 is None:
155
+ return cls + token_ids_0 + sep
156
+ return cls + token_ids_0 + sep + token_ids_1 + sep
157
+
158
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
159
+ if not self.can_save_slow_tokenizer:
160
+ raise ValueError(
161
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
162
+ "tokenizer."
163
+ )
164
+
165
+ if not os.path.isdir(save_directory):
166
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
167
+ return
168
+ out_vocab_file = os.path.join(
169
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
170
+ )
171
+
172
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
173
+ copyfile(self.vocab_file, out_vocab_file)
174
+
175
+ return (out_vocab_file,)
176
+
177
+
178
+ __all__ = ["AlbertTokenizerFast"]
venv/lib/python3.13/site-packages/transformers/models/apertus/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
3
+ #
4
+ # This code is based on HuggingFace's LLaMA implementation in this library.
5
+ # It has been modified from its original forms to accommodate the architectural
6
+ # differences made by the Swiss AI Initiative that trained the model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ from typing import TYPE_CHECKING
20
+
21
+ from ...utils import _LazyModule
22
+ from ...utils.import_utils import define_import_structure
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from .configuration_apertus import *
27
+ from .modeling_apertus import *
28
+ else:
29
+ import sys
30
+
31
+ _file = globals()["__file__"]
32
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
venv/lib/python3.13/site-packages/transformers/models/apertus/configuration_apertus.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/apertus/modular_apertus.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_apertus.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 the HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
9
+ #
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+ from ...configuration_utils import PretrainedConfig
24
+ from ...modeling_rope_utils import rope_config_validation
25
+
26
+
27
+ class ApertusConfig(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`ApertusModel`]. It is used to instantiate a Apertus
30
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
31
+ defaults will yield a similar configuration to that of the Apertus-8B.
32
+ e.g. [swiss-ai/Apertus-8B](https://huggingface.co/swiss-ai/Apertus-8B)
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 131072):
40
+ Vocabulary size of the Apertus model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`ApertusModel`]
42
+ hidden_size (`int`, *optional*, defaults to 4096):
43
+ Dimension of the hidden representations.
44
+ intermediate_size (`int`, *optional*, defaults to 14336):
45
+ Dimension of the MLP representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 32):
47
+ Number of hidden layers in the Transformer decoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 32):
49
+ Number of attention heads for each attention layer in the Transformer decoder.
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details, check out [this
56
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"xielu"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 65536):
61
+ The maximum sequence length that this model might ever be used with. Apertus supports up to 65536 tokens.
62
+ initializer_range (`float`, *optional*, defaults to 0.02):
63
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
65
+ The epsilon used by the rms normalization layers.
66
+ use_cache (`bool`, *optional*, defaults to `True`):
67
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
68
+ relevant if `config.is_decoder=True`.
69
+ pad_token_id (`int`, *optional*, defaults to 3):
70
+ Padding token id.
71
+ bos_token_id (`int`, *optional*, defaults to 1):
72
+ Beginning of stream token id.
73
+ eos_token_id (`int`, *optional*, defaults to 2):
74
+ End of stream token id.
75
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
76
+ Whether to tie weight embeddings
77
+ rope_theta (`float`, *optional*, defaults to 12000000.0):
78
+ The base period of the RoPE embeddings.
79
+ rope_scaling (`Dict`, *optional*):
80
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
81
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
82
+ accordingly.
83
+ Expected contents:
84
+ `rope_type` (`str`):
85
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
86
+ 'llama3'], with 'default' being the original RoPE implementation.
87
+ `factor` (`float`, *optional*):
88
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
89
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
90
+ original maximum pre-trained length.
91
+ `original_max_position_embeddings` (`int`, *optional*):
92
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
93
+ pretraining.
94
+ `attention_factor` (`float`, *optional*):
95
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
96
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
97
+ `factor` field to infer the suggested value.
98
+ `beta_fast` (`float`, *optional*):
99
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
100
+ ramp function. If unspecified, it defaults to 32.
101
+ `beta_slow` (`float`, *optional*):
102
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
103
+ ramp function. If unspecified, it defaults to 1.
104
+ `short_factor` (`list[float]`, *optional*):
105
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
106
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
107
+ size divided by the number of attention heads divided by 2
108
+ `long_factor` (`list[float]`, *optional*):
109
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
110
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
111
+ size divided by the number of attention heads divided by 2
112
+ `low_freq_factor` (`float`, *optional*):
113
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
114
+ `high_freq_factor` (`float`, *optional*):
115
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
116
+ attention_bias (`bool`, *optional*, defaults to `False`):
117
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
118
+ attention_dropout (`float`, *optional*, defaults to 0.0):
119
+ The dropout ratio for the attention probabilities.
120
+
121
+ ```python
122
+ >>> from transformers import ApertusModel, ApertusConfig
123
+
124
+ >>> # Initializing a Apertus-8B style configuration
125
+ >>> configuration = ApertusConfig()
126
+
127
+ >>> # Initializing a model from the Apertus-8B style configuration
128
+ >>> model = ApertusModel(configuration)
129
+
130
+ >>> # Accessing the model configuration
131
+ >>> configuration = model.config
132
+ ```"""
133
+
134
+ model_type = "apertus"
135
+ keys_to_ignore_at_inference = ["past_key_values"]
136
+ base_model_tp_plan = {
137
+ "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
138
+ "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
139
+ "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
140
+ "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
141
+ "layers.*.mlp.up_proj": "colwise",
142
+ "layers.*.mlp.down_proj": "rowwise",
143
+ "layers.*.mlp.gate_proj": "colwise",
144
+ }
145
+ base_model_pp_plan = {
146
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
147
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
148
+ "norm": (["hidden_states"], ["hidden_states"]),
149
+ }
150
+
151
+ def __init__(
152
+ self,
153
+ vocab_size=131072,
154
+ hidden_size=4096,
155
+ intermediate_size=14336,
156
+ num_hidden_layers=32,
157
+ num_attention_heads=32,
158
+ num_key_value_heads=None,
159
+ hidden_act="xielu",
160
+ max_position_embeddings=65536,
161
+ initializer_range=0.02,
162
+ rms_norm_eps=1e-5,
163
+ use_cache=True,
164
+ pad_token_id=3,
165
+ bos_token_id=1,
166
+ eos_token_id=2,
167
+ tie_word_embeddings=False,
168
+ rope_theta=12000000.0,
169
+ rope_scaling={
170
+ "rope_type": "llama3",
171
+ "factor": 8.0,
172
+ "original_max_position_embeddings": 8192,
173
+ "low_freq_factor": 1.0,
174
+ "high_freq_factor": 4.0,
175
+ },
176
+ attention_bias=False,
177
+ attention_dropout=0.0,
178
+ **kwargs,
179
+ ):
180
+ super().__init__(
181
+ pad_token_id=pad_token_id,
182
+ bos_token_id=bos_token_id,
183
+ eos_token_id=eos_token_id,
184
+ tie_word_embeddings=tie_word_embeddings,
185
+ **kwargs,
186
+ )
187
+ self.vocab_size = vocab_size
188
+ self.max_position_embeddings = max_position_embeddings
189
+ self.hidden_size = hidden_size
190
+ self.intermediate_size = intermediate_size
191
+ self.num_hidden_layers = num_hidden_layers
192
+ self.num_attention_heads = num_attention_heads
193
+
194
+ # for backward compatibility
195
+ if num_key_value_heads is None:
196
+ num_key_value_heads = num_attention_heads
197
+
198
+ self.num_key_value_heads = num_key_value_heads
199
+ self.hidden_act = hidden_act
200
+ self.initializer_range = initializer_range
201
+ self.rms_norm_eps = rms_norm_eps
202
+ self.use_cache = use_cache
203
+ self.rope_theta = rope_theta
204
+ self.rope_scaling = rope_scaling
205
+ self.attention_bias = attention_bias
206
+ self.attention_dropout = attention_dropout
207
+ # Validate the correctness of rotary position embeddings parameters
208
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
209
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
210
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
211
+ rope_config_validation(self)
212
+
213
+
214
+ __all__ = ["ApertusConfig"]
venv/lib/python3.13/site-packages/transformers/models/apertus/modeling_apertus.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/apertus/modular_apertus.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_apertus.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 the HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
9
+ #
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ from typing import Callable, Optional, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+
27
+ from ...activations import ACT2FN
28
+ from ...cache_utils import Cache, DynamicCache
29
+ from ...generation import GenerationMixin
30
+ from ...integrations import use_kernel_forward_from_hub
31
+ from ...masking_utils import create_causal_mask
32
+ from ...modeling_layers import GenericForTokenClassification, GradientCheckpointingLayer
33
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
34
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
35
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from ...processing_utils import Unpack
37
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
38
+ from ...utils.deprecation import deprecate_kwarg
39
+ from ...utils.generic import check_model_inputs
40
+ from .configuration_apertus import ApertusConfig
41
+
42
+
43
+ class ApertusMLP(nn.Module):
44
+ def __init__(self, config):
45
+ super().__init__()
46
+ self.config = config
47
+ self.hidden_size = config.hidden_size
48
+ self.intermediate_size = config.intermediate_size
49
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
50
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
51
+ self.act_fn = ACT2FN[config.hidden_act]
52
+
53
+ def forward(self, x):
54
+ return self.down_proj(self.act_fn(self.up_proj(x)))
55
+
56
+
57
+ @use_kernel_forward_from_hub("RMSNorm")
58
+ class ApertusRMSNorm(nn.Module):
59
+ def __init__(self, hidden_size, eps=1e-6):
60
+ """
61
+ ApertusRMSNorm is equivalent to T5LayerNorm
62
+ """
63
+ super().__init__()
64
+ self.weight = nn.Parameter(torch.ones(hidden_size))
65
+ self.variance_epsilon = eps
66
+
67
+ def forward(self, hidden_states):
68
+ input_dtype = hidden_states.dtype
69
+ hidden_states = hidden_states.to(torch.float32)
70
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
71
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
72
+ return self.weight * hidden_states.to(input_dtype)
73
+
74
+ def extra_repr(self):
75
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
76
+
77
+
78
+ class ApertusRotaryEmbedding(nn.Module):
79
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
80
+
81
+ def __init__(self, config: ApertusConfig, device=None):
82
+ super().__init__()
83
+ # BC: "rope_type" was originally "type"
84
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
85
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
86
+ else:
87
+ self.rope_type = "default"
88
+ self.max_seq_len_cached = config.max_position_embeddings
89
+ self.original_max_seq_len = config.max_position_embeddings
90
+
91
+ self.config = config
92
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
93
+
94
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
95
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
96
+ self.original_inv_freq = self.inv_freq
97
+
98
+ @torch.no_grad()
99
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
100
+ def forward(self, x, position_ids):
101
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
102
+ position_ids_expanded = position_ids[:, None, :].float()
103
+
104
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
105
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
106
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
107
+ emb = torch.cat((freqs, freqs), dim=-1)
108
+ cos = emb.cos() * self.attention_scaling
109
+ sin = emb.sin() * self.attention_scaling
110
+
111
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
112
+
113
+
114
+ def rotate_half(x):
115
+ """Rotates half the hidden dims of the input."""
116
+ x1 = x[..., : x.shape[-1] // 2]
117
+ x2 = x[..., x.shape[-1] // 2 :]
118
+ return torch.cat((-x2, x1), dim=-1)
119
+
120
+
121
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
122
+ """Applies Rotary Position Embedding to the query and key tensors.
123
+
124
+ Args:
125
+ q (`torch.Tensor`): The query tensor.
126
+ k (`torch.Tensor`): The key tensor.
127
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
128
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
129
+ position_ids (`torch.Tensor`, *optional*):
130
+ Deprecated and unused.
131
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
132
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
133
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
134
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
135
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
136
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
137
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
138
+ Returns:
139
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
140
+ """
141
+ cos = cos.unsqueeze(unsqueeze_dim)
142
+ sin = sin.unsqueeze(unsqueeze_dim)
143
+ q_embed = (q * cos) + (rotate_half(q) * sin)
144
+ k_embed = (k * cos) + (rotate_half(k) * sin)
145
+ return q_embed, k_embed
146
+
147
+
148
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
149
+ """
150
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
151
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
152
+ """
153
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
154
+ if n_rep == 1:
155
+ return hidden_states
156
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
157
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
158
+
159
+
160
+ def eager_attention_forward(
161
+ module: nn.Module,
162
+ query: torch.Tensor,
163
+ key: torch.Tensor,
164
+ value: torch.Tensor,
165
+ attention_mask: Optional[torch.Tensor],
166
+ scaling: float,
167
+ dropout: float = 0.0,
168
+ **kwargs: Unpack[TransformersKwargs],
169
+ ):
170
+ key_states = repeat_kv(key, module.num_key_value_groups)
171
+ value_states = repeat_kv(value, module.num_key_value_groups)
172
+
173
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
174
+ if attention_mask is not None:
175
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
176
+ attn_weights = attn_weights + causal_mask
177
+
178
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
179
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
180
+ attn_output = torch.matmul(attn_weights, value_states)
181
+ attn_output = attn_output.transpose(1, 2).contiguous()
182
+
183
+ return attn_output, attn_weights
184
+
185
+
186
+ class ApertusAttention(nn.Module):
187
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
188
+
189
+ def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None):
190
+ super().__init__()
191
+ self.config = config
192
+ self.layer_idx = layer_idx
193
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
194
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
195
+ self.scaling = self.head_dim**-0.5
196
+ self.attention_dropout = config.attention_dropout
197
+ self.is_causal = True
198
+
199
+ self.q_proj = nn.Linear(
200
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
201
+ )
202
+ self.k_proj = nn.Linear(
203
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
204
+ )
205
+ self.v_proj = nn.Linear(
206
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
207
+ )
208
+ self.o_proj = nn.Linear(
209
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
210
+ )
211
+ self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
212
+ self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
213
+
214
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
215
+ def forward(
216
+ self,
217
+ hidden_states: torch.Tensor,
218
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
219
+ attention_mask: Optional[torch.Tensor],
220
+ past_key_values: Optional[Cache] = None,
221
+ cache_position: Optional[torch.LongTensor] = None,
222
+ **kwargs: Unpack[TransformersKwargs],
223
+ ) -> tuple[torch.Tensor, torch.Tensor]:
224
+ input_shape = hidden_states.shape[:-1]
225
+ hidden_shape = (*input_shape, -1, self.head_dim)
226
+
227
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
228
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
229
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
230
+ query_states = self.q_norm(query_states)
231
+ key_states = self.k_norm(key_states)
232
+
233
+ cos, sin = position_embeddings
234
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
235
+
236
+ if past_key_values is not None:
237
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
238
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
239
+
240
+ attention_interface: Callable = eager_attention_forward
241
+ if self.config._attn_implementation != "eager":
242
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
243
+
244
+ attn_output, attn_weights = attention_interface(
245
+ self,
246
+ query_states,
247
+ key_states,
248
+ value_states,
249
+ attention_mask,
250
+ dropout=0.0 if not self.training else self.attention_dropout,
251
+ scaling=self.scaling,
252
+ **kwargs,
253
+ )
254
+
255
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
256
+ attn_output = self.o_proj(attn_output)
257
+ return attn_output, attn_weights
258
+
259
+
260
+ class ApertusDecoderLayer(GradientCheckpointingLayer):
261
+ def __init__(self, config: ApertusConfig, layer_idx: int):
262
+ super().__init__()
263
+ self.hidden_size = config.hidden_size
264
+
265
+ self.self_attn = ApertusAttention(config=config, layer_idx=layer_idx)
266
+
267
+ self.mlp = ApertusMLP(config)
268
+ self.attention_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
269
+ self.feedforward_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
270
+
271
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
272
+ def forward(
273
+ self,
274
+ hidden_states: torch.Tensor,
275
+ attention_mask: Optional[torch.Tensor] = None,
276
+ position_ids: Optional[torch.LongTensor] = None,
277
+ past_key_values: Optional[Cache] = None,
278
+ use_cache: Optional[bool] = False,
279
+ cache_position: Optional[torch.LongTensor] = None,
280
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
281
+ **kwargs: Unpack[TransformersKwargs],
282
+ ) -> tuple[torch.Tensor]:
283
+ residual = hidden_states
284
+ hidden_states = self.attention_layernorm(hidden_states)
285
+ hidden_states, _ = self.self_attn(
286
+ hidden_states=hidden_states,
287
+ attention_mask=attention_mask,
288
+ position_ids=position_ids,
289
+ past_key_values=past_key_values,
290
+ use_cache=use_cache,
291
+ cache_position=cache_position,
292
+ position_embeddings=position_embeddings,
293
+ **kwargs,
294
+ )
295
+ hidden_states = residual + hidden_states
296
+
297
+ # Fully Connected
298
+ residual = hidden_states
299
+ hidden_states = self.feedforward_layernorm(hidden_states)
300
+ hidden_states = self.mlp(hidden_states)
301
+ hidden_states = residual + hidden_states
302
+ return hidden_states
303
+
304
+
305
+ @auto_docstring
306
+ class ApertusPreTrainedModel(PreTrainedModel):
307
+ config: ApertusConfig
308
+ base_model_prefix = "model"
309
+ supports_gradient_checkpointing = True
310
+ _no_split_modules = ["ApertusDecoderLayer"]
311
+ _skip_keys_device_placement = ["past_key_values"]
312
+ _supports_flash_attn = True
313
+ _supports_sdpa = True
314
+ _supports_flex_attn = True
315
+
316
+ _can_compile_fullgraph = True
317
+ _supports_attention_backend = True
318
+ _can_record_outputs = {
319
+ "hidden_states": ApertusDecoderLayer,
320
+ "attentions": ApertusAttention,
321
+ }
322
+
323
+
324
+ @auto_docstring
325
+ class ApertusModel(ApertusPreTrainedModel):
326
+ def __init__(self, config: ApertusConfig):
327
+ super().__init__(config)
328
+ self.padding_idx = config.pad_token_id
329
+ self.vocab_size = config.vocab_size
330
+
331
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
332
+ self.layers = nn.ModuleList(
333
+ [ApertusDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
334
+ )
335
+ self.norm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
336
+ self.rotary_emb = ApertusRotaryEmbedding(config=config)
337
+ self.gradient_checkpointing = False
338
+
339
+ # Initialize weights and apply final processing
340
+ self.post_init()
341
+
342
+ @check_model_inputs()
343
+ @auto_docstring
344
+ def forward(
345
+ self,
346
+ input_ids: Optional[torch.LongTensor] = None,
347
+ attention_mask: Optional[torch.Tensor] = None,
348
+ position_ids: Optional[torch.LongTensor] = None,
349
+ past_key_values: Optional[Cache] = None,
350
+ inputs_embeds: Optional[torch.FloatTensor] = None,
351
+ cache_position: Optional[torch.LongTensor] = None,
352
+ use_cache: Optional[bool] = None,
353
+ **kwargs: Unpack[TransformersKwargs],
354
+ ) -> BaseModelOutputWithPast:
355
+ if (input_ids is None) ^ (inputs_embeds is not None):
356
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
357
+
358
+ if inputs_embeds is None:
359
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
360
+
361
+ if use_cache and past_key_values is None:
362
+ past_key_values = DynamicCache(config=self.config)
363
+
364
+ if cache_position is None:
365
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
366
+ cache_position: torch.Tensor = torch.arange(
367
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
368
+ )
369
+
370
+ if position_ids is None:
371
+ position_ids = cache_position.unsqueeze(0)
372
+
373
+ causal_mask = create_causal_mask(
374
+ config=self.config,
375
+ input_embeds=inputs_embeds,
376
+ attention_mask=attention_mask,
377
+ cache_position=cache_position,
378
+ past_key_values=past_key_values,
379
+ position_ids=position_ids,
380
+ )
381
+
382
+ hidden_states = inputs_embeds
383
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
384
+
385
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
386
+ hidden_states = decoder_layer(
387
+ hidden_states,
388
+ attention_mask=causal_mask,
389
+ position_ids=position_ids,
390
+ past_key_values=past_key_values,
391
+ cache_position=cache_position,
392
+ position_embeddings=position_embeddings,
393
+ **kwargs,
394
+ )
395
+
396
+ hidden_states = self.norm(hidden_states)
397
+ return BaseModelOutputWithPast(
398
+ last_hidden_state=hidden_states,
399
+ past_key_values=past_key_values,
400
+ )
401
+
402
+
403
+ @auto_docstring
404
+ class ApertusForCausalLM(ApertusPreTrainedModel, GenerationMixin):
405
+ _tied_weights_keys = ["lm_head.weight"]
406
+ _tp_plan = {"lm_head": "colwise_rep"}
407
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
408
+
409
+ def __init__(self, config):
410
+ super().__init__(config)
411
+ self.model = ApertusModel(config)
412
+ self.vocab_size = config.vocab_size
413
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
414
+
415
+ # Initialize weights and apply final processing
416
+ self.post_init()
417
+
418
+ @can_return_tuple
419
+ @auto_docstring
420
+ def forward(
421
+ self,
422
+ input_ids: Optional[torch.LongTensor] = None,
423
+ attention_mask: Optional[torch.Tensor] = None,
424
+ position_ids: Optional[torch.LongTensor] = None,
425
+ past_key_values: Optional[Cache] = None,
426
+ inputs_embeds: Optional[torch.FloatTensor] = None,
427
+ labels: Optional[torch.LongTensor] = None,
428
+ use_cache: Optional[bool] = None,
429
+ cache_position: Optional[torch.LongTensor] = None,
430
+ logits_to_keep: Union[int, torch.Tensor] = 0,
431
+ **kwargs: Unpack[TransformersKwargs],
432
+ ) -> CausalLMOutputWithPast:
433
+ r"""
434
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
435
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
436
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
437
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
438
+
439
+ Example:
440
+
441
+ ```python
442
+ >>> from transformers import AutoTokenizer, ApertusForCausalLM
443
+
444
+ >>> model = ApertusForCausalLM.from_pretrained("swiss-ai/Apertus-8B")
445
+ >>> tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B")
446
+
447
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
448
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
449
+
450
+ >>> # Generate
451
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
452
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
453
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
454
+ ```"""
455
+ outputs: BaseModelOutputWithPast = self.model(
456
+ input_ids=input_ids,
457
+ attention_mask=attention_mask,
458
+ position_ids=position_ids,
459
+ past_key_values=past_key_values,
460
+ inputs_embeds=inputs_embeds,
461
+ use_cache=use_cache,
462
+ cache_position=cache_position,
463
+ **kwargs,
464
+ )
465
+
466
+ hidden_states = outputs.last_hidden_state
467
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
468
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
469
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
470
+
471
+ loss = None
472
+ if labels is not None:
473
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
474
+
475
+ return CausalLMOutputWithPast(
476
+ loss=loss,
477
+ logits=logits,
478
+ past_key_values=outputs.past_key_values,
479
+ hidden_states=outputs.hidden_states,
480
+ attentions=outputs.attentions,
481
+ )
482
+
483
+
484
+ class ApertusForTokenClassification(GenericForTokenClassification, ApertusPreTrainedModel):
485
+ pass
486
+
487
+
488
+ __all__ = ["ApertusModel", "ApertusForCausalLM", "ApertusForTokenClassification", "ApertusPreTrainedModel"]
venv/lib/python3.13/site-packages/transformers/models/apertus/modular_apertus.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 the HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
3
+ #
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from typing import Callable, Optional
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from ...cache_utils import Cache
22
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
23
+ from ...processing_utils import Unpack
24
+ from ...utils import TransformersKwargs, logging
25
+ from ..llama.configuration_llama import LlamaConfig
26
+ from ..llama.modeling_llama import (
27
+ LlamaAttention,
28
+ LlamaDecoderLayer,
29
+ LlamaForCausalLM,
30
+ LlamaForTokenClassification,
31
+ LlamaModel,
32
+ LlamaPreTrainedModel,
33
+ LlamaRMSNorm,
34
+ LlamaRotaryEmbedding,
35
+ apply_rotary_pos_emb,
36
+ eager_attention_forward,
37
+ )
38
+ from ..nemotron.modeling_nemotron import NemotronMLP
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ class ApertusConfig(LlamaConfig):
45
+ r"""
46
+ This is the configuration class to store the configuration of a [`ApertusModel`]. It is used to instantiate a Apertus
47
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
48
+ defaults will yield a similar configuration to that of the Apertus-8B.
49
+ e.g. [swiss-ai/Apertus-8B](https://huggingface.co/swiss-ai/Apertus-8B)
50
+
51
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
52
+ documentation from [`PretrainedConfig`] for more information.
53
+
54
+
55
+ Args:
56
+ vocab_size (`int`, *optional*, defaults to 131072):
57
+ Vocabulary size of the Apertus model. Defines the number of different tokens that can be represented by the
58
+ `inputs_ids` passed when calling [`ApertusModel`]
59
+ hidden_size (`int`, *optional*, defaults to 4096):
60
+ Dimension of the hidden representations.
61
+ intermediate_size (`int`, *optional*, defaults to 14336):
62
+ Dimension of the MLP representations.
63
+ num_hidden_layers (`int`, *optional*, defaults to 32):
64
+ Number of hidden layers in the Transformer decoder.
65
+ num_attention_heads (`int`, *optional*, defaults to 32):
66
+ Number of attention heads for each attention layer in the Transformer decoder.
67
+ num_key_value_heads (`int`, *optional*):
68
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
69
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
70
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
71
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
72
+ by meanpooling all the original heads within that group. For more details, check out [this
73
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
74
+ `num_attention_heads`.
75
+ hidden_act (`str` or `function`, *optional*, defaults to `"xielu"`):
76
+ The non-linear activation function (function or string) in the decoder.
77
+ max_position_embeddings (`int`, *optional*, defaults to 65536):
78
+ The maximum sequence length that this model might ever be used with. Apertus supports up to 65536 tokens.
79
+ initializer_range (`float`, *optional*, defaults to 0.02):
80
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
81
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
82
+ The epsilon used by the rms normalization layers.
83
+ use_cache (`bool`, *optional*, defaults to `True`):
84
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
85
+ relevant if `config.is_decoder=True`.
86
+ pad_token_id (`int`, *optional*, defaults to 3):
87
+ Padding token id.
88
+ bos_token_id (`int`, *optional*, defaults to 1):
89
+ Beginning of stream token id.
90
+ eos_token_id (`int`, *optional*, defaults to 2):
91
+ End of stream token id.
92
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
93
+ Whether to tie weight embeddings
94
+ rope_theta (`float`, *optional*, defaults to 12000000.0):
95
+ The base period of the RoPE embeddings.
96
+ rope_scaling (`Dict`, *optional*):
97
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
98
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
99
+ accordingly.
100
+ Expected contents:
101
+ `rope_type` (`str`):
102
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
103
+ 'llama3'], with 'default' being the original RoPE implementation.
104
+ `factor` (`float`, *optional*):
105
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
106
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
107
+ original maximum pre-trained length.
108
+ `original_max_position_embeddings` (`int`, *optional*):
109
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
110
+ pretraining.
111
+ `attention_factor` (`float`, *optional*):
112
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
113
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
114
+ `factor` field to infer the suggested value.
115
+ `beta_fast` (`float`, *optional*):
116
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
117
+ ramp function. If unspecified, it defaults to 32.
118
+ `beta_slow` (`float`, *optional*):
119
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
120
+ ramp function. If unspecified, it defaults to 1.
121
+ `short_factor` (`list[float]`, *optional*):
122
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
123
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
124
+ size divided by the number of attention heads divided by 2
125
+ `long_factor` (`list[float]`, *optional*):
126
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
127
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
128
+ size divided by the number of attention heads divided by 2
129
+ `low_freq_factor` (`float`, *optional*):
130
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
131
+ `high_freq_factor` (`float`, *optional*):
132
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
133
+ attention_bias (`bool`, *optional*, defaults to `False`):
134
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
135
+ attention_dropout (`float`, *optional*, defaults to 0.0):
136
+ The dropout ratio for the attention probabilities.
137
+
138
+ ```python
139
+ >>> from transformers import ApertusModel, ApertusConfig
140
+
141
+ >>> # Initializing a Apertus-8B style configuration
142
+ >>> configuration = ApertusConfig()
143
+
144
+ >>> # Initializing a model from the Apertus-8B style configuration
145
+ >>> model = ApertusModel(configuration)
146
+
147
+ >>> # Accessing the model configuration
148
+ >>> configuration = model.config
149
+ ```"""
150
+
151
+ model_type = "apertus"
152
+ base_model_tp_plan = {
153
+ "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
154
+ "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
155
+ "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
156
+ "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
157
+ "layers.*.mlp.up_proj": "colwise",
158
+ "layers.*.mlp.down_proj": "rowwise",
159
+ "layers.*.mlp.gate_proj": "colwise",
160
+ }
161
+
162
+ def __init__(
163
+ self,
164
+ vocab_size=131072,
165
+ hidden_size=4096,
166
+ intermediate_size=14336,
167
+ num_hidden_layers=32,
168
+ num_attention_heads=32,
169
+ num_key_value_heads=None,
170
+ hidden_act="xielu",
171
+ max_position_embeddings=65536,
172
+ initializer_range=0.02,
173
+ rms_norm_eps=1e-5,
174
+ use_cache=True,
175
+ pad_token_id=3,
176
+ bos_token_id=1,
177
+ eos_token_id=2,
178
+ tie_word_embeddings=False,
179
+ rope_theta=12000000.0,
180
+ rope_scaling={
181
+ "rope_type": "llama3",
182
+ "factor": 8.0,
183
+ "original_max_position_embeddings": 8192,
184
+ "low_freq_factor": 1.0,
185
+ "high_freq_factor": 4.0,
186
+ },
187
+ attention_bias=False,
188
+ attention_dropout=0.0,
189
+ **kwargs,
190
+ ):
191
+ super().__init__(
192
+ vocab_size=vocab_size,
193
+ hidden_size=hidden_size,
194
+ intermediate_size=intermediate_size,
195
+ num_hidden_layers=num_hidden_layers,
196
+ num_attention_heads=num_attention_heads,
197
+ num_key_value_heads=num_key_value_heads,
198
+ hidden_act=hidden_act,
199
+ max_position_embeddings=max_position_embeddings,
200
+ initializer_range=initializer_range,
201
+ rms_norm_eps=rms_norm_eps,
202
+ use_cache=use_cache,
203
+ pad_token_id=pad_token_id,
204
+ bos_token_id=bos_token_id,
205
+ eos_token_id=eos_token_id,
206
+ tie_word_embeddings=tie_word_embeddings,
207
+ rope_theta=rope_theta,
208
+ rope_scaling=rope_scaling,
209
+ attention_bias=attention_bias,
210
+ attention_dropout=attention_dropout,
211
+ **kwargs,
212
+ )
213
+ del self.pretraining_tp
214
+ del self.mlp_bias
215
+ del self.head_dim
216
+
217
+
218
+ class ApertusMLP(NemotronMLP):
219
+ def __init__(self, config):
220
+ super().__init__()
221
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
222
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
223
+
224
+
225
+ class ApertusRMSNorm(LlamaRMSNorm):
226
+ pass
227
+
228
+
229
+ class ApertusRotaryEmbedding(LlamaRotaryEmbedding):
230
+ pass
231
+
232
+
233
+ class ApertusAttention(LlamaAttention):
234
+ def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None):
235
+ super().__init__(config, layer_idx)
236
+ self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
237
+ self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
238
+
239
+ def forward(
240
+ self,
241
+ hidden_states: torch.Tensor,
242
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
243
+ attention_mask: Optional[torch.Tensor],
244
+ past_key_values: Optional[Cache] = None,
245
+ cache_position: Optional[torch.LongTensor] = None,
246
+ **kwargs: Unpack[TransformersKwargs],
247
+ ) -> tuple[torch.Tensor, torch.Tensor]:
248
+ input_shape = hidden_states.shape[:-1]
249
+ hidden_shape = (*input_shape, -1, self.head_dim)
250
+
251
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
252
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
253
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
254
+ query_states = self.q_norm(query_states)
255
+ key_states = self.k_norm(key_states)
256
+
257
+ cos, sin = position_embeddings
258
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
259
+
260
+ if past_key_values is not None:
261
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
262
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
263
+
264
+ attention_interface: Callable = eager_attention_forward
265
+ if self.config._attn_implementation != "eager":
266
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
267
+
268
+ attn_output, attn_weights = attention_interface(
269
+ self,
270
+ query_states,
271
+ key_states,
272
+ value_states,
273
+ attention_mask,
274
+ dropout=0.0 if not self.training else self.attention_dropout,
275
+ scaling=self.scaling,
276
+ **kwargs,
277
+ )
278
+
279
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
280
+ attn_output = self.o_proj(attn_output)
281
+ return attn_output, attn_weights
282
+
283
+
284
+ class ApertusDecoderLayer(LlamaDecoderLayer):
285
+ def __init__(self, config: ApertusConfig, layer_idx: int):
286
+ super().__init__(config, layer_idx)
287
+ self.attention_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
288
+ self.feedforward_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
289
+
290
+ del self.input_layernorm
291
+ del self.post_attention_layernorm
292
+
293
+ def forward(
294
+ self,
295
+ hidden_states: torch.Tensor,
296
+ attention_mask: Optional[torch.Tensor] = None,
297
+ position_ids: Optional[torch.LongTensor] = None,
298
+ past_key_values: Optional[Cache] = None,
299
+ use_cache: Optional[bool] = False,
300
+ cache_position: Optional[torch.LongTensor] = None,
301
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
302
+ **kwargs: Unpack[TransformersKwargs],
303
+ ) -> tuple[torch.Tensor]:
304
+ residual = hidden_states
305
+ hidden_states = self.attention_layernorm(hidden_states)
306
+ hidden_states, _ = self.self_attn(
307
+ hidden_states=hidden_states,
308
+ attention_mask=attention_mask,
309
+ position_ids=position_ids,
310
+ past_key_values=past_key_values,
311
+ use_cache=use_cache,
312
+ cache_position=cache_position,
313
+ position_embeddings=position_embeddings,
314
+ **kwargs,
315
+ )
316
+ hidden_states = residual + hidden_states
317
+
318
+ # Fully Connected
319
+ residual = hidden_states
320
+ hidden_states = self.feedforward_layernorm(hidden_states)
321
+ hidden_states = self.mlp(hidden_states)
322
+ hidden_states = residual + hidden_states
323
+ return hidden_states
324
+
325
+
326
+ class ApertusPreTrainedModel(LlamaPreTrainedModel):
327
+ pass
328
+
329
+
330
+ class ApertusModel(LlamaModel):
331
+ pass
332
+
333
+
334
+ class ApertusForCausalLM(LlamaForCausalLM):
335
+ def forward(self, **super_kwargs):
336
+ r"""
337
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
338
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
339
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
340
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
341
+
342
+ Example:
343
+
344
+ ```python
345
+ >>> from transformers import AutoTokenizer, ApertusForCausalLM
346
+
347
+ >>> model = ApertusForCausalLM.from_pretrained("swiss-ai/Apertus-8B")
348
+ >>> tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B")
349
+
350
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
351
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
352
+
353
+ >>> # Generate
354
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
355
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
356
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
357
+ ```"""
358
+ return super().forward(**super_kwargs)
359
+
360
+
361
+ class ApertusForTokenClassification(LlamaForTokenClassification):
362
+ pass
363
+
364
+
365
+ __all__ = [
366
+ "ApertusConfig",
367
+ "ApertusModel",
368
+ "ApertusForCausalLM",
369
+ "ApertusForTokenClassification",
370
+ "ApertusPreTrainedModel",
371
+ ]
venv/lib/python3.13/site-packages/transformers/models/arcee/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_arcee import *
22
+ from .modeling_arcee import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
venv/lib/python3.13/site-packages/transformers/models/arcee/configuration_arcee.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/arcee/modular_arcee.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_arcee.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from ...configuration_utils import PretrainedConfig
23
+ from ...modeling_rope_utils import rope_config_validation
24
+
25
+
26
+ class ArceeConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`ArceeModel`]. It is used to instantiate an Arcee
29
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30
+ defaults will yield a similar configuration to that of the AFM-4.5B-Base.
31
+
32
+ Pre-trained weights are available at
33
+ [arcee-ai/AFM-4.5B](https://huggingface.co/arcee-ai/AFM-4.5B)
34
+ and were used to build the examples below.
35
+
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 32000):
41
+ Vocabulary size of the Arcee model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`ArceeModel`]
43
+ hidden_size (`int`, *optional*, defaults to 2560):
44
+ Dimension of the hidden representations.
45
+ intermediate_size (`int`, *optional*, defaults to 18432):
46
+ Dimension of the MLP representations.
47
+ num_hidden_layers (`int`, *optional*, defaults to 32):
48
+ Number of hidden layers in the Transformer decoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 32):
50
+ Number of attention heads for each attention layer in the Transformer decoder.
51
+ num_key_value_heads (`int`, *optional*):
52
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56
+ by meanpooling all the original heads within that group. For more details checkout [this
57
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
58
+ `num_attention_heads`.
59
+ hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
60
+ The non-linear activation function (function or string) in the decoder.
61
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
62
+ The maximum sequence length that this model might ever be used with. AFM-4.5B-Base supports up to 16384 tokens.
63
+ initializer_range (`float`, *optional*, defaults to 0.02):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
66
+ The epsilon used by the rms normalization layers.
67
+ use_cache (`bool`, *optional*, defaults to `True`):
68
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
69
+ relevant if `config.is_decoder=True`.
70
+ pad_token_id (`int`, *optional*):
71
+ Padding token id.
72
+ bos_token_id (`int`, *optional*, defaults to 128000):
73
+ Beginning of stream token id.
74
+ eos_token_id (`int`, *optional*, defaults to 128001):
75
+ End of stream token id.
76
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
77
+ Whether to tie weight embeddings
78
+ rope_theta (`float`, *optional*, defaults to 10000.0):
79
+ The base period of the RoPE embeddings.
80
+ rope_scaling (`Dict`, *optional*):
81
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
82
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
83
+ accordingly.
84
+ Expected contents:
85
+ `rope_type` (`str`):
86
+ The sub-variant of RoPE to use. Can be one of ['default', 'yarn'], with 'default' being the original RoPE implementation.
87
+ `factor` (`float`, *optional*):
88
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
89
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
90
+ original maximum pre-trained length.
91
+ `original_max_position_embeddings` (`int`, *optional*):
92
+ Used with 'yarn'. The original max position embeddings used during pretraining.
93
+ `attention_factor` (`float`, *optional*):
94
+ Used with 'yarn'. The scaling factor to be applied on the attention computation. If unspecified,
95
+ it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value.
96
+ `beta_fast` (`float`, *optional*):
97
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
98
+ ramp function. If unspecified, it defaults to 32.
99
+ `beta_slow` (`float`, *optional*):
100
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
101
+ ramp function. If unspecified, it defaults to 1.
102
+ attention_bias (`bool`, *optional*, defaults to `False`):
103
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
104
+ attention_dropout (`float`, *optional*, defaults to 0.0):
105
+ The dropout ratio for the attention probabilities.
106
+ mlp_bias (`bool`, *optional*, defaults to `False`):
107
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
108
+ head_dim (`int`, *optional*):
109
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
110
+
111
+ ```python
112
+ >>> from transformers import ArceeModel, ArceeConfig
113
+
114
+ >>> # Initializing an Arcee AFM-4.5B-Base style configuration
115
+ >>> configuration = ArceeConfig()
116
+
117
+ >>> # Initializing a model from the AFM-4.5B-Base style configuration
118
+ >>> model = ArceeModel(configuration)
119
+
120
+ >>> # Accessing the model configuration
121
+ >>> configuration = model.config
122
+ ```"""
123
+
124
+ model_type = "arcee"
125
+ keys_to_ignore_at_inference = ["past_key_values"]
126
+ base_model_tp_plan = {
127
+ "layers.*.self_attn.q_proj": "colwise",
128
+ "layers.*.self_attn.k_proj": "colwise",
129
+ "layers.*.self_attn.v_proj": "colwise",
130
+ "layers.*.self_attn.o_proj": "rowwise",
131
+ "layers.*.mlp.up_proj": "colwise",
132
+ "layers.*.mlp.down_proj": "rowwise",
133
+ }
134
+ base_model_pp_plan = {
135
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
136
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
137
+ "norm": (["hidden_states"], ["hidden_states"]),
138
+ }
139
+
140
+ def __init__(
141
+ self,
142
+ vocab_size=32000,
143
+ hidden_size=2560,
144
+ intermediate_size=18432,
145
+ num_hidden_layers=32,
146
+ num_attention_heads=32,
147
+ num_key_value_heads=None,
148
+ hidden_act="relu2",
149
+ max_position_embeddings=4096,
150
+ initializer_range=0.02,
151
+ rms_norm_eps=1e-5,
152
+ use_cache=True,
153
+ pad_token_id=None,
154
+ bos_token_id=128000,
155
+ eos_token_id=128001,
156
+ tie_word_embeddings=False,
157
+ rope_theta=10000.0,
158
+ rope_scaling=None,
159
+ attention_bias=False,
160
+ attention_dropout=0.0,
161
+ mlp_bias=False,
162
+ head_dim=None,
163
+ **kwargs,
164
+ ):
165
+ super().__init__(
166
+ pad_token_id=pad_token_id,
167
+ bos_token_id=bos_token_id,
168
+ eos_token_id=eos_token_id,
169
+ tie_word_embeddings=tie_word_embeddings,
170
+ **kwargs,
171
+ )
172
+ self.vocab_size = vocab_size
173
+ self.max_position_embeddings = max_position_embeddings
174
+ self.hidden_size = hidden_size
175
+ self.intermediate_size = intermediate_size
176
+ self.num_hidden_layers = num_hidden_layers
177
+ self.num_attention_heads = num_attention_heads
178
+
179
+ # for backward compatibility
180
+ if num_key_value_heads is None:
181
+ num_key_value_heads = num_attention_heads
182
+
183
+ self.num_key_value_heads = num_key_value_heads
184
+ self.hidden_act = hidden_act
185
+ self.initializer_range = initializer_range
186
+ self.rms_norm_eps = rms_norm_eps
187
+ self.use_cache = use_cache
188
+ self.rope_theta = rope_theta
189
+ self.rope_scaling = rope_scaling
190
+ self.attention_bias = attention_bias
191
+ self.attention_dropout = attention_dropout
192
+ self.mlp_bias = mlp_bias
193
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
194
+ # Validate the correctness of rotary position embeddings parameters
195
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
196
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
197
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
198
+ rope_config_validation(self)
199
+
200
+
201
+ __all__ = ["ArceeConfig"]
venv/lib/python3.13/site-packages/transformers/models/arcee/modeling_arcee.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/arcee/modular_arcee.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_arcee.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from typing import Callable, Optional, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+
27
+ from transformers.utils import auto_docstring
28
+
29
+ from ...activations import ACT2FN
30
+ from ...cache_utils import Cache, DynamicCache
31
+ from ...generation import GenerationMixin
32
+ from ...integrations import use_kernel_forward_from_hub
33
+ from ...masking_utils import create_causal_mask
34
+ from ...modeling_layers import (
35
+ GenericForQuestionAnswering,
36
+ GenericForSequenceClassification,
37
+ GenericForTokenClassification,
38
+ GradientCheckpointingLayer,
39
+ )
40
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
41
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
43
+ from ...processing_utils import Unpack
44
+ from ...utils import TransformersKwargs, can_return_tuple
45
+ from ...utils.deprecation import deprecate_kwarg
46
+ from ...utils.generic import check_model_inputs
47
+ from .configuration_arcee import ArceeConfig
48
+
49
+
50
+ class ArceeMLP(nn.Module):
51
+ def __init__(self, config):
52
+ super().__init__()
53
+ self.config = config
54
+ self.hidden_size = config.hidden_size
55
+ self.intermediate_size = config.intermediate_size
56
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
57
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
58
+ self.act_fn = ACT2FN[config.hidden_act]
59
+
60
+ def forward(self, x):
61
+ return self.down_proj(self.act_fn(self.up_proj(x)))
62
+
63
+
64
+ @use_kernel_forward_from_hub("RMSNorm")
65
+ class ArceeRMSNorm(nn.Module):
66
+ def __init__(self, hidden_size, eps=1e-6):
67
+ """
68
+ ArceeRMSNorm is equivalent to T5LayerNorm
69
+ """
70
+ super().__init__()
71
+ self.weight = nn.Parameter(torch.ones(hidden_size))
72
+ self.variance_epsilon = eps
73
+
74
+ def forward(self, hidden_states):
75
+ input_dtype = hidden_states.dtype
76
+ hidden_states = hidden_states.to(torch.float32)
77
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
78
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
79
+ return self.weight * hidden_states.to(input_dtype)
80
+
81
+ def extra_repr(self):
82
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
83
+
84
+
85
+ class ArceeRotaryEmbedding(nn.Module):
86
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
87
+
88
+ def __init__(self, config: ArceeConfig, device=None):
89
+ super().__init__()
90
+ # BC: "rope_type" was originally "type"
91
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
92
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
93
+ else:
94
+ self.rope_type = "default"
95
+ self.max_seq_len_cached = config.max_position_embeddings
96
+ self.original_max_seq_len = config.max_position_embeddings
97
+
98
+ self.config = config
99
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
100
+
101
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
102
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
103
+ self.original_inv_freq = self.inv_freq
104
+
105
+ @torch.no_grad()
106
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
107
+ def forward(self, x, position_ids):
108
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
109
+ position_ids_expanded = position_ids[:, None, :].float()
110
+
111
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
112
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
113
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
114
+ emb = torch.cat((freqs, freqs), dim=-1)
115
+ cos = emb.cos() * self.attention_scaling
116
+ sin = emb.sin() * self.attention_scaling
117
+
118
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
119
+
120
+
121
+ def rotate_half(x):
122
+ """Rotates half the hidden dims of the input."""
123
+ x1 = x[..., : x.shape[-1] // 2]
124
+ x2 = x[..., x.shape[-1] // 2 :]
125
+ return torch.cat((-x2, x1), dim=-1)
126
+
127
+
128
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
129
+ """Applies Rotary Position Embedding to the query and key tensors.
130
+
131
+ Args:
132
+ q (`torch.Tensor`): The query tensor.
133
+ k (`torch.Tensor`): The key tensor.
134
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
135
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
136
+ position_ids (`torch.Tensor`, *optional*):
137
+ Deprecated and unused.
138
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
139
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
140
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
141
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
142
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
143
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
144
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
145
+ Returns:
146
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
147
+ """
148
+ cos = cos.unsqueeze(unsqueeze_dim)
149
+ sin = sin.unsqueeze(unsqueeze_dim)
150
+ q_embed = (q * cos) + (rotate_half(q) * sin)
151
+ k_embed = (k * cos) + (rotate_half(k) * sin)
152
+ return q_embed, k_embed
153
+
154
+
155
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
156
+ """
157
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
158
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
159
+ """
160
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
161
+ if n_rep == 1:
162
+ return hidden_states
163
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
164
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
165
+
166
+
167
+ def eager_attention_forward(
168
+ module: nn.Module,
169
+ query: torch.Tensor,
170
+ key: torch.Tensor,
171
+ value: torch.Tensor,
172
+ attention_mask: Optional[torch.Tensor],
173
+ scaling: float,
174
+ dropout: float = 0.0,
175
+ **kwargs: Unpack[TransformersKwargs],
176
+ ):
177
+ key_states = repeat_kv(key, module.num_key_value_groups)
178
+ value_states = repeat_kv(value, module.num_key_value_groups)
179
+
180
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
181
+ if attention_mask is not None:
182
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
183
+ attn_weights = attn_weights + causal_mask
184
+
185
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
186
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
187
+ attn_output = torch.matmul(attn_weights, value_states)
188
+ attn_output = attn_output.transpose(1, 2).contiguous()
189
+
190
+ return attn_output, attn_weights
191
+
192
+
193
+ class ArceeAttention(nn.Module):
194
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
195
+
196
+ def __init__(self, config: ArceeConfig, layer_idx: int):
197
+ super().__init__()
198
+ self.config = config
199
+ self.layer_idx = layer_idx
200
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
201
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
202
+ self.scaling = self.head_dim**-0.5
203
+ self.attention_dropout = config.attention_dropout
204
+ self.is_causal = True
205
+
206
+ self.q_proj = nn.Linear(
207
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
208
+ )
209
+ self.k_proj = nn.Linear(
210
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
211
+ )
212
+ self.v_proj = nn.Linear(
213
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
214
+ )
215
+ self.o_proj = nn.Linear(
216
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
217
+ )
218
+
219
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
220
+ def forward(
221
+ self,
222
+ hidden_states: torch.Tensor,
223
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
224
+ attention_mask: Optional[torch.Tensor],
225
+ past_key_values: Optional[Cache] = None,
226
+ cache_position: Optional[torch.LongTensor] = None,
227
+ **kwargs: Unpack[TransformersKwargs],
228
+ ) -> tuple[torch.Tensor, torch.Tensor]:
229
+ input_shape = hidden_states.shape[:-1]
230
+ hidden_shape = (*input_shape, -1, self.head_dim)
231
+
232
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
233
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
234
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
235
+
236
+ cos, sin = position_embeddings
237
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
238
+
239
+ if past_key_values is not None:
240
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
241
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
242
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
243
+
244
+ attention_interface: Callable = eager_attention_forward
245
+ if self.config._attn_implementation != "eager":
246
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
247
+
248
+ attn_output, attn_weights = attention_interface(
249
+ self,
250
+ query_states,
251
+ key_states,
252
+ value_states,
253
+ attention_mask,
254
+ dropout=0.0 if not self.training else self.attention_dropout,
255
+ scaling=self.scaling,
256
+ **kwargs,
257
+ )
258
+
259
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
260
+ attn_output = self.o_proj(attn_output)
261
+ return attn_output, attn_weights
262
+
263
+
264
+ class ArceeDecoderLayer(GradientCheckpointingLayer):
265
+ def __init__(self, config: ArceeConfig, layer_idx: int):
266
+ super().__init__()
267
+ self.hidden_size = config.hidden_size
268
+
269
+ self.self_attn = ArceeAttention(config=config, layer_idx=layer_idx)
270
+
271
+ self.mlp = ArceeMLP(config)
272
+ self.input_layernorm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
273
+ self.post_attention_layernorm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
274
+
275
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
276
+ def forward(
277
+ self,
278
+ hidden_states: torch.Tensor,
279
+ attention_mask: Optional[torch.Tensor] = None,
280
+ position_ids: Optional[torch.LongTensor] = None,
281
+ past_key_values: Optional[Cache] = None,
282
+ use_cache: Optional[bool] = False,
283
+ cache_position: Optional[torch.LongTensor] = None,
284
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
285
+ **kwargs: Unpack[TransformersKwargs],
286
+ ) -> torch.Tensor:
287
+ residual = hidden_states
288
+ hidden_states = self.input_layernorm(hidden_states)
289
+ # Self Attention
290
+ hidden_states, _ = self.self_attn(
291
+ hidden_states=hidden_states,
292
+ attention_mask=attention_mask,
293
+ position_ids=position_ids,
294
+ past_key_values=past_key_values,
295
+ use_cache=use_cache,
296
+ cache_position=cache_position,
297
+ position_embeddings=position_embeddings,
298
+ **kwargs,
299
+ )
300
+ hidden_states = residual + hidden_states
301
+
302
+ # Fully Connected
303
+ residual = hidden_states
304
+ hidden_states = self.post_attention_layernorm(hidden_states)
305
+ hidden_states = self.mlp(hidden_states)
306
+ hidden_states = residual + hidden_states
307
+ return hidden_states
308
+
309
+
310
+ @auto_docstring
311
+ class ArceePreTrainedModel(PreTrainedModel):
312
+ config: ArceeConfig
313
+ base_model_prefix = "model"
314
+ supports_gradient_checkpointing = True
315
+ _no_split_modules = ["ArceeDecoderLayer"]
316
+ _skip_keys_device_placement = ["past_key_values"]
317
+ _supports_flash_attn = True
318
+ _supports_sdpa = True
319
+ _supports_flex_attn = True
320
+
321
+ _can_compile_fullgraph = True
322
+ _supports_attention_backend = True
323
+ _can_record_outputs = {
324
+ "hidden_states": ArceeDecoderLayer,
325
+ "attentions": ArceeAttention,
326
+ }
327
+
328
+
329
+ @auto_docstring
330
+ class ArceeModel(ArceePreTrainedModel):
331
+ def __init__(self, config: ArceeConfig):
332
+ super().__init__(config)
333
+ self.padding_idx = config.pad_token_id
334
+ self.vocab_size = config.vocab_size
335
+
336
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
337
+ self.layers = nn.ModuleList(
338
+ [ArceeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
339
+ )
340
+ self.norm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
341
+ self.rotary_emb = ArceeRotaryEmbedding(config=config)
342
+ self.gradient_checkpointing = False
343
+
344
+ # Initialize weights and apply final processing
345
+ self.post_init()
346
+
347
+ @check_model_inputs()
348
+ @auto_docstring
349
+ def forward(
350
+ self,
351
+ input_ids: Optional[torch.LongTensor] = None,
352
+ attention_mask: Optional[torch.Tensor] = None,
353
+ position_ids: Optional[torch.LongTensor] = None,
354
+ past_key_values: Optional[Cache] = None,
355
+ inputs_embeds: Optional[torch.FloatTensor] = None,
356
+ cache_position: Optional[torch.LongTensor] = None,
357
+ use_cache: Optional[bool] = None,
358
+ **kwargs: Unpack[TransformersKwargs],
359
+ ) -> BaseModelOutputWithPast:
360
+ if (input_ids is None) ^ (inputs_embeds is not None):
361
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
362
+
363
+ if inputs_embeds is None:
364
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
365
+
366
+ if use_cache and past_key_values is None:
367
+ past_key_values = DynamicCache(config=self.config)
368
+
369
+ if cache_position is None:
370
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
371
+ cache_position: torch.Tensor = torch.arange(
372
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
373
+ )
374
+
375
+ if position_ids is None:
376
+ position_ids = cache_position.unsqueeze(0)
377
+
378
+ causal_mask = create_causal_mask(
379
+ config=self.config,
380
+ input_embeds=inputs_embeds,
381
+ attention_mask=attention_mask,
382
+ cache_position=cache_position,
383
+ past_key_values=past_key_values,
384
+ position_ids=position_ids,
385
+ )
386
+
387
+ hidden_states = inputs_embeds
388
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
389
+
390
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
391
+ hidden_states = decoder_layer(
392
+ hidden_states,
393
+ attention_mask=causal_mask,
394
+ position_ids=position_ids,
395
+ past_key_values=past_key_values,
396
+ cache_position=cache_position,
397
+ position_embeddings=position_embeddings,
398
+ **kwargs,
399
+ )
400
+
401
+ hidden_states = self.norm(hidden_states)
402
+ return BaseModelOutputWithPast(
403
+ last_hidden_state=hidden_states,
404
+ past_key_values=past_key_values,
405
+ )
406
+
407
+
408
+ @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
409
+ class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
410
+ _tied_weights_keys = ["lm_head.weight"]
411
+ _tp_plan = {"lm_head": "colwise_rep"}
412
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
413
+
414
+ def __init__(self, config):
415
+ super().__init__(config)
416
+ self.model = ArceeModel(config)
417
+ self.vocab_size = config.vocab_size
418
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
419
+
420
+ # Initialize weights and apply final processing
421
+ self.post_init()
422
+
423
+ @can_return_tuple
424
+ @auto_docstring
425
+ def forward(
426
+ self,
427
+ input_ids: Optional[torch.LongTensor] = None,
428
+ attention_mask: Optional[torch.Tensor] = None,
429
+ position_ids: Optional[torch.LongTensor] = None,
430
+ past_key_values: Optional[Cache] = None,
431
+ inputs_embeds: Optional[torch.FloatTensor] = None,
432
+ labels: Optional[torch.LongTensor] = None,
433
+ use_cache: Optional[bool] = None,
434
+ cache_position: Optional[torch.LongTensor] = None,
435
+ logits_to_keep: Union[int, torch.Tensor] = 0,
436
+ **kwargs: Unpack[TransformersKwargs],
437
+ ) -> CausalLMOutputWithPast:
438
+ r"""
439
+ Example:
440
+
441
+ ```python
442
+ >>> from transformers import AutoTokenizer, ArceeForCausalLM
443
+
444
+ >>> model = ArceeForCausalLM.from_pretrained("meta-arcee/Arcee-2-7b-hf")
445
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-arcee/Arcee-2-7b-hf")
446
+
447
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
448
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
449
+
450
+ >>> # Generate
451
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
452
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
453
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
454
+ ```"""
455
+ outputs: BaseModelOutputWithPast = self.model(
456
+ input_ids=input_ids,
457
+ attention_mask=attention_mask,
458
+ position_ids=position_ids,
459
+ past_key_values=past_key_values,
460
+ inputs_embeds=inputs_embeds,
461
+ use_cache=use_cache,
462
+ cache_position=cache_position,
463
+ **kwargs,
464
+ )
465
+
466
+ hidden_states = outputs.last_hidden_state
467
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
468
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
469
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
470
+
471
+ loss = None
472
+ if labels is not None:
473
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
474
+
475
+ return CausalLMOutputWithPast(
476
+ loss=loss,
477
+ logits=logits,
478
+ past_key_values=outputs.past_key_values,
479
+ hidden_states=outputs.hidden_states,
480
+ attentions=outputs.attentions,
481
+ )
482
+
483
+
484
+ @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
485
+ class ArceeForSequenceClassification(GenericForSequenceClassification, ArceePreTrainedModel):
486
+ pass
487
+
488
+
489
+ @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
490
+ class ArceeForQuestionAnswering(GenericForQuestionAnswering, ArceePreTrainedModel):
491
+ base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
492
+
493
+
494
+ @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
495
+ class ArceeForTokenClassification(GenericForTokenClassification, ArceePreTrainedModel):
496
+ pass
497
+
498
+
499
+ __all__ = [
500
+ "ArceeForCausalLM",
501
+ "ArceeForQuestionAnswering",
502
+ "ArceeForSequenceClassification",
503
+ "ArceeForTokenClassification",
504
+ "ArceeModel",
505
+ "ArceePreTrainedModel",
506
+ ]
venv/lib/python3.13/site-packages/transformers/models/arcee/modular_arcee.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Arcee model."""
16
+
17
+ from transformers.utils import auto_docstring, logging
18
+
19
+ from ..llama.configuration_llama import LlamaConfig
20
+ from ..llama.modeling_llama import (
21
+ LlamaForCausalLM,
22
+ LlamaForQuestionAnswering,
23
+ LlamaForSequenceClassification,
24
+ LlamaForTokenClassification,
25
+ )
26
+ from ..nemotron.modeling_nemotron import NemotronMLP
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class ArceeConfig(LlamaConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`ArceeModel`]. It is used to instantiate an Arcee
35
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
36
+ defaults will yield a similar configuration to that of the AFM-4.5B-Base.
37
+
38
+ Pre-trained weights are available at
39
+ [arcee-ai/AFM-4.5B](https://huggingface.co/arcee-ai/AFM-4.5B)
40
+ and were used to build the examples below.
41
+
42
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
43
+ documentation from [`PretrainedConfig`] for more information.
44
+
45
+ Args:
46
+ vocab_size (`int`, *optional*, defaults to 32000):
47
+ Vocabulary size of the Arcee model. Defines the number of different tokens that can be represented by the
48
+ `inputs_ids` passed when calling [`ArceeModel`]
49
+ hidden_size (`int`, *optional*, defaults to 2560):
50
+ Dimension of the hidden representations.
51
+ intermediate_size (`int`, *optional*, defaults to 18432):
52
+ Dimension of the MLP representations.
53
+ num_hidden_layers (`int`, *optional*, defaults to 32):
54
+ Number of hidden layers in the Transformer decoder.
55
+ num_attention_heads (`int`, *optional*, defaults to 32):
56
+ Number of attention heads for each attention layer in the Transformer decoder.
57
+ num_key_value_heads (`int`, *optional*):
58
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
59
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
60
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
61
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
62
+ by meanpooling all the original heads within that group. For more details checkout [this
63
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
64
+ `num_attention_heads`.
65
+ hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
66
+ The non-linear activation function (function or string) in the decoder.
67
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
68
+ The maximum sequence length that this model might ever be used with. AFM-4.5B-Base supports up to 16384 tokens.
69
+ initializer_range (`float`, *optional*, defaults to 0.02):
70
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
71
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
72
+ The epsilon used by the rms normalization layers.
73
+ use_cache (`bool`, *optional*, defaults to `True`):
74
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
75
+ relevant if `config.is_decoder=True`.
76
+ pad_token_id (`int`, *optional*):
77
+ Padding token id.
78
+ bos_token_id (`int`, *optional*, defaults to 128000):
79
+ Beginning of stream token id.
80
+ eos_token_id (`int`, *optional*, defaults to 128001):
81
+ End of stream token id.
82
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
83
+ Whether to tie weight embeddings
84
+ rope_theta (`float`, *optional*, defaults to 10000.0):
85
+ The base period of the RoPE embeddings.
86
+ rope_scaling (`Dict`, *optional*):
87
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
88
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
89
+ accordingly.
90
+ Expected contents:
91
+ `rope_type` (`str`):
92
+ The sub-variant of RoPE to use. Can be one of ['default', 'yarn'], with 'default' being the original RoPE implementation.
93
+ `factor` (`float`, *optional*):
94
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
95
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
96
+ original maximum pre-trained length.
97
+ `original_max_position_embeddings` (`int`, *optional*):
98
+ Used with 'yarn'. The original max position embeddings used during pretraining.
99
+ `attention_factor` (`float`, *optional*):
100
+ Used with 'yarn'. The scaling factor to be applied on the attention computation. If unspecified,
101
+ it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value.
102
+ `beta_fast` (`float`, *optional*):
103
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
104
+ ramp function. If unspecified, it defaults to 32.
105
+ `beta_slow` (`float`, *optional*):
106
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
107
+ ramp function. If unspecified, it defaults to 1.
108
+ attention_bias (`bool`, *optional*, defaults to `False`):
109
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
110
+ attention_dropout (`float`, *optional*, defaults to 0.0):
111
+ The dropout ratio for the attention probabilities.
112
+ mlp_bias (`bool`, *optional*, defaults to `False`):
113
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
114
+ head_dim (`int`, *optional*):
115
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
116
+
117
+ ```python
118
+ >>> from transformers import ArceeModel, ArceeConfig
119
+
120
+ >>> # Initializing an Arcee AFM-4.5B-Base style configuration
121
+ >>> configuration = ArceeConfig()
122
+
123
+ >>> # Initializing a model from the AFM-4.5B-Base style configuration
124
+ >>> model = ArceeModel(configuration)
125
+
126
+ >>> # Accessing the model configuration
127
+ >>> configuration = model.config
128
+ ```"""
129
+
130
+ model_type = "arcee"
131
+ base_model_tp_plan = {
132
+ "layers.*.self_attn.q_proj": "colwise",
133
+ "layers.*.self_attn.k_proj": "colwise",
134
+ "layers.*.self_attn.v_proj": "colwise",
135
+ "layers.*.self_attn.o_proj": "rowwise",
136
+ "layers.*.mlp.up_proj": "colwise",
137
+ "layers.*.mlp.down_proj": "rowwise",
138
+ }
139
+
140
+ def __init__(
141
+ self,
142
+ vocab_size=32000,
143
+ hidden_size=2560,
144
+ intermediate_size=18432,
145
+ num_hidden_layers=32,
146
+ num_attention_heads=32,
147
+ num_key_value_heads=None,
148
+ hidden_act="relu2",
149
+ max_position_embeddings=4096,
150
+ initializer_range=0.02,
151
+ rms_norm_eps=1e-5,
152
+ use_cache=True,
153
+ pad_token_id=None,
154
+ bos_token_id=128000,
155
+ eos_token_id=128001,
156
+ tie_word_embeddings=False,
157
+ rope_theta=10000.0,
158
+ rope_scaling=None,
159
+ attention_bias=False,
160
+ attention_dropout=0.0,
161
+ mlp_bias=False,
162
+ head_dim=None,
163
+ **kwargs,
164
+ ):
165
+ super().__init__(
166
+ vocab_size=vocab_size,
167
+ hidden_size=hidden_size,
168
+ intermediate_size=intermediate_size,
169
+ num_hidden_layers=num_hidden_layers,
170
+ num_attention_heads=num_attention_heads,
171
+ num_key_value_heads=num_key_value_heads,
172
+ hidden_act=hidden_act,
173
+ max_position_embeddings=max_position_embeddings,
174
+ initializer_range=initializer_range,
175
+ rms_norm_eps=rms_norm_eps,
176
+ use_cache=use_cache,
177
+ pad_token_id=pad_token_id,
178
+ bos_token_id=bos_token_id,
179
+ eos_token_id=eos_token_id,
180
+ tie_word_embeddings=tie_word_embeddings,
181
+ rope_theta=rope_theta,
182
+ rope_scaling=rope_scaling,
183
+ attention_bias=attention_bias,
184
+ attention_dropout=attention_dropout,
185
+ mlp_bias=mlp_bias,
186
+ head_dim=head_dim,
187
+ **kwargs,
188
+ )
189
+
190
+ del self.pretraining_tp
191
+
192
+
193
+ class ArceeMLP(NemotronMLP):
194
+ pass
195
+
196
+
197
+ @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
198
+ class ArceeForCausalLM(LlamaForCausalLM):
199
+ pass
200
+
201
+
202
+ @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
203
+ class ArceeForSequenceClassification(LlamaForSequenceClassification):
204
+ pass
205
+
206
+
207
+ @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
208
+ class ArceeForQuestionAnswering(LlamaForQuestionAnswering):
209
+ pass
210
+
211
+
212
+ @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
213
+ class ArceeForTokenClassification(LlamaForTokenClassification):
214
+ pass
215
+
216
+
217
+ __all__ = [
218
+ "ArceeConfig",
219
+ "ArceeForCausalLM",
220
+ "ArceeForQuestionAnswering",
221
+ "ArceeForSequenceClassification",
222
+ "ArceeForTokenClassification",
223
+ "ArceeModel", # noqa: F822
224
+ "ArceePreTrainedModel", # noqa: F822
225
+ ]
venv/lib/python3.13/site-packages/transformers/models/aria/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_aria import *
22
+ from .image_processing_aria import *
23
+ from .modeling_aria import *
24
+ from .processing_aria import *
25
+
26
+ else:
27
+ import sys
28
+
29
+ _file = globals()["__file__"]
30
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
venv/lib/python3.13/site-packages/transformers/models/aria/configuration_aria.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/aria/modular_aria.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_aria.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ from typing import Optional
22
+
23
+ from ...configuration_utils import PretrainedConfig
24
+ from ...modeling_rope_utils import rope_config_validation
25
+ from ..auto import CONFIG_MAPPING, AutoConfig
26
+
27
+
28
+ class AriaTextConfig(PretrainedConfig):
29
+ r"""
30
+ This class handles the configuration for the text component of the Aria model.
31
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
32
+ [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
33
+ This class extends the LlamaConfig to include additional parameters specific to the Mixture of Experts (MoE) architecture.
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 32000):
37
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`LlamaModel`]
39
+ hidden_size (`int`, *optional*, defaults to 4096):
40
+ Dimension of the hidden representations.
41
+ intermediate_size (`int`, *optional*, defaults to 4096):
42
+ The size of the MLP representations.
43
+ num_hidden_layers (`int`, *optional*, defaults to 32):
44
+ Number of hidden layers in the Transformer decoder.
45
+ num_attention_heads (`int`, *optional*, defaults to 32):
46
+ Number of attention heads for each attention layer in the Transformer decoder.
47
+ num_key_value_heads (`int`, *optional*):
48
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
49
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
50
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
51
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
52
+ by meanpooling all the original heads within that group. For more details, check out [this
53
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
54
+ `num_attention_heads`.
55
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
56
+ The non-linear activation function (function or string) in the decoder.
57
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
58
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
59
+ Llama 2 up to 4096, CodeLlama up to 16384.
60
+ initializer_range (`float`, *optional*, defaults to 0.02):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
63
+ The epsilon used by the rms normalization layers.
64
+ use_cache (`bool`, *optional*, defaults to `True`):
65
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
66
+ relevant if `config.is_decoder=True`.
67
+ pad_token_id (`int`, *optional*, defaults to 2):
68
+ Padding token id.
69
+ bos_token_id (`int`, *optional*, defaults to 1):
70
+ Beginning of stream token id.
71
+ eos_token_id (`int`, *optional*, defaults to 2):
72
+ End of stream token id.
73
+ pretraining_tp (`int`, *optional*, defaults to 1):
74
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
75
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
76
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
77
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
78
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
79
+ Whether to tie weight embeddings
80
+ rope_theta (`float`, *optional*, defaults to 10000.0):
81
+ The base period of the RoPE embeddings.
82
+ rope_scaling (`Dict`, *optional*):
83
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
84
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
85
+ accordingly.
86
+ Expected contents:
87
+ `rope_type` (`str`):
88
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
89
+ 'llama3'], with 'default' being the original RoPE implementation.
90
+ `factor` (`float`, *optional*):
91
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
92
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
93
+ original maximum pre-trained length.
94
+ `original_max_position_embeddings` (`int`, *optional*):
95
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
96
+ pretraining.
97
+ `attention_factor` (`float`, *optional*):
98
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
99
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
100
+ `factor` field to infer the suggested value.
101
+ `beta_fast` (`float`, *optional*):
102
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
103
+ ramp function. If unspecified, it defaults to 32.
104
+ `beta_slow` (`float`, *optional*):
105
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
106
+ ramp function. If unspecified, it defaults to 1.
107
+ `short_factor` (`list[float]`, *optional*):
108
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
109
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
110
+ size divided by the number of attention heads divided by 2
111
+ `long_factor` (`list[float]`, *optional*):
112
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
113
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
114
+ size divided by the number of attention heads divided by 2
115
+ `low_freq_factor` (`float`, *optional*):
116
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
117
+ `high_freq_factor` (`float`, *optional*):
118
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
119
+ attention_bias (`bool`, *optional*, defaults to `False`):
120
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
121
+ attention_dropout (`float`, *optional*, defaults to 0.0):
122
+ The dropout ratio for the attention probabilities.
123
+ mlp_bias (`bool`, *optional*, defaults to `False`):
124
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
125
+ head_dim (`int`, *optional*):
126
+ The attention head dimension. If None, it will default to hidden_size // num_heads
127
+ moe_num_experts (`int`, *optional*, defaults to 8):
128
+ The number of experts in the MoE layer.
129
+ moe_topk (`int`, *optional*, defaults to 2):
130
+ The number of top experts to route to for each token.
131
+ moe_num_shared_experts (`int`, *optional*, defaults to 2):
132
+ The number of shared experts.
133
+ """
134
+
135
+ model_type = "aria_text"
136
+ keys_to_ignore_at_inference = ["past_key_values"]
137
+ # Default tensor parallel plan for base model `AriaTextModel`
138
+ base_model_tp_plan = {
139
+ "layers.*.self_attn.q_proj": "colwise",
140
+ "layers.*.self_attn.k_proj": "colwise",
141
+ "layers.*.self_attn.v_proj": "colwise",
142
+ "layers.*.self_attn.o_proj": "rowwise",
143
+ "layers.*.mlp.gate_proj": "colwise",
144
+ "layers.*.mlp.up_proj": "colwise",
145
+ "layers.*.mlp.down_proj": "rowwise",
146
+ }
147
+ base_model_pp_plan = {
148
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
149
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
150
+ "norm": (["hidden_states"], ["hidden_states"]),
151
+ }
152
+ base_config_key = "text_config"
153
+
154
+ def __init__(
155
+ self,
156
+ vocab_size=32000,
157
+ hidden_size=4096,
158
+ intermediate_size: int = 4096,
159
+ num_hidden_layers=32,
160
+ num_attention_heads=32,
161
+ num_key_value_heads=None,
162
+ hidden_act="silu",
163
+ max_position_embeddings=2048,
164
+ initializer_range=0.02,
165
+ rms_norm_eps=1e-6,
166
+ use_cache=True,
167
+ pad_token_id=2,
168
+ bos_token_id=1,
169
+ eos_token_id=2,
170
+ pretraining_tp=1,
171
+ tie_word_embeddings=False,
172
+ rope_theta=10000.0,
173
+ rope_scaling=None,
174
+ attention_bias=False,
175
+ attention_dropout=0.0,
176
+ mlp_bias=False,
177
+ head_dim=None,
178
+ moe_num_experts: int = 8,
179
+ moe_topk: int = 2,
180
+ moe_num_shared_experts: int = 2,
181
+ **kwargs,
182
+ ):
183
+ super().__init__(
184
+ pad_token_id=pad_token_id,
185
+ bos_token_id=bos_token_id,
186
+ eos_token_id=eos_token_id,
187
+ tie_word_embeddings=tie_word_embeddings,
188
+ **kwargs,
189
+ )
190
+ self.vocab_size = vocab_size
191
+ self.max_position_embeddings = max_position_embeddings
192
+ self.hidden_size = hidden_size
193
+ self.intermediate_size = intermediate_size
194
+ self.num_hidden_layers = num_hidden_layers
195
+ self.num_attention_heads = num_attention_heads
196
+
197
+ # for backward compatibility
198
+ if num_key_value_heads is None:
199
+ num_key_value_heads = num_attention_heads
200
+
201
+ self.num_key_value_heads = num_key_value_heads
202
+ self.hidden_act = hidden_act
203
+ self.initializer_range = initializer_range
204
+ self.rms_norm_eps = rms_norm_eps
205
+ self.pretraining_tp = pretraining_tp
206
+ self.use_cache = use_cache
207
+ self.rope_theta = rope_theta
208
+ self.rope_scaling = rope_scaling
209
+ self.attention_bias = attention_bias
210
+ self.attention_dropout = attention_dropout
211
+ self.mlp_bias = mlp_bias
212
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
213
+ # Validate the correctness of rotary position embeddings parameters
214
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
215
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
216
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
217
+ rope_config_validation(self)
218
+ self.moe_num_experts = moe_num_experts
219
+ self.moe_topk = moe_topk
220
+ self.moe_num_shared_experts = moe_num_shared_experts
221
+
222
+
223
+ class AriaConfig(PretrainedConfig):
224
+ r"""
225
+ This class handles the configuration for both vision and text components of the Aria model,
226
+ as well as additional parameters for image token handling and projector mapping.
227
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
228
+ [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
229
+
230
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
231
+ documentation from [`PretrainedConfig`] for more information.
232
+
233
+ Args:
234
+ vision_config (`AriaVisionConfig` or `dict`, *optional*):
235
+ Configuration for the vision component.
236
+ vision_feature_layer (`int`, *optional*, defaults to -1):
237
+ The index of the layer to select the vision feature.
238
+ text_config (`AriaTextConfig` or `dict`, *optional*):
239
+ Configuration for the text component.
240
+ projector_patch_to_query_dict (`dict`, *optional*):
241
+ Mapping of patch sizes to query dimensions.
242
+ image_token_index (`int`, *optional*, defaults to 9):
243
+ Index used to represent image tokens.
244
+ initializer_range (`float`, *optional*, defaults to 0.02):
245
+ The standard deviation of the truncated normal initializer for initializing all weight matrices.
246
+
247
+ Attributes:
248
+ model_type (`str`):
249
+ Type of the model, set to `"aria"`.
250
+ image_token_index (`int`):
251
+ Index used to represent image tokens.
252
+ projector_patch_to_query_dict (`dict`):
253
+ Mapping of patch sizes to query dimensions.
254
+ vision_config (`AriaVisionConfig`):
255
+ Configuration for the vision component.
256
+ text_config (`AriaTextConfig`):
257
+ Configuration for the text component.
258
+ """
259
+
260
+ model_type = "aria"
261
+ attribute_map = {
262
+ "image_token_id": "image_token_index",
263
+ }
264
+ sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig}
265
+
266
+ def __init__(
267
+ self,
268
+ vision_config=None,
269
+ vision_feature_layer: int = -1,
270
+ text_config: AriaTextConfig = None,
271
+ projector_patch_to_query_dict: Optional[dict] = None,
272
+ image_token_index: int = 9,
273
+ initializer_range: float = 0.02,
274
+ **kwargs,
275
+ ):
276
+ self.image_token_index = image_token_index
277
+
278
+ # Convert the keys and values of projector_patch_to_query_dict to integers
279
+ # This ensures consistency even if they were provided as strings
280
+ if projector_patch_to_query_dict is None:
281
+ projector_patch_to_query_dict = {
282
+ 1225: 128,
283
+ 4900: 256,
284
+ }
285
+ self.projector_patch_to_query_dict = {int(k): int(v) for k, v in projector_patch_to_query_dict.items()}
286
+ self.max_value_projector_patch_to_query_dict = max(self.projector_patch_to_query_dict.values())
287
+ self.vision_feature_layer = vision_feature_layer
288
+ if isinstance(vision_config, dict):
289
+ vision_config["model_type"] = "idefics3_vision"
290
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
291
+ elif vision_config is None:
292
+ vision_config = CONFIG_MAPPING["idefics3_vision"]()
293
+
294
+ self.vision_config = vision_config
295
+ self.initializer_range = initializer_range
296
+
297
+ if isinstance(text_config, dict) and "model_type" in text_config:
298
+ text_config = AriaTextConfig(**text_config)
299
+ elif text_config is None:
300
+ text_config = AriaTextConfig()
301
+
302
+ self.text_config = text_config
303
+
304
+ super().__init__(**kwargs)
305
+
306
+
307
+ __all__ = ["AriaConfig", "AriaTextConfig"]
venv/lib/python3.13/site-packages/transformers/models/aria/image_processing_aria.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/aria/modular_aria.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_aria.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ from collections.abc import Iterable
22
+ from typing import Optional, Union
23
+
24
+ import numpy as np
25
+
26
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
27
+ from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
28
+ from ...image_utils import (
29
+ ChannelDimension,
30
+ ImageInput,
31
+ PILImageResampling,
32
+ get_image_size,
33
+ infer_channel_dimension_format,
34
+ is_scaled_image,
35
+ make_flat_list_of_images,
36
+ to_numpy_array,
37
+ valid_images,
38
+ validate_preprocess_arguments,
39
+ )
40
+ from ...utils import TensorType, logging
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ def divide_to_patches(image: np.ndarray, patch_size: int, input_data_format) -> list[np.ndarray]:
47
+ """
48
+ Divides an image into patches of a specified size.
49
+
50
+ Args:
51
+ image (`np.ndarray`):
52
+ The input image.
53
+ patch_size (`int`):
54
+ The size of each patch.
55
+ input_data_format (`ChannelDimension` or `str`):
56
+ The channel dimension format of the input image.
57
+
58
+ Returns:
59
+ list: A list of np.ndarray representing the patches.
60
+ """
61
+ patches = []
62
+ height, width = get_image_size(image, channel_dim=input_data_format)
63
+ for i in range(0, height, patch_size):
64
+ for j in range(0, width, patch_size):
65
+ if input_data_format == ChannelDimension.LAST:
66
+ patch = image[i : i + patch_size, j : j + patch_size]
67
+ else:
68
+ patch = image[:, i : i + patch_size, j : j + patch_size]
69
+ patches.append(patch)
70
+
71
+ return patches
72
+
73
+
74
+ class AriaImageProcessor(BaseImageProcessor):
75
+ """
76
+ A vision processor for the Aria model that handles image preprocessing.
77
+ Initialize the AriaImageProcessor.
78
+
79
+ Args:
80
+ image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
81
+ Mean values for normalization.
82
+ image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
83
+ Standard deviation values for normalization.
84
+ max_image_size (`int`, *optional*, defaults to 980):
85
+ Maximum image size.
86
+ min_image_size (`int`, *optional*, defaults to 336):
87
+ Minimum image size.
88
+ split_resolutions (`list`, *optional*, defaults to a list of optimal,resolutions as tuples):
89
+ The optimal resolutions for splitting the image.
90
+ split_image (`bool`, *optional*, defaults to `False`):
91
+ Whether to split the image.
92
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
93
+ Whether to convert the image to RGB.
94
+ do_rescale (`bool`, *optional*, defaults to `True`):
95
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
96
+ the `preprocess` method.
97
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
98
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
99
+ method.
100
+ do_normalize (`bool`, *optional*, defaults to `True`):
101
+ Whether to normalize the image.
102
+ resample (PILImageResampling, *optional*, defaults to `BICUBIC`):
103
+ The resampling filter to use if resizing the image.
104
+ """
105
+
106
+ model_input_names = ["pixel_values", "pixel_mask", "num_crops"]
107
+
108
+ def __init__(
109
+ self,
110
+ image_mean: Optional[list[float]] = None,
111
+ image_std: Optional[list[float]] = None,
112
+ max_image_size: int = 980,
113
+ min_image_size: int = 336,
114
+ split_resolutions: Optional[list[tuple[int, int]]] = None,
115
+ split_image: Optional[bool] = False,
116
+ do_convert_rgb: Optional[bool] = True,
117
+ do_rescale: bool = True,
118
+ rescale_factor: Union[int, float] = 1 / 255,
119
+ do_normalize: Optional[bool] = True,
120
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
121
+ **kwargs,
122
+ ):
123
+ super().__init__(**kwargs)
124
+
125
+ if image_mean is None:
126
+ image_mean = [0.5, 0.5, 0.5]
127
+ if image_std is None:
128
+ image_std = [0.5, 0.5, 0.5]
129
+ self.max_image_size = max_image_size
130
+ self.min_image_size = min_image_size
131
+ self.image_mean = image_mean
132
+ self.image_std = image_std
133
+ self.split_image = split_image
134
+ if split_resolutions is None:
135
+ split_resolutions = [(1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (2, 4), (2, 3), (2, 2), (2, 1), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1), (6, 1), (7, 1), (8, 1)] # fmt: skip
136
+ split_resolutions = [(el[0] * 490, el[1] * 490) for el in split_resolutions]
137
+ self.split_resolutions = split_resolutions
138
+ self.do_convert_rgb = do_convert_rgb
139
+ self.do_rescale = do_rescale
140
+ self.rescale_factor = rescale_factor
141
+ self.do_normalize = do_normalize
142
+ self.resample = resample
143
+
144
+ def preprocess(
145
+ self,
146
+ images: Union[ImageInput, list[ImageInput]],
147
+ image_mean: Optional[Union[float, list[float]]] = None,
148
+ image_std: Optional[Union[float, list[float]]] = None,
149
+ max_image_size: Optional[int] = None,
150
+ min_image_size: Optional[int] = None,
151
+ split_image: Optional[bool] = None,
152
+ do_convert_rgb: Optional[bool] = None,
153
+ do_rescale: Optional[bool] = None,
154
+ rescale_factor: Optional[float] = None,
155
+ do_normalize: Optional[bool] = None,
156
+ resample: Optional[PILImageResampling] = None,
157
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
158
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
159
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
160
+ ):
161
+ """
162
+ Process a list of images.
163
+
164
+ Args:
165
+ images (ImageInput or list of ImageInput):
166
+ The input image or a list of images.
167
+ image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
168
+ Mean values for normalization.
169
+ image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
170
+ Standard deviation values for normalization.
171
+ max_image_size (`int`, *optional*, defaults to `self.max_image_size` (980)):
172
+ Maximum image size.
173
+ min_image_size (`int`, *optional*, defaults to `self.min_image_size` (336)):
174
+ Minimum image size.
175
+ split_image (`bool`, *optional*, defaults to `self.split_image` (False)):
176
+ Whether to split the image.
177
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb` (True)):
178
+ Whether to convert the image to RGB.
179
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
180
+ Whether to rescale the image.
181
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
182
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
183
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize` (True)):
184
+ Whether to normalize the image.
185
+ resample (PILImageResampling, *optional*, defaults to `self.resample` (BICUBIC)):
186
+ The resampling filter to use if resizing the image.
187
+ return_tensors (`str` or `TensorType`, *optional*, defaults to "pt"):
188
+ The type of tensor to return.
189
+ data_format (`str` or `ChannelDimension`, *optional*):
190
+ The channel dimension format for the output image. Can be one of:
191
+ - `"channels_first"` or `ChannelDimension.FIRST`:
192
+ image in (num_channels, height, width) format.
193
+ - `"channels_last"` or `ChannelDimension.LAST`:
194
+ image in (height, width, num_channels) format.
195
+ If unset, will use same as the input image.
196
+ input_data_format (`str` or `ChannelDimension`, *optional*):
197
+ The channel dimension format for the input image. Can be one of:
198
+ - `"channels_first"` or `ChannelDimension.FIRST`:
199
+ image in (num_channels, height, width) format.
200
+ - `"channels_last"` or `ChannelDimension.LAST`:
201
+ image in (height, width, num_channels) format.
202
+ If unset, will use the inferred format of the input image.
203
+
204
+ Returns:
205
+ BatchFeature:
206
+ A BatchFeature object containing:
207
+ - 'pixel_values':
208
+ Tensor of processed image pixel values.
209
+ - 'pixel_mask':
210
+ Boolean pixel mask. This mask is a 2D tensor of shape (max_image_size, max_image_size) where:
211
+ - True (1) values indicate pixels that belong to the original resized image.
212
+ - False (0) values indicate pixels that are part of the padding.
213
+ The mask helps distinguish between actual image content and padded areas in subsequent processing steps.
214
+ - 'num_crops':
215
+ The maximum number of crops across all images.
216
+ """
217
+ image_mean = image_mean if image_mean is not None else self.image_mean
218
+ image_std = image_std if image_std is not None else self.image_std
219
+ max_image_size = max_image_size if max_image_size is not None else self.max_image_size
220
+ min_image_size = min_image_size if min_image_size is not None else self.min_image_size
221
+ split_image = split_image if split_image is not None else self.split_image
222
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
223
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
224
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
225
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
226
+ resample = resample if resample is not None else self.resample
227
+
228
+ if max_image_size not in [490, 980]:
229
+ raise ValueError("max_image_size must be either 490 or 980")
230
+
231
+ images = self.fetch_images(images)
232
+ images = make_flat_list_of_images(images)
233
+
234
+ if not valid_images(images):
235
+ raise ValueError(
236
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
237
+ "torch.Tensor, tf.Tensor or jax.ndarray."
238
+ )
239
+
240
+ validate_preprocess_arguments(
241
+ do_normalize=do_normalize,
242
+ image_mean=image_mean,
243
+ image_std=image_std,
244
+ resample=resample,
245
+ do_rescale=do_rescale,
246
+ rescale_factor=rescale_factor,
247
+ )
248
+
249
+ if do_convert_rgb:
250
+ images = [convert_to_rgb(image) for image in images]
251
+
252
+ # All transformations expect numpy arrays.
253
+ images = [to_numpy_array(image) for image in images]
254
+
255
+ if do_rescale and is_scaled_image(images[0]):
256
+ logger.warning_once(
257
+ "It looks like you are trying to rescale already rescaled images. If the input"
258
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
259
+ )
260
+
261
+ if input_data_format is None:
262
+ # We assume that all images have the same channel dimension format.
263
+ input_data_format = infer_channel_dimension_format(images[0])
264
+
265
+ pixel_values = []
266
+ pixel_masks = []
267
+ num_crops = None
268
+
269
+ for image in images:
270
+ if split_image:
271
+ crop_images = self.get_image_patches(
272
+ image,
273
+ self.split_resolutions,
274
+ max_image_size,
275
+ resample,
276
+ data_format=input_data_format,
277
+ input_data_format=input_data_format,
278
+ )
279
+ else:
280
+ crop_images = [image]
281
+ if num_crops is None or len(crop_images) > num_crops:
282
+ num_crops = len(crop_images)
283
+
284
+ for crop_image in crop_images:
285
+ # At this point the scale is the rescaling factor that would bring the image to max_size in its larger dimension
286
+ h, w = get_image_size(crop_image)
287
+ scale = max_image_size / max(h, w)
288
+ if w >= h:
289
+ new_size = (max(int(h * scale), min_image_size), max_image_size) # h, w
290
+ else:
291
+ new_size = (max_image_size, max(int(w * scale), min_image_size)) # h, w
292
+
293
+ crop_image_resized = resize(
294
+ crop_image,
295
+ new_size,
296
+ resample=resample,
297
+ data_format=input_data_format,
298
+ input_data_format=input_data_format,
299
+ )
300
+
301
+ padding_bottom, padding_right = max_image_size - new_size[0], max_image_size - new_size[1]
302
+ crop_image_padded = pad(
303
+ crop_image_resized,
304
+ ((0, padding_bottom), (0, padding_right)),
305
+ data_format=input_data_format,
306
+ input_data_format=input_data_format,
307
+ )
308
+
309
+ # Create a pixel mask
310
+ pixel_mask = np.zeros((max_image_size, max_image_size), dtype=bool)
311
+ pixel_mask[: new_size[0], : new_size[1]] = 1
312
+ pixel_masks.append(pixel_mask)
313
+
314
+ if do_rescale:
315
+ crop_image_padded = self.rescale(
316
+ image=crop_image_padded, scale=rescale_factor, input_data_format=input_data_format
317
+ )
318
+
319
+ if do_normalize:
320
+ crop_image_padded = self.normalize(
321
+ crop_image_padded,
322
+ self.image_mean,
323
+ self.image_std,
324
+ data_format=input_data_format,
325
+ input_data_format=input_data_format,
326
+ )
327
+ crop_image_padded = (
328
+ to_channel_dimension_format(crop_image_padded, data_format, input_data_format)
329
+ if data_format is not None
330
+ else crop_image_padded
331
+ )
332
+
333
+ pixel_values.append(crop_image_padded)
334
+ return BatchFeature(
335
+ data={
336
+ "pixel_values": np.stack(pixel_values, axis=0),
337
+ "pixel_mask": np.stack(pixel_masks, axis=0),
338
+ "num_crops": num_crops,
339
+ },
340
+ tensor_type=return_tensors,
341
+ )
342
+
343
+ def _resize_for_patching(
344
+ self, image: np.ndarray, target_resolution: tuple, resample, input_data_format: ChannelDimension
345
+ ) -> np.ndarray:
346
+ """
347
+ Resizes an image to a target resolution while maintaining aspect ratio.
348
+
349
+ Args:
350
+ image (np.ndarray):
351
+ The input image.
352
+ target_resolution (tuple):
353
+ The target resolution (height, width) of the image.
354
+ resample (`PILImageResampling`):
355
+ Resampling filter to use if resizing the image.
356
+ input_data_format (`ChannelDimension` or `str`):
357
+ The channel dimension format of the input image.
358
+
359
+ Returns:
360
+ np.ndarray: The resized and padded image.
361
+ """
362
+ new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
363
+
364
+ # Resize the image
365
+ resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
366
+
367
+ return resized_image
368
+
369
+ def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
370
+ original_height, original_width = original_resolution
371
+ target_height, target_width = target_resolution
372
+ paste_x, r_x = divmod(target_width - original_width, 2)
373
+ paste_y, r_y = divmod(target_height - original_height, 2)
374
+ return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
375
+
376
+ def _pad_for_patching(
377
+ self, image: np.ndarray, target_resolution: tuple, input_data_format: ChannelDimension
378
+ ) -> np.ndarray:
379
+ """
380
+ Pad an image to a target resolution while maintaining aspect ratio.
381
+ """
382
+ new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
383
+ padding = self._get_padding_size(new_resolution, target_resolution)
384
+
385
+ padded_image = self.pad(image, padding=padding)
386
+
387
+ return padded_image
388
+
389
+ def pad(
390
+ self,
391
+ image: np.ndarray,
392
+ padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]],
393
+ mode: PaddingMode = PaddingMode.CONSTANT,
394
+ constant_values: Union[float, Iterable[float]] = 0.0,
395
+ data_format: Optional[Union[str, ChannelDimension]] = None,
396
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
397
+ ) -> np.ndarray:
398
+ """
399
+ Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
400
+ dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
401
+ as input.
402
+
403
+ Args:
404
+ image (`np.ndarray`):
405
+ The image to pad.
406
+ padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
407
+ Padding to apply to the edges of the height, width axes. Can be one of three formats:
408
+ - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
409
+ - `((before, after),)` yields same before and after pad for height and width.
410
+ - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
411
+ mode (`PaddingMode`):
412
+ The padding mode to use. Can be one of:
413
+ - `"constant"`: pads with a constant value.
414
+ - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
415
+ vector along each axis.
416
+ - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
417
+ - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
418
+ constant_values (`float` or `Iterable[float]`, *optional*):
419
+ The value to use for the padding if `mode` is `"constant"`.
420
+ data_format (`str` or `ChannelDimension`, *optional*):
421
+ The channel dimension format for the output image. Can be one of:
422
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
423
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
424
+ If unset, will use same as the input image.
425
+ input_data_format (`str` or `ChannelDimension`, *optional*):
426
+ The channel dimension format for the input image. Can be one of:
427
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
428
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
429
+ If unset, will use the inferred format of the input image.
430
+
431
+ Returns:
432
+ `np.ndarray`: The padded image.
433
+
434
+ """
435
+
436
+ # call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
437
+ if isinstance(padding, int) or len(padding) != 4:
438
+ return pad(image, padding, mode, constant_values, data_format, input_data_format)
439
+
440
+ if input_data_format is None:
441
+ input_data_format = infer_channel_dimension_format(image)
442
+
443
+ padding_mode_mapping = {
444
+ PaddingMode.CONSTANT: "constant",
445
+ PaddingMode.REFLECT: "reflect",
446
+ PaddingMode.REPLICATE: "edge",
447
+ PaddingMode.SYMMETRIC: "symmetric",
448
+ }
449
+ image = np.pad(image, padding, mode=padding_mode_mapping[mode], constant_values=constant_values)
450
+ image = (
451
+ to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
452
+ )
453
+ return image
454
+
455
+ def get_image_patches(
456
+ self,
457
+ image: np.ndarray,
458
+ grid_pinpoints: list[tuple[int, int]],
459
+ patch_size: int,
460
+ resample: PILImageResampling,
461
+ data_format: ChannelDimension,
462
+ input_data_format: ChannelDimension,
463
+ ) -> list[np.ndarray]:
464
+ """
465
+ Process an image with variable resolutions by dividing it into patches.
466
+
467
+ Args:
468
+ image (`np.ndarray`):
469
+ The input image to be processed.
470
+ grid_pinpoints (list[tuple[int, int]]):
471
+ A list of possible resolutions as tuples.
472
+ patch_size (`int`):
473
+ Size of the patches to divide the image into.
474
+ resample (`PILImageResampling`):
475
+ Resampling filter to use if resizing the image.
476
+ data_format (`ChannelDimension` or `str`):
477
+ The channel dimension format for the output image.
478
+ input_data_format (`ChannelDimension` or `str`):
479
+ The channel dimension format of the input image.
480
+
481
+ Returns:
482
+ `list[np.ndarray]`: A list of NumPy arrays containing the processed image patches.
483
+ """
484
+ if not isinstance(grid_pinpoints, list):
485
+ raise TypeError("grid_pinpoints must be a list of possible resolutions.")
486
+
487
+ possible_resolutions = grid_pinpoints
488
+
489
+ image_size = get_image_size(image, channel_dim=input_data_format)
490
+ best_resolution = select_best_resolution(image_size, possible_resolutions)
491
+ resized_image = self._resize_for_patching(
492
+ image, best_resolution, resample=resample, input_data_format=input_data_format
493
+ )
494
+ padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
495
+
496
+ patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format)
497
+
498
+ # make sure that all patches are in the input data format
499
+ patches = [
500
+ to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format)
501
+ for patch in patches
502
+ ]
503
+ return patches
504
+
505
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
506
+ """
507
+ A utility that returns number of image patches for a given image size.
508
+
509
+ Args:
510
+ height (`int`):
511
+ Height of the input image.
512
+ width (`int`):
513
+ Width of the input image.
514
+ images_kwargs (`dict`, *optional*)
515
+ Any kwargs to override defaults of the image processor.
516
+ Returns:
517
+ `int`: Number of patches per image.
518
+ """
519
+ split_image = images_kwargs.get("split_image", self.split_image)
520
+ max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
521
+
522
+ resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
523
+ num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
524
+ return num_patches
525
+
526
+
527
+ __all__ = ["AriaImageProcessor"]
venv/lib/python3.13/site-packages/transformers/models/aria/modeling_aria.py ADDED
@@ -0,0 +1,1275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/aria/modular_aria.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_aria.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ from dataclasses import dataclass
22
+ from typing import Callable, Optional, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+
27
+ from ...activations import ACT2FN
28
+ from ...cache_utils import Cache, DynamicCache
29
+ from ...generation import GenerationMixin
30
+ from ...integrations import use_kernel_forward_from_hub
31
+ from ...masking_utils import create_causal_mask
32
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
33
+ from ...modeling_layers import GradientCheckpointingLayer
34
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
35
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
+ from ...processing_utils import Unpack
38
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
39
+ from ...utils.deprecation import deprecate_kwarg
40
+ from ...utils.generic import check_model_inputs
41
+ from ..auto import AutoModel
42
+ from .configuration_aria import AriaConfig, AriaTextConfig
43
+
44
+
45
+ @use_kernel_forward_from_hub("RMSNorm")
46
+ class AriaTextRMSNorm(nn.Module):
47
+ def __init__(self, hidden_size, eps=1e-6):
48
+ """
49
+ AriaTextRMSNorm is equivalent to T5LayerNorm
50
+ """
51
+ super().__init__()
52
+ self.weight = nn.Parameter(torch.ones(hidden_size))
53
+ self.variance_epsilon = eps
54
+
55
+ def forward(self, hidden_states):
56
+ input_dtype = hidden_states.dtype
57
+ hidden_states = hidden_states.to(torch.float32)
58
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
59
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
60
+ return self.weight * hidden_states.to(input_dtype)
61
+
62
+ def extra_repr(self):
63
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
64
+
65
+
66
+ class AriaProjectorMLP(nn.Module):
67
+ """
68
+ Feed-Forward Network module for the Aria Projector.
69
+
70
+ Args:
71
+ in_features (`int`):
72
+ Input embedding dimension.
73
+ hidden_features (`int`):
74
+ Hidden dimension of the feed-forward network.
75
+ output_dim (`int`):
76
+ Output dimension.
77
+ """
78
+
79
+ def __init__(self, in_features, hidden_features, output_dim):
80
+ super().__init__()
81
+ self.linear_in = nn.Linear(in_features, hidden_features, bias=False)
82
+ self.linear_out = nn.Linear(hidden_features, output_dim, bias=False)
83
+ self.act = ACT2FN["gelu_new"]
84
+
85
+ def forward(self, hidden_states):
86
+ hidden_states = self.act(self.linear_in(hidden_states))
87
+ hidden_states = self.linear_out(hidden_states)
88
+ return hidden_states
89
+
90
+
91
+ class AriaCrossAttention(nn.Module):
92
+ """
93
+ Aria Cross-Attention module.
94
+
95
+ Args:
96
+ config (`AriaConfig`):
97
+ The configuration to use.
98
+ """
99
+
100
+ def __init__(self, config: AriaConfig, dropout_rate: float = 0):
101
+ super().__init__()
102
+ hidden_size = config.vision_config.hidden_size
103
+ num_heads = config.vision_config.num_attention_heads
104
+ self.num_heads = num_heads
105
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
106
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
107
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
108
+
109
+ # Original code here: https://github.com/rhymes-ai/Aria/blob/719ff4e52b727443cba3793b0e27fe64e0244fe1/aria/model/projector.py#L48
110
+ self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
111
+ self.linear = nn.Linear(hidden_size, hidden_size)
112
+ self.dropout = nn.Dropout(dropout_rate)
113
+
114
+ self.layer_norm = nn.LayerNorm(hidden_size)
115
+ self.layer_norm_kv = nn.LayerNorm(hidden_size)
116
+
117
+ def forward(self, key_value_states, hidden_states, attn_mask=None):
118
+ """
119
+ Forward pass of the AriaCrossAttention module.
120
+
121
+ Args:
122
+ key_value_states (`torch.Tensor`):
123
+ Input tensor for key and value.
124
+ hidden_states (`torch.Tensor`):
125
+ Input tensor for query.
126
+ attn_mask (`torch.Tensor`, *optional*, defaults to None):
127
+ Attention mask.
128
+
129
+ Returns:
130
+ torch.Tensor:
131
+ Output tensor after cross-attention.
132
+ """
133
+ query = self.q_proj(self.layer_norm(hidden_states))
134
+
135
+ key_value_states = self.layer_norm_kv(key_value_states)
136
+ key = self.k_proj(key_value_states)
137
+ value = self.v_proj(key_value_states)
138
+
139
+ attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask)
140
+
141
+ attn_output = self.dropout(self.linear(attn_output))
142
+
143
+ return attn_output
144
+
145
+
146
+ class AriaProjector(nn.Module):
147
+ """
148
+ Aria Projector module.
149
+
150
+ This module projects vision features into the language model's embedding space, enabling interaction between vision and language components.
151
+
152
+ Args:
153
+ config (`AriaConfig`):
154
+ Configuration object for the model.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ config: AriaConfig,
160
+ ):
161
+ super().__init__()
162
+
163
+ self.patch_to_query_dict = config.projector_patch_to_query_dict
164
+ self.in_features = config.vision_config.hidden_size
165
+ self.num_heads = config.vision_config.num_attention_heads
166
+ self.kv_dim = config.vision_config.hidden_size
167
+ self.hidden_features = config.text_config.hidden_size
168
+ self.output_dim = config.text_config.hidden_size
169
+
170
+ self.query = nn.Parameter(torch.zeros(config.max_value_projector_patch_to_query_dict, self.in_features))
171
+
172
+ self.cross_attn = AriaCrossAttention(config)
173
+
174
+ self.layer_norm = nn.LayerNorm(self.in_features)
175
+ self.feed_forward = AriaProjectorMLP(self.in_features, self.hidden_features, self.output_dim)
176
+
177
+ def forward(self, key_value_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
178
+ """
179
+ Forward pass of the Projector module.
180
+
181
+ Args:
182
+ key_value_states (`torch.Tensor`):
183
+ Input tensor of shape (batch_size, num_patches, kv_dim).
184
+ attn_mask (`torch.Tensor`, *optional*, default is None):
185
+ Attention mask.
186
+
187
+ Returns:
188
+ `torch.Tensor`: Output tensor of shape (batch_size, query_number, output_dim).
189
+ """
190
+ batch_size, num_patches = key_value_states.shape[0], key_value_states.shape[1]
191
+
192
+ if num_patches not in self.patch_to_query_dict:
193
+ raise KeyError(
194
+ f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}."
195
+ )
196
+ query_num = self.patch_to_query_dict[num_patches]
197
+
198
+ queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
199
+
200
+ if attn_mask is not None:
201
+ attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
202
+ attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)
203
+
204
+ attention_out = self.cross_attn(key_value_states, queries, attn_mask=attn_mask)
205
+
206
+ out = self.feed_forward(self.layer_norm(attention_out))
207
+
208
+ return out
209
+
210
+
211
+ class AriaSharedExpertsMLP(nn.Module):
212
+ """
213
+ Shared Expert MLP for shared experts.
214
+
215
+ Unlike routed experts, shared experts process all tokens without routing.
216
+ This class reconfigures the intermediate size in comparison to the LlamaMLP.
217
+
218
+ Args:
219
+ config (`AriaTextConfig`): Configuration object for the Aria language model.
220
+ """
221
+
222
+ def __init__(self, config: AriaTextConfig):
223
+ super().__init__()
224
+ self.config = config
225
+ self.hidden_size = config.hidden_size
226
+ self.intermediate_size = config.intermediate_size * config.moe_num_shared_experts
227
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
228
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
229
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
230
+ self.act_fn = ACT2FN[config.hidden_act]
231
+
232
+ def forward(self, x):
233
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
234
+ return down_proj
235
+
236
+
237
+ def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
238
+ """
239
+ Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
240
+
241
+ Args:
242
+ token_states (torch.Tensor): Input tensor of shape (num_tokens, in_features).
243
+ expert_weights (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features).
244
+ tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
245
+
246
+ Returns:
247
+ torch.Tensor: Output tensor of shape (num_tokens, out_features).
248
+ """
249
+ num_tokens = token_states.shape[0]
250
+ out_features = expert_weights.shape[-1]
251
+ output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device)
252
+
253
+ cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
254
+ # Insert zero at the beginning for offset index's convenience
255
+ zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
256
+ cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
257
+
258
+ for expert_num in range(expert_weights.shape[0]):
259
+ start = cumsum_num_tokens[expert_num]
260
+ end = cumsum_num_tokens[expert_num + 1]
261
+ tokens = token_states[start:end]
262
+
263
+ out = torch.matmul(tokens, expert_weights[expert_num])
264
+ output[start:end] = out
265
+ return output
266
+
267
+
268
+ class AriaGroupedExpertsGemm(nn.Module):
269
+ """
270
+ Grouped GEMM (General Matrix Multiplication) module for efficient expert computation.
271
+ This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm)
272
+ for optimized performance. If the grouped_gemm library is not installed, it gracefully
273
+ falls back to a sequential GEMM implementation, which may be slower but ensures
274
+ functionality.
275
+
276
+ Args:
277
+ in_features (`int`):
278
+ Number of input features.
279
+ out_features (`int`):
280
+ Number of output features.
281
+ groups (`int`):
282
+ Number of expert groups.
283
+ """
284
+
285
+ def __init__(self, in_features, out_features, groups):
286
+ super().__init__()
287
+ self.in_features = in_features
288
+ self.out_features = out_features
289
+ self.groups = groups
290
+ self.weight = nn.Parameter(torch.empty(groups, in_features, out_features))
291
+
292
+ def forward(self, input, tokens_per_expert):
293
+ """
294
+ Perform grouped matrix multiplication.
295
+
296
+ Args:
297
+ input (`torch.Tensor`):
298
+ Input tensor of shape (num_tokens, in_features).
299
+ tokens_per_expert (`torch.Tensor`):
300
+ Number of tokens assigned to each expert.
301
+
302
+ Returns:
303
+ torch.Tensor: Output tensor of shape (num_tokens, out_features).
304
+ """
305
+ return sequential_experts_gemm(
306
+ input,
307
+ self.weight,
308
+ tokens_per_expert.cpu(),
309
+ )
310
+
311
+
312
+ class AriaGroupedExpertsMLP(nn.Module):
313
+ """
314
+ Grouped MLP module for Mixture of Experts.
315
+
316
+ Args:
317
+ config (`AriaTextConfig`):
318
+ Configuration object for the model.
319
+ """
320
+
321
+ def __init__(self, config: AriaTextConfig) -> None:
322
+ super().__init__()
323
+ self.config = config
324
+ self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts)
325
+ self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts)
326
+
327
+ def forward(self, permuted_tokens, tokens_per_expert):
328
+ """
329
+ Forward pass of the Grouped MLP.
330
+
331
+ Args:
332
+ permuted_tokens (torch.Tensor): Permuted input tokens.
333
+ tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
334
+
335
+ Returns:
336
+ torch.Tensor: Output tensor after passing through the MLP.
337
+ """
338
+ fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
339
+ projection, gate = torch.chunk(fc1_output, 2, dim=-1)
340
+ fc1_output = nn.functional.silu(projection) * gate
341
+ fc2_output = self.fc2(fc1_output, tokens_per_expert)
342
+ return fc2_output
343
+
344
+
345
+ # Token permutation adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587
346
+ class AriaTextMoELayer(nn.Module):
347
+ """
348
+ Aria Text Mixture of Experts (MoE) Layer.
349
+
350
+ This layer applies a gating mechanism to route input tokens to different experts.
351
+
352
+ Args:
353
+ config (`AriaTextConfig`):
354
+ Configuration object for the text component of the model.
355
+ """
356
+
357
+ def __init__(self, config: AriaTextConfig):
358
+ super().__init__()
359
+
360
+ self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
361
+ self.experts = AriaGroupedExpertsMLP(config)
362
+ self.shared_experts = AriaSharedExpertsMLP(config)
363
+ self.config = config
364
+
365
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
366
+ """
367
+ Forward pass of the MoE Layer.
368
+
369
+ Args:
370
+ hidden_states (`torch.Tensor`):
371
+ Input tensor of shape (batch_size, sequence_length, hidden_size).
372
+
373
+ Returns:
374
+ torch.Tensor: Output tensor after passing through the MoE layer.
375
+
376
+ Process:
377
+ 1. Route tokens to experts using the router.
378
+ 2. Permute tokens based on routing decisions.
379
+ 3. Process tokens through experts.
380
+ 4. Unpermute and combine expert outputs.
381
+ 5. Add shared expert output to the final result.
382
+ """
383
+ original_shape = hidden_states.shape
384
+ hidden_states = hidden_states.view(-1, hidden_states.size(-1))
385
+
386
+ # Top K Routing
387
+ logits = self.router(hidden_states)
388
+ top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
389
+ scores = nn.functional.softmax(top_logits, dim=-1)
390
+
391
+ original_dtype = top_indices.dtype
392
+
393
+ tokens_per_expert = torch.histc(
394
+ top_indices.flatten().to(torch.float32),
395
+ bins=self.config.moe_num_experts,
396
+ min=0,
397
+ max=self.config.moe_num_experts - 1,
398
+ ).to(original_dtype)
399
+ indices = top_indices
400
+
401
+ # Token permutation
402
+ flatten_indices = indices.view(-1)
403
+ sorted_indices = torch.argsort(flatten_indices)
404
+ permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk)
405
+
406
+ # Process through experts
407
+ expert_output = self.experts(permuted_tokens, tokens_per_expert)
408
+
409
+ # Token unpermutation
410
+ unpermuted_tokens = torch.zeros(
411
+ (scores.shape[0] * self.config.moe_topk, expert_output.size(1)),
412
+ dtype=expert_output.dtype,
413
+ device=expert_output.device,
414
+ )
415
+ unpermuted_tokens.index_copy_(0, sorted_indices, expert_output)
416
+ unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1))
417
+
418
+ output = (unpermuted_tokens * scores.unsqueeze(-1)).sum(dim=1).view(original_shape)
419
+
420
+ # Add shared expert output
421
+ shared_expert_output = self.shared_experts(hidden_states.view(original_shape))
422
+ return output + shared_expert_output
423
+
424
+
425
+ def rotate_half(x):
426
+ """Rotates half the hidden dims of the input."""
427
+ x1 = x[..., : x.shape[-1] // 2]
428
+ x2 = x[..., x.shape[-1] // 2 :]
429
+ return torch.cat((-x2, x1), dim=-1)
430
+
431
+
432
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
433
+ """Applies Rotary Position Embedding to the query and key tensors.
434
+
435
+ Args:
436
+ q (`torch.Tensor`): The query tensor.
437
+ k (`torch.Tensor`): The key tensor.
438
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
439
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
440
+ position_ids (`torch.Tensor`, *optional*):
441
+ Deprecated and unused.
442
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
443
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
444
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
445
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
446
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
447
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
448
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
449
+ Returns:
450
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
451
+ """
452
+ cos = cos.unsqueeze(unsqueeze_dim)
453
+ sin = sin.unsqueeze(unsqueeze_dim)
454
+ q_embed = (q * cos) + (rotate_half(q) * sin)
455
+ k_embed = (k * cos) + (rotate_half(k) * sin)
456
+ return q_embed, k_embed
457
+
458
+
459
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
460
+ """
461
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
462
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
463
+ """
464
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
465
+ if n_rep == 1:
466
+ return hidden_states
467
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
468
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
469
+
470
+
471
+ def eager_attention_forward(
472
+ module: nn.Module,
473
+ query: torch.Tensor,
474
+ key: torch.Tensor,
475
+ value: torch.Tensor,
476
+ attention_mask: Optional[torch.Tensor],
477
+ scaling: float,
478
+ dropout: float = 0.0,
479
+ **kwargs: Unpack[TransformersKwargs],
480
+ ):
481
+ key_states = repeat_kv(key, module.num_key_value_groups)
482
+ value_states = repeat_kv(value, module.num_key_value_groups)
483
+
484
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
485
+ if attention_mask is not None:
486
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
487
+ attn_weights = attn_weights + causal_mask
488
+
489
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
490
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
491
+ attn_output = torch.matmul(attn_weights, value_states)
492
+ attn_output = attn_output.transpose(1, 2).contiguous()
493
+
494
+ return attn_output, attn_weights
495
+
496
+
497
+ class AriaTextAttention(nn.Module):
498
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
499
+
500
+ def __init__(self, config: AriaTextConfig, layer_idx: int):
501
+ super().__init__()
502
+ self.config = config
503
+ self.layer_idx = layer_idx
504
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
505
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
506
+ self.scaling = self.head_dim**-0.5
507
+ self.attention_dropout = config.attention_dropout
508
+ self.is_causal = True
509
+
510
+ self.q_proj = nn.Linear(
511
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
512
+ )
513
+ self.k_proj = nn.Linear(
514
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
515
+ )
516
+ self.v_proj = nn.Linear(
517
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
518
+ )
519
+ self.o_proj = nn.Linear(
520
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
521
+ )
522
+
523
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
524
+ def forward(
525
+ self,
526
+ hidden_states: torch.Tensor,
527
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
528
+ attention_mask: Optional[torch.Tensor],
529
+ past_key_values: Optional[Cache] = None,
530
+ cache_position: Optional[torch.LongTensor] = None,
531
+ **kwargs: Unpack[TransformersKwargs],
532
+ ) -> tuple[torch.Tensor, torch.Tensor]:
533
+ input_shape = hidden_states.shape[:-1]
534
+ hidden_shape = (*input_shape, -1, self.head_dim)
535
+
536
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
537
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
538
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
539
+
540
+ cos, sin = position_embeddings
541
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
542
+
543
+ if past_key_values is not None:
544
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
545
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
546
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
547
+
548
+ attention_interface: Callable = eager_attention_forward
549
+ if self.config._attn_implementation != "eager":
550
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
551
+
552
+ attn_output, attn_weights = attention_interface(
553
+ self,
554
+ query_states,
555
+ key_states,
556
+ value_states,
557
+ attention_mask,
558
+ dropout=0.0 if not self.training else self.attention_dropout,
559
+ scaling=self.scaling,
560
+ **kwargs,
561
+ )
562
+
563
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
564
+ attn_output = self.o_proj(attn_output)
565
+ return attn_output, attn_weights
566
+
567
+
568
+ class AriaTextDecoderLayer(GradientCheckpointingLayer):
569
+ """
570
+ Aria Text Decoder Layer.
571
+
572
+ This class defines a single decoder layer in the language model, incorporating self-attention and Mixture of Experts (MoE) feed-forward network.
573
+
574
+ Args:
575
+ config (`AriaTextConfig`):
576
+ Configuration object for the text component of the model.
577
+ layer_idx (`int`):
578
+ Index of the layer.
579
+ """
580
+
581
+ def __init__(self, config: AriaTextConfig, layer_idx: int):
582
+ super().__init__()
583
+ self.hidden_size = config.hidden_size
584
+
585
+ self.self_attn = AriaTextAttention(config=config, layer_idx=layer_idx)
586
+ self.mlp = AriaTextMoELayer(config)
587
+ self.input_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
588
+ self.post_attention_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
589
+
590
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
591
+ def forward(
592
+ self,
593
+ hidden_states: torch.Tensor,
594
+ attention_mask: Optional[torch.Tensor] = None,
595
+ position_ids: Optional[torch.LongTensor] = None,
596
+ past_key_values: Optional[Cache] = None,
597
+ use_cache: Optional[bool] = False,
598
+ cache_position: Optional[torch.LongTensor] = None,
599
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
600
+ **kwargs: Unpack[TransformersKwargs],
601
+ ) -> torch.Tensor:
602
+ residual = hidden_states
603
+ hidden_states = self.input_layernorm(hidden_states)
604
+ # Self Attention
605
+ hidden_states, _ = self.self_attn(
606
+ hidden_states=hidden_states,
607
+ attention_mask=attention_mask,
608
+ position_ids=position_ids,
609
+ past_key_values=past_key_values,
610
+ use_cache=use_cache,
611
+ cache_position=cache_position,
612
+ position_embeddings=position_embeddings,
613
+ **kwargs,
614
+ )
615
+ hidden_states = residual + hidden_states
616
+
617
+ # Fully Connected
618
+ residual = hidden_states
619
+ hidden_states = self.post_attention_layernorm(hidden_states)
620
+ hidden_states = self.mlp(hidden_states)
621
+ hidden_states = residual + hidden_states
622
+ return hidden_states
623
+
624
+
625
+ @auto_docstring
626
+ class AriaTextPreTrainedModel(PreTrainedModel):
627
+ config: AriaTextConfig
628
+ base_model_prefix = "model"
629
+ _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
630
+ supports_gradient_checkpointing = True
631
+ _skip_keys_device_placement = "past_key_values"
632
+ _supports_flash_attn = True
633
+ _supports_sdpa = True
634
+
635
+ _supports_attention_backend = True
636
+ _can_record_outputs = {
637
+ "hidden_states": AriaTextDecoderLayer,
638
+ "attentions": AriaTextAttention,
639
+ }
640
+
641
+ def _init_weights(self, module):
642
+ super()._init_weights(module)
643
+ if isinstance(module, AriaGroupedExpertsGemm):
644
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
645
+
646
+
647
+ @auto_docstring
648
+ class AriaPreTrainedModel(PreTrainedModel):
649
+ config: AriaConfig
650
+ base_model_prefix = ""
651
+ supports_gradient_checkpointing = True
652
+ _no_split_modules = ["AriaDecoderLayer"]
653
+ _skip_keys_device_placement = ["past_key_values"]
654
+ _supports_flash_attn = True
655
+ _supports_sdpa = True
656
+ _supports_flex_attn = True
657
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
658
+ _supports_attention_backend = True
659
+ _can_record_outputs = {
660
+ "hidden_states": AriaTextDecoderLayer,
661
+ "attentions": AriaTextAttention,
662
+ }
663
+
664
+ def _init_weights(self, module):
665
+ super()._init_weights(module)
666
+ if isinstance(module, AriaProjector):
667
+ nn.init.trunc_normal_(module.query, std=self.config.initializer_range)
668
+
669
+
670
+ class AriaTextRotaryEmbedding(nn.Module):
671
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
672
+
673
+ def __init__(self, config: AriaTextConfig, device=None):
674
+ super().__init__()
675
+ # BC: "rope_type" was originally "type"
676
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
677
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
678
+ else:
679
+ self.rope_type = "default"
680
+ self.max_seq_len_cached = config.max_position_embeddings
681
+ self.original_max_seq_len = config.max_position_embeddings
682
+
683
+ self.config = config
684
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
685
+
686
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
687
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
688
+ self.original_inv_freq = self.inv_freq
689
+
690
+ @torch.no_grad()
691
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
692
+ def forward(self, x, position_ids):
693
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
694
+ position_ids_expanded = position_ids[:, None, :].float()
695
+
696
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
697
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
698
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
699
+ emb = torch.cat((freqs, freqs), dim=-1)
700
+ cos = emb.cos() * self.attention_scaling
701
+ sin = emb.sin() * self.attention_scaling
702
+
703
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
704
+
705
+
706
+ @auto_docstring
707
+ class AriaTextModel(AriaTextPreTrainedModel):
708
+ def __init__(self, config: AriaTextConfig):
709
+ super().__init__(config)
710
+ self.padding_idx = config.pad_token_id
711
+ self.vocab_size = config.vocab_size
712
+
713
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
714
+ self.layers = nn.ModuleList(
715
+ [AriaTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
716
+ )
717
+ self.norm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
718
+ self.rotary_emb = AriaTextRotaryEmbedding(config=config)
719
+ self.gradient_checkpointing = False
720
+
721
+ # Initialize weights and apply final processing
722
+ self.post_init()
723
+
724
+ @check_model_inputs()
725
+ @auto_docstring
726
+ def forward(
727
+ self,
728
+ input_ids: Optional[torch.LongTensor] = None,
729
+ attention_mask: Optional[torch.Tensor] = None,
730
+ position_ids: Optional[torch.LongTensor] = None,
731
+ past_key_values: Optional[Cache] = None,
732
+ inputs_embeds: Optional[torch.FloatTensor] = None,
733
+ cache_position: Optional[torch.LongTensor] = None,
734
+ use_cache: Optional[bool] = None,
735
+ **kwargs: Unpack[TransformersKwargs],
736
+ ) -> BaseModelOutputWithPast:
737
+ if (input_ids is None) ^ (inputs_embeds is not None):
738
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
739
+
740
+ if inputs_embeds is None:
741
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
742
+
743
+ if use_cache and past_key_values is None:
744
+ past_key_values = DynamicCache(config=self.config)
745
+
746
+ if cache_position is None:
747
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
748
+ cache_position: torch.Tensor = torch.arange(
749
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
750
+ )
751
+
752
+ if position_ids is None:
753
+ position_ids = cache_position.unsqueeze(0)
754
+
755
+ causal_mask = create_causal_mask(
756
+ config=self.config,
757
+ input_embeds=inputs_embeds,
758
+ attention_mask=attention_mask,
759
+ cache_position=cache_position,
760
+ past_key_values=past_key_values,
761
+ position_ids=position_ids,
762
+ )
763
+
764
+ hidden_states = inputs_embeds
765
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
766
+
767
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
768
+ hidden_states = decoder_layer(
769
+ hidden_states,
770
+ attention_mask=causal_mask,
771
+ position_ids=position_ids,
772
+ past_key_values=past_key_values,
773
+ cache_position=cache_position,
774
+ position_embeddings=position_embeddings,
775
+ **kwargs,
776
+ )
777
+
778
+ hidden_states = self.norm(hidden_states)
779
+ return BaseModelOutputWithPast(
780
+ last_hidden_state=hidden_states,
781
+ past_key_values=past_key_values,
782
+ )
783
+
784
+
785
+ @auto_docstring
786
+ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
787
+ _tied_weights_keys = ["lm_head.weight"]
788
+ _tp_plan = {"lm_head": "colwise_rep"}
789
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
790
+
791
+ def __init__(self, config: AriaTextConfig):
792
+ super().__init__(config)
793
+ self.model = AriaTextModel(config)
794
+ self.vocab_size = config.vocab_size
795
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
796
+
797
+ # Initialize weights and apply final processing
798
+ self.post_init()
799
+
800
+ @auto_docstring
801
+ def forward(
802
+ self,
803
+ input_ids: Optional[torch.LongTensor] = None,
804
+ attention_mask: Optional[torch.Tensor] = None,
805
+ position_ids: Optional[torch.LongTensor] = None,
806
+ past_key_values: Optional[Cache] = None,
807
+ inputs_embeds: Optional[torch.FloatTensor] = None,
808
+ labels: Optional[torch.LongTensor] = None,
809
+ use_cache: Optional[bool] = None,
810
+ cache_position: Optional[torch.LongTensor] = None,
811
+ logits_to_keep: Union[int, torch.Tensor] = 0,
812
+ **kwargs: Unpack[TransformersKwargs],
813
+ ) -> CausalLMOutputWithPast:
814
+ r"""
815
+ Example:
816
+
817
+ ```python
818
+ >>> from transformers import AutoTokenizer, AriaTextForCausalLM
819
+
820
+ >>> model = AriaTextForCausalLM.from_pretrained("meta-aria_text/AriaText-2-7b-hf")
821
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-aria_text/AriaText-2-7b-hf")
822
+
823
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
824
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
825
+
826
+ >>> # Generate
827
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
828
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
829
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
830
+ ```"""
831
+ outputs: BaseModelOutputWithPast = self.model(
832
+ input_ids=input_ids,
833
+ attention_mask=attention_mask,
834
+ position_ids=position_ids,
835
+ past_key_values=past_key_values,
836
+ inputs_embeds=inputs_embeds,
837
+ use_cache=use_cache,
838
+ cache_position=cache_position,
839
+ **kwargs,
840
+ )
841
+
842
+ hidden_states = outputs.last_hidden_state
843
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
844
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
845
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
846
+
847
+ loss = None
848
+ if labels is not None:
849
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
850
+
851
+ return CausalLMOutputWithPast(
852
+ loss=loss,
853
+ logits=logits,
854
+ past_key_values=outputs.past_key_values,
855
+ hidden_states=outputs.hidden_states,
856
+ attentions=outputs.attentions,
857
+ )
858
+
859
+
860
+ @dataclass
861
+ @auto_docstring(
862
+ custom_intro="""
863
+ Base class for Aria causal language model (or autoregressive) outputs.
864
+ """
865
+ )
866
+ class AriaCausalLMOutputWithPast(ModelOutput):
867
+ r"""
868
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
869
+ Language modeling loss (for next-token prediction).
870
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
871
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
872
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
873
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
874
+
875
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
876
+ `past_key_values` input) to speed up sequential decoding.
877
+ image_hidden_states (`torch.FloatTensor`, *optional*):
878
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
879
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
880
+ """
881
+
882
+ loss: Optional[torch.FloatTensor] = None
883
+ logits: Optional[torch.FloatTensor] = None
884
+ past_key_values: Optional[Cache] = None
885
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
886
+ attentions: Optional[tuple[torch.FloatTensor]] = None
887
+ image_hidden_states: Optional[torch.FloatTensor] = None
888
+
889
+
890
+ @dataclass
891
+ @auto_docstring(
892
+ custom_intro="""
893
+ Base class for Aria outputs, with hidden states and attentions.
894
+ """
895
+ )
896
+ class AriaModelOutputWithPast(BaseModelOutputWithPast):
897
+ r"""
898
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
899
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
900
+
901
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
902
+ `past_key_values` input) to speed up sequential decoding.
903
+ image_hidden_states (`torch.FloatTensor`, *optional*):
904
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
905
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
906
+ """
907
+
908
+ image_hidden_states: Optional[torch.FloatTensor] = None
909
+
910
+
911
+ @auto_docstring(
912
+ custom_intro="""
913
+ The Aria model which consists of a vision backbone and a language model, without a language modeling head.
914
+ """
915
+ )
916
+ class AriaModel(AriaPreTrainedModel):
917
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
918
+
919
+ def __init__(self, config: AriaConfig):
920
+ super().__init__(config)
921
+ self.vision_tower = AutoModel.from_config(config.vision_config)
922
+ self.multi_modal_projector = AriaProjector(config)
923
+ self.language_model = AutoModel.from_config(config.text_config)
924
+ self.post_init()
925
+
926
+ def get_input_embeddings(self):
927
+ return self.language_model.get_input_embeddings()
928
+
929
+ def set_input_embeddings(self, value):
930
+ self.language_model.set_input_embeddings(value)
931
+
932
+ def set_decoder(self, decoder):
933
+ self.language_model = decoder
934
+
935
+ def get_decoder(self):
936
+ return self.language_model
937
+
938
+ def get_image_features(
939
+ self,
940
+ pixel_values: torch.FloatTensor,
941
+ pixel_mask: Optional[torch.FloatTensor] = None,
942
+ vision_feature_layer: int = -1,
943
+ ):
944
+ """
945
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
946
+
947
+ Args:
948
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
949
+ The tensors corresponding to the input images.
950
+ pixel_mask (`torch.FloatTensor]`, *optional*):
951
+ The tensors corresponding to the input image mask.
952
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
953
+ The index of the layer to select the vision feature. If multiple indices are provided,
954
+ the vision feature of the corresponding indices will be concatenated to form the
955
+ vision features.
956
+ Returns:
957
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
958
+ """
959
+ vision_feature_layer = (
960
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
961
+ )
962
+ patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
963
+ image_outputs = self.vision_tower(
964
+ pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
965
+ )
966
+ image_attn_mask = None
967
+ if patch_attention_mask is not None:
968
+ flattened_mask = patch_attention_mask.flatten(1)
969
+ image_attn_mask = torch.logical_not(flattened_mask)
970
+
971
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
972
+ image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
973
+ return image_features
974
+
975
+ def get_placeholder_mask(
976
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
977
+ ):
978
+ """
979
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
980
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
981
+ """
982
+ if input_ids is None:
983
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
984
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
985
+ )
986
+ special_image_mask = special_image_mask.all(-1)
987
+ else:
988
+ special_image_mask = input_ids == self.config.image_token_id
989
+
990
+ n_image_tokens = special_image_mask.sum()
991
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
992
+ n_image_features = image_features.shape[0] * image_features.shape[1]
993
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
994
+ raise ValueError(
995
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
996
+ )
997
+ return special_image_mask
998
+
999
+ @can_return_tuple
1000
+ @auto_docstring
1001
+ def forward(
1002
+ self,
1003
+ input_ids: Optional[torch.LongTensor] = None,
1004
+ pixel_values: Optional[torch.FloatTensor] = None,
1005
+ pixel_mask: Optional[torch.LongTensor] = None,
1006
+ attention_mask: Optional[torch.Tensor] = None,
1007
+ position_ids: Optional[torch.LongTensor] = None,
1008
+ past_key_values: Optional[Cache] = None,
1009
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1010
+ use_cache: Optional[bool] = None,
1011
+ cache_position: Optional[torch.LongTensor] = None,
1012
+ **kwargs: Unpack[FlashAttentionKwargs],
1013
+ ) -> Union[tuple, AriaModelOutputWithPast]:
1014
+ if inputs_embeds is None:
1015
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1016
+
1017
+ # 2. Merge text and images
1018
+ if pixel_values is not None and inputs_embeds.shape[1] != 1:
1019
+ image_features = self.get_image_features(
1020
+ pixel_values=pixel_values,
1021
+ pixel_mask=pixel_mask,
1022
+ vision_feature_layer=self.config.vision_feature_layer,
1023
+ )
1024
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
1025
+ special_image_mask = self.get_placeholder_mask(
1026
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
1027
+ )
1028
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
1029
+
1030
+ outputs = self.language_model(
1031
+ attention_mask=attention_mask,
1032
+ position_ids=position_ids,
1033
+ past_key_values=past_key_values,
1034
+ inputs_embeds=inputs_embeds,
1035
+ use_cache=use_cache,
1036
+ cache_position=cache_position,
1037
+ **kwargs,
1038
+ )
1039
+
1040
+ return AriaModelOutputWithPast(
1041
+ last_hidden_state=outputs.last_hidden_state,
1042
+ past_key_values=outputs.past_key_values if use_cache else None,
1043
+ hidden_states=outputs.hidden_states,
1044
+ attentions=outputs.attentions,
1045
+ image_hidden_states=image_features if pixel_values is not None else None,
1046
+ )
1047
+
1048
+ def _create_patch_attention_mask(self, pixel_mask):
1049
+ if pixel_mask is None:
1050
+ return None
1051
+
1052
+ patches_subgrid = pixel_mask.unfold(
1053
+ dimension=1,
1054
+ size=self.vision_tower.config.patch_size,
1055
+ step=self.vision_tower.config.patch_size,
1056
+ )
1057
+ patches_subgrid = patches_subgrid.unfold(
1058
+ dimension=2,
1059
+ size=self.vision_tower.config.patch_size,
1060
+ step=self.vision_tower.config.patch_size,
1061
+ )
1062
+ return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
1063
+
1064
+
1065
+ @auto_docstring(
1066
+ custom_intro="""
1067
+ Aria model for conditional generation tasks.
1068
+
1069
+ This model combines a vision tower, a multi-modal projector, and a language model
1070
+ to perform tasks that involve both image and text inputs.
1071
+ """
1072
+ )
1073
+ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
1074
+ _checkpoint_conversion_mapping = {
1075
+ "^language_model.model": "model.language_model",
1076
+ "^vision_tower": "model.vision_tower",
1077
+ "^multi_modal_projector": "model.multi_modal_projector",
1078
+ "^language_model.lm_head": "lm_head",
1079
+ }
1080
+ _tied_weights_keys = ["lm_head.weight"]
1081
+
1082
+ def __init__(self, config: AriaConfig):
1083
+ super().__init__(config)
1084
+ self.model = AriaModel(config)
1085
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1086
+ self.post_init()
1087
+
1088
+ def get_input_embeddings(self):
1089
+ return self.model.get_input_embeddings()
1090
+
1091
+ def set_input_embeddings(self, value):
1092
+ self.model.set_input_embeddings(value)
1093
+
1094
+ def get_output_embeddings(self) -> nn.Module:
1095
+ return self.lm_head
1096
+
1097
+ def set_decoder(self, decoder):
1098
+ self.model.set_decoder(decoder)
1099
+
1100
+ def get_decoder(self):
1101
+ return self.model.get_decoder()
1102
+
1103
+ def get_image_features(
1104
+ self,
1105
+ pixel_values: torch.FloatTensor,
1106
+ pixel_mask: Optional[torch.FloatTensor] = None,
1107
+ vision_feature_layer: int = -1,
1108
+ ):
1109
+ return self.model.get_image_features(
1110
+ pixel_values=pixel_values,
1111
+ pixel_mask=pixel_mask,
1112
+ vision_feature_layer=vision_feature_layer,
1113
+ )
1114
+
1115
+ # Make modules available through conditional class for BC
1116
+ @property
1117
+ def language_model(self):
1118
+ return self.model.language_model
1119
+
1120
+ @property
1121
+ def vision_tower(self):
1122
+ return self.model.vision_tower
1123
+
1124
+ @property
1125
+ def multi_modal_projector(self):
1126
+ return self.model.multi_modal_projector
1127
+
1128
+ @can_return_tuple
1129
+ @auto_docstring
1130
+ def forward(
1131
+ self,
1132
+ input_ids: Optional[torch.LongTensor] = None,
1133
+ pixel_values: Optional[torch.FloatTensor] = None,
1134
+ pixel_mask: Optional[torch.LongTensor] = None,
1135
+ attention_mask: Optional[torch.Tensor] = None,
1136
+ position_ids: Optional[torch.LongTensor] = None,
1137
+ past_key_values: Optional[Cache] = None,
1138
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1139
+ labels: Optional[torch.LongTensor] = None,
1140
+ use_cache: Optional[bool] = None,
1141
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1142
+ cache_position: Optional[torch.LongTensor] = None,
1143
+ **kwargs: Unpack[TransformersKwargs],
1144
+ ) -> Union[tuple, AriaCausalLMOutputWithPast]:
1145
+ r"""
1146
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1147
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1148
+ config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `AriaForConditionalGeneration`).
1149
+ Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
1150
+ computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1151
+
1152
+ Example:
1153
+
1154
+ ```python
1155
+ >>> import requests
1156
+ >>> import torch
1157
+ >>> from PIL import Image
1158
+ >>> from io import BytesIO
1159
+
1160
+ >>> from transformers import AutoProcessor, AutoModel
1161
+ >>> from transformers.image_utils import load_image
1162
+
1163
+ >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
1164
+ >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
1165
+ >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
1166
+ >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
1167
+
1168
+ >>> processor = AutoProcessor.from_pretrained("Rhymes-AI/Aria")
1169
+ >>> model = AutoModel.from_pretrained("Rhymes-AI/Aria", dtype=torch.bfloat16, device_map="auto")
1170
+
1171
+ >>> # Create inputs
1172
+ >>> messages = [
1173
+ ... {
1174
+ ... "role": "user",
1175
+ ... "content": [
1176
+ ... {"type": "image"},
1177
+ ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
1178
+ ... {"type": "image"},
1179
+ ... {"type": "text", "text": "What can we see in this image?"},
1180
+ ... ]
1181
+ ... },
1182
+ ... {
1183
+ ... "role": "user",
1184
+ ... "content": [
1185
+ ... {"type": "image"},
1186
+ ... {"type": "text", "text": "In which city is that bridge located?"},
1187
+ ... ]
1188
+ ... }
1189
+ ... ]
1190
+
1191
+ >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
1192
+ >>> images = [[image1, image2], [image3]]
1193
+ >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
1194
+
1195
+ >>> # Generate
1196
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
1197
+ >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
1198
+
1199
+ >>> print(generated_texts[0])
1200
+ Assistant: There are buildings, trees, lights, and water visible in this image.
1201
+
1202
+ >>> print(generated_texts[1])
1203
+ Assistant: The bridge is in San Francisco.
1204
+ ```"""
1205
+ outputs = self.model(
1206
+ input_ids=input_ids,
1207
+ pixel_values=pixel_values,
1208
+ pixel_mask=pixel_mask,
1209
+ attention_mask=attention_mask,
1210
+ position_ids=position_ids,
1211
+ past_key_values=past_key_values,
1212
+ inputs_embeds=inputs_embeds,
1213
+ use_cache=use_cache,
1214
+ cache_position=cache_position,
1215
+ **kwargs,
1216
+ )
1217
+
1218
+ hidden_states = outputs[0]
1219
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1220
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1221
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1222
+
1223
+ loss = None
1224
+ if labels is not None:
1225
+ loss = self.loss_function(
1226
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
1227
+ )
1228
+
1229
+ return AriaCausalLMOutputWithPast(
1230
+ loss=loss,
1231
+ logits=logits,
1232
+ past_key_values=outputs.past_key_values,
1233
+ hidden_states=outputs.hidden_states,
1234
+ attentions=outputs.attentions,
1235
+ )
1236
+
1237
+ def prepare_inputs_for_generation(
1238
+ self,
1239
+ input_ids,
1240
+ past_key_values=None,
1241
+ inputs_embeds=None,
1242
+ pixel_values=None,
1243
+ pixel_mask=None,
1244
+ attention_mask=None,
1245
+ cache_position=None,
1246
+ logits_to_keep=None,
1247
+ **kwargs,
1248
+ ):
1249
+ model_inputs = super().prepare_inputs_for_generation(
1250
+ input_ids,
1251
+ past_key_values=past_key_values,
1252
+ inputs_embeds=inputs_embeds,
1253
+ attention_mask=attention_mask,
1254
+ cache_position=cache_position,
1255
+ logits_to_keep=logits_to_keep,
1256
+ **kwargs,
1257
+ )
1258
+
1259
+ if cache_position[0] == 0:
1260
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
1261
+ # Otherwise we need pixel values to be passed to model
1262
+ model_inputs["pixel_values"] = pixel_values
1263
+ model_inputs["pixel_mask"] = pixel_mask
1264
+
1265
+ return model_inputs
1266
+
1267
+
1268
+ __all__ = [
1269
+ "AriaForConditionalGeneration",
1270
+ "AriaPreTrainedModel",
1271
+ "AriaTextPreTrainedModel",
1272
+ "AriaTextModel",
1273
+ "AriaModel",
1274
+ "AriaTextForCausalLM",
1275
+ ]
venv/lib/python3.13/site-packages/transformers/models/aria/modular_aria.py ADDED
@@ -0,0 +1,1610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from collections.abc import Iterable
16
+ from typing import Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from torch import nn
21
+
22
+ from ...activations import ACT2FN
23
+ from ...cache_utils import Cache
24
+ from ...configuration_utils import PretrainedConfig
25
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
26
+ from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
27
+ from ...image_utils import (
28
+ ChannelDimension,
29
+ ImageInput,
30
+ PILImageResampling,
31
+ get_image_size,
32
+ infer_channel_dimension_format,
33
+ is_scaled_image,
34
+ make_flat_list_of_images,
35
+ to_numpy_array,
36
+ valid_images,
37
+ validate_preprocess_arguments,
38
+ )
39
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
40
+ from ...modeling_utils import PreTrainedModel
41
+ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
42
+ from ...tokenization_utils import PreTokenizedInput, TextInput
43
+ from ...utils import TensorType, TransformersKwargs, auto_docstring, can_return_tuple, logging
44
+ from ..auto import CONFIG_MAPPING, AutoConfig, AutoTokenizer
45
+ from ..llama.configuration_llama import LlamaConfig
46
+ from ..llama.modeling_llama import (
47
+ LlamaAttention,
48
+ LlamaDecoderLayer,
49
+ LlamaForCausalLM,
50
+ LlamaMLP,
51
+ LlamaModel,
52
+ LlamaPreTrainedModel,
53
+ LlamaRMSNorm,
54
+ )
55
+ from ..llava.modeling_llava import (
56
+ LlavaCausalLMOutputWithPast,
57
+ LlavaForConditionalGeneration,
58
+ LlavaModel,
59
+ LlavaModelOutputWithPast,
60
+ )
61
+ from ..llava_next.image_processing_llava_next import divide_to_patches
62
+
63
+
64
+ logger = logging.get_logger(__name__)
65
+
66
+
67
+ def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
68
+ """
69
+ Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
70
+
71
+ Args:
72
+ token_states (torch.Tensor): Input tensor of shape (num_tokens, in_features).
73
+ expert_weights (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features).
74
+ tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
75
+
76
+ Returns:
77
+ torch.Tensor: Output tensor of shape (num_tokens, out_features).
78
+ """
79
+ num_tokens = token_states.shape[0]
80
+ out_features = expert_weights.shape[-1]
81
+ output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device)
82
+
83
+ cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
84
+ # Insert zero at the beginning for offset index's convenience
85
+ zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
86
+ cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
87
+
88
+ for expert_num in range(expert_weights.shape[0]):
89
+ start = cumsum_num_tokens[expert_num]
90
+ end = cumsum_num_tokens[expert_num + 1]
91
+ tokens = token_states[start:end]
92
+
93
+ out = torch.matmul(tokens, expert_weights[expert_num])
94
+ output[start:end] = out
95
+ return output
96
+
97
+
98
+ class AriaTextConfig(LlamaConfig):
99
+ r"""
100
+ This class handles the configuration for the text component of the Aria model.
101
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
102
+ [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
103
+ This class extends the LlamaConfig to include additional parameters specific to the Mixture of Experts (MoE) architecture.
104
+
105
+ Args:
106
+ vocab_size (`int`, *optional*, defaults to 32000):
107
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
108
+ `inputs_ids` passed when calling [`LlamaModel`]
109
+ hidden_size (`int`, *optional*, defaults to 4096):
110
+ Dimension of the hidden representations.
111
+ intermediate_size (`int`, *optional*, defaults to 4096):
112
+ The size of the MLP representations.
113
+ num_hidden_layers (`int`, *optional*, defaults to 32):
114
+ Number of hidden layers in the Transformer decoder.
115
+ num_attention_heads (`int`, *optional*, defaults to 32):
116
+ Number of attention heads for each attention layer in the Transformer decoder.
117
+ num_key_value_heads (`int`, *optional*):
118
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
119
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
120
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
121
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
122
+ by meanpooling all the original heads within that group. For more details, check out [this
123
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
124
+ `num_attention_heads`.
125
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
126
+ The non-linear activation function (function or string) in the decoder.
127
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
128
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
129
+ Llama 2 up to 4096, CodeLlama up to 16384.
130
+ initializer_range (`float`, *optional*, defaults to 0.02):
131
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
132
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
133
+ The epsilon used by the rms normalization layers.
134
+ use_cache (`bool`, *optional*, defaults to `True`):
135
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
136
+ relevant if `config.is_decoder=True`.
137
+ pad_token_id (`int`, *optional*, defaults to 2):
138
+ Padding token id.
139
+ bos_token_id (`int`, *optional*, defaults to 1):
140
+ Beginning of stream token id.
141
+ eos_token_id (`int`, *optional*, defaults to 2):
142
+ End of stream token id.
143
+ pretraining_tp (`int`, *optional*, defaults to 1):
144
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
145
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
146
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
147
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
148
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
149
+ Whether to tie weight embeddings
150
+ rope_theta (`float`, *optional*, defaults to 10000.0):
151
+ The base period of the RoPE embeddings.
152
+ rope_scaling (`Dict`, *optional*):
153
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
154
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
155
+ accordingly.
156
+ Expected contents:
157
+ `rope_type` (`str`):
158
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
159
+ 'llama3'], with 'default' being the original RoPE implementation.
160
+ `factor` (`float`, *optional*):
161
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
162
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
163
+ original maximum pre-trained length.
164
+ `original_max_position_embeddings` (`int`, *optional*):
165
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
166
+ pretraining.
167
+ `attention_factor` (`float`, *optional*):
168
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
169
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
170
+ `factor` field to infer the suggested value.
171
+ `beta_fast` (`float`, *optional*):
172
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
173
+ ramp function. If unspecified, it defaults to 32.
174
+ `beta_slow` (`float`, *optional*):
175
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
176
+ ramp function. If unspecified, it defaults to 1.
177
+ `short_factor` (`list[float]`, *optional*):
178
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
179
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
180
+ size divided by the number of attention heads divided by 2
181
+ `long_factor` (`list[float]`, *optional*):
182
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
183
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
184
+ size divided by the number of attention heads divided by 2
185
+ `low_freq_factor` (`float`, *optional*):
186
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
187
+ `high_freq_factor` (`float`, *optional*):
188
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
189
+ attention_bias (`bool`, *optional*, defaults to `False`):
190
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
191
+ attention_dropout (`float`, *optional*, defaults to 0.0):
192
+ The dropout ratio for the attention probabilities.
193
+ mlp_bias (`bool`, *optional*, defaults to `False`):
194
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
195
+ head_dim (`int`, *optional*):
196
+ The attention head dimension. If None, it will default to hidden_size // num_heads
197
+ moe_num_experts (`int`, *optional*, defaults to 8):
198
+ The number of experts in the MoE layer.
199
+ moe_topk (`int`, *optional*, defaults to 2):
200
+ The number of top experts to route to for each token.
201
+ moe_num_shared_experts (`int`, *optional*, defaults to 2):
202
+ The number of shared experts.
203
+ """
204
+
205
+ model_type = "aria_text"
206
+ base_config_key = "text_config"
207
+
208
+ def __init__(
209
+ self,
210
+ intermediate_size: int = 4096,
211
+ moe_num_experts: int = 8,
212
+ moe_topk: int = 2,
213
+ moe_num_shared_experts: int = 2,
214
+ pad_token_id=2,
215
+ **super_kwargs,
216
+ ):
217
+ super().__init__(pad_token_id=pad_token_id, **super_kwargs)
218
+ self.intermediate_size = intermediate_size
219
+ self.moe_num_experts = moe_num_experts
220
+ self.moe_topk = moe_topk
221
+ self.moe_num_shared_experts = moe_num_shared_experts
222
+
223
+
224
+ class AriaConfig(PretrainedConfig):
225
+ r"""
226
+ This class handles the configuration for both vision and text components of the Aria model,
227
+ as well as additional parameters for image token handling and projector mapping.
228
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
229
+ [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
230
+
231
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
232
+ documentation from [`PretrainedConfig`] for more information.
233
+
234
+ Args:
235
+ vision_config (`AriaVisionConfig` or `dict`, *optional*):
236
+ Configuration for the vision component.
237
+ vision_feature_layer (`int`, *optional*, defaults to -1):
238
+ The index of the layer to select the vision feature.
239
+ text_config (`AriaTextConfig` or `dict`, *optional*):
240
+ Configuration for the text component.
241
+ projector_patch_to_query_dict (`dict`, *optional*):
242
+ Mapping of patch sizes to query dimensions.
243
+ image_token_index (`int`, *optional*, defaults to 9):
244
+ Index used to represent image tokens.
245
+ initializer_range (`float`, *optional*, defaults to 0.02):
246
+ The standard deviation of the truncated normal initializer for initializing all weight matrices.
247
+
248
+ Attributes:
249
+ model_type (`str`):
250
+ Type of the model, set to `"aria"`.
251
+ image_token_index (`int`):
252
+ Index used to represent image tokens.
253
+ projector_patch_to_query_dict (`dict`):
254
+ Mapping of patch sizes to query dimensions.
255
+ vision_config (`AriaVisionConfig`):
256
+ Configuration for the vision component.
257
+ text_config (`AriaTextConfig`):
258
+ Configuration for the text component.
259
+ """
260
+
261
+ model_type = "aria"
262
+ attribute_map = {
263
+ "image_token_id": "image_token_index",
264
+ }
265
+ sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig}
266
+
267
+ def __init__(
268
+ self,
269
+ vision_config=None,
270
+ vision_feature_layer: int = -1,
271
+ text_config: AriaTextConfig = None,
272
+ projector_patch_to_query_dict: Optional[dict] = None,
273
+ image_token_index: int = 9,
274
+ initializer_range: float = 0.02,
275
+ **kwargs,
276
+ ):
277
+ self.image_token_index = image_token_index
278
+
279
+ # Convert the keys and values of projector_patch_to_query_dict to integers
280
+ # This ensures consistency even if they were provided as strings
281
+ if projector_patch_to_query_dict is None:
282
+ projector_patch_to_query_dict = {
283
+ 1225: 128,
284
+ 4900: 256,
285
+ }
286
+ self.projector_patch_to_query_dict = {int(k): int(v) for k, v in projector_patch_to_query_dict.items()}
287
+ self.max_value_projector_patch_to_query_dict = max(self.projector_patch_to_query_dict.values())
288
+ self.vision_feature_layer = vision_feature_layer
289
+ if isinstance(vision_config, dict):
290
+ vision_config["model_type"] = "idefics3_vision"
291
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
292
+ elif vision_config is None:
293
+ vision_config = CONFIG_MAPPING["idefics3_vision"]()
294
+
295
+ self.vision_config = vision_config
296
+ self.initializer_range = initializer_range
297
+
298
+ if isinstance(text_config, dict) and "model_type" in text_config:
299
+ text_config = AriaTextConfig(**text_config)
300
+ elif text_config is None:
301
+ text_config = AriaTextConfig()
302
+
303
+ self.text_config = text_config
304
+
305
+ super().__init__(**kwargs)
306
+
307
+
308
+ class AriaTextRMSNorm(LlamaRMSNorm):
309
+ pass
310
+
311
+
312
+ class AriaProjectorMLP(nn.Module):
313
+ """
314
+ Feed-Forward Network module for the Aria Projector.
315
+
316
+ Args:
317
+ in_features (`int`):
318
+ Input embedding dimension.
319
+ hidden_features (`int`):
320
+ Hidden dimension of the feed-forward network.
321
+ output_dim (`int`):
322
+ Output dimension.
323
+ """
324
+
325
+ def __init__(self, in_features, hidden_features, output_dim):
326
+ super().__init__()
327
+ self.linear_in = nn.Linear(in_features, hidden_features, bias=False)
328
+ self.linear_out = nn.Linear(hidden_features, output_dim, bias=False)
329
+ self.act = ACT2FN["gelu_new"]
330
+
331
+ def forward(self, hidden_states):
332
+ hidden_states = self.act(self.linear_in(hidden_states))
333
+ hidden_states = self.linear_out(hidden_states)
334
+ return hidden_states
335
+
336
+
337
+ class AriaCrossAttention(nn.Module):
338
+ """
339
+ Aria Cross-Attention module.
340
+
341
+ Args:
342
+ config (`AriaConfig`):
343
+ The configuration to use.
344
+ """
345
+
346
+ def __init__(self, config: AriaConfig, dropout_rate: float = 0):
347
+ super().__init__()
348
+ hidden_size = config.vision_config.hidden_size
349
+ num_heads = config.vision_config.num_attention_heads
350
+ self.num_heads = num_heads
351
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
352
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
353
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
354
+
355
+ # Original code here: https://github.com/rhymes-ai/Aria/blob/719ff4e52b727443cba3793b0e27fe64e0244fe1/aria/model/projector.py#L48
356
+ self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
357
+ self.linear = nn.Linear(hidden_size, hidden_size)
358
+ self.dropout = nn.Dropout(dropout_rate)
359
+
360
+ self.layer_norm = nn.LayerNorm(hidden_size)
361
+ self.layer_norm_kv = nn.LayerNorm(hidden_size)
362
+
363
+ def forward(self, key_value_states, hidden_states, attn_mask=None):
364
+ """
365
+ Forward pass of the AriaCrossAttention module.
366
+
367
+ Args:
368
+ key_value_states (`torch.Tensor`):
369
+ Input tensor for key and value.
370
+ hidden_states (`torch.Tensor`):
371
+ Input tensor for query.
372
+ attn_mask (`torch.Tensor`, *optional*, defaults to None):
373
+ Attention mask.
374
+
375
+ Returns:
376
+ torch.Tensor:
377
+ Output tensor after cross-attention.
378
+ """
379
+ query = self.q_proj(self.layer_norm(hidden_states))
380
+
381
+ key_value_states = self.layer_norm_kv(key_value_states)
382
+ key = self.k_proj(key_value_states)
383
+ value = self.v_proj(key_value_states)
384
+
385
+ attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask)
386
+
387
+ attn_output = self.dropout(self.linear(attn_output))
388
+
389
+ return attn_output
390
+
391
+
392
+ class AriaProjector(nn.Module):
393
+ """
394
+ Aria Projector module.
395
+
396
+ This module projects vision features into the language model's embedding space, enabling interaction between vision and language components.
397
+
398
+ Args:
399
+ config (`AriaConfig`):
400
+ Configuration object for the model.
401
+ """
402
+
403
+ def __init__(
404
+ self,
405
+ config: AriaConfig,
406
+ ):
407
+ super().__init__()
408
+
409
+ self.patch_to_query_dict = config.projector_patch_to_query_dict
410
+ self.in_features = config.vision_config.hidden_size
411
+ self.num_heads = config.vision_config.num_attention_heads
412
+ self.kv_dim = config.vision_config.hidden_size
413
+ self.hidden_features = config.text_config.hidden_size
414
+ self.output_dim = config.text_config.hidden_size
415
+
416
+ self.query = nn.Parameter(torch.zeros(config.max_value_projector_patch_to_query_dict, self.in_features))
417
+
418
+ self.cross_attn = AriaCrossAttention(config)
419
+
420
+ self.layer_norm = nn.LayerNorm(self.in_features)
421
+ self.feed_forward = AriaProjectorMLP(self.in_features, self.hidden_features, self.output_dim)
422
+
423
+ def forward(self, key_value_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
424
+ """
425
+ Forward pass of the Projector module.
426
+
427
+ Args:
428
+ key_value_states (`torch.Tensor`):
429
+ Input tensor of shape (batch_size, num_patches, kv_dim).
430
+ attn_mask (`torch.Tensor`, *optional*, default is None):
431
+ Attention mask.
432
+
433
+ Returns:
434
+ `torch.Tensor`: Output tensor of shape (batch_size, query_number, output_dim).
435
+ """
436
+ batch_size, num_patches = key_value_states.shape[0], key_value_states.shape[1]
437
+
438
+ if num_patches not in self.patch_to_query_dict:
439
+ raise KeyError(
440
+ f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}."
441
+ )
442
+ query_num = self.patch_to_query_dict[num_patches]
443
+
444
+ queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
445
+
446
+ if attn_mask is not None:
447
+ attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
448
+ attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)
449
+
450
+ attention_out = self.cross_attn(key_value_states, queries, attn_mask=attn_mask)
451
+
452
+ out = self.feed_forward(self.layer_norm(attention_out))
453
+
454
+ return out
455
+
456
+
457
+ class AriaImageProcessor(BaseImageProcessor):
458
+ """
459
+ A vision processor for the Aria model that handles image preprocessing.
460
+ Initialize the AriaImageProcessor.
461
+
462
+ Args:
463
+ image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
464
+ Mean values for normalization.
465
+ image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
466
+ Standard deviation values for normalization.
467
+ max_image_size (`int`, *optional*, defaults to 980):
468
+ Maximum image size.
469
+ min_image_size (`int`, *optional*, defaults to 336):
470
+ Minimum image size.
471
+ split_resolutions (`list`, *optional*, defaults to a list of optimal,resolutions as tuples):
472
+ The optimal resolutions for splitting the image.
473
+ split_image (`bool`, *optional*, defaults to `False`):
474
+ Whether to split the image.
475
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
476
+ Whether to convert the image to RGB.
477
+ do_rescale (`bool`, *optional*, defaults to `True`):
478
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
479
+ the `preprocess` method.
480
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
481
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
482
+ method.
483
+ do_normalize (`bool`, *optional*, defaults to `True`):
484
+ Whether to normalize the image.
485
+ resample (PILImageResampling, *optional*, defaults to `BICUBIC`):
486
+ The resampling filter to use if resizing the image.
487
+ """
488
+
489
+ model_input_names = ["pixel_values", "pixel_mask", "num_crops"]
490
+
491
+ def __init__(
492
+ self,
493
+ image_mean: Optional[list[float]] = None,
494
+ image_std: Optional[list[float]] = None,
495
+ max_image_size: int = 980,
496
+ min_image_size: int = 336,
497
+ split_resolutions: Optional[list[tuple[int, int]]] = None,
498
+ split_image: Optional[bool] = False,
499
+ do_convert_rgb: Optional[bool] = True,
500
+ do_rescale: bool = True,
501
+ rescale_factor: Union[int, float] = 1 / 255,
502
+ do_normalize: Optional[bool] = True,
503
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
504
+ **kwargs,
505
+ ):
506
+ super().__init__(**kwargs)
507
+
508
+ if image_mean is None:
509
+ image_mean = [0.5, 0.5, 0.5]
510
+ if image_std is None:
511
+ image_std = [0.5, 0.5, 0.5]
512
+ self.max_image_size = max_image_size
513
+ self.min_image_size = min_image_size
514
+ self.image_mean = image_mean
515
+ self.image_std = image_std
516
+ self.split_image = split_image
517
+ if split_resolutions is None:
518
+ split_resolutions = [(1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (2, 4), (2, 3), (2, 2), (2, 1), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1), (6, 1), (7, 1), (8, 1)] # fmt: skip
519
+ split_resolutions = [(el[0] * 490, el[1] * 490) for el in split_resolutions]
520
+ self.split_resolutions = split_resolutions
521
+ self.do_convert_rgb = do_convert_rgb
522
+ self.do_rescale = do_rescale
523
+ self.rescale_factor = rescale_factor
524
+ self.do_normalize = do_normalize
525
+ self.resample = resample
526
+
527
+ def preprocess(
528
+ self,
529
+ images: Union[ImageInput, list[ImageInput]],
530
+ image_mean: Optional[Union[float, list[float]]] = None,
531
+ image_std: Optional[Union[float, list[float]]] = None,
532
+ max_image_size: Optional[int] = None,
533
+ min_image_size: Optional[int] = None,
534
+ split_image: Optional[bool] = None,
535
+ do_convert_rgb: Optional[bool] = None,
536
+ do_rescale: Optional[bool] = None,
537
+ rescale_factor: Optional[float] = None,
538
+ do_normalize: Optional[bool] = None,
539
+ resample: Optional[PILImageResampling] = None,
540
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
541
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
542
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
543
+ ):
544
+ """
545
+ Process a list of images.
546
+
547
+ Args:
548
+ images (ImageInput or list of ImageInput):
549
+ The input image or a list of images.
550
+ image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
551
+ Mean values for normalization.
552
+ image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
553
+ Standard deviation values for normalization.
554
+ max_image_size (`int`, *optional*, defaults to `self.max_image_size` (980)):
555
+ Maximum image size.
556
+ min_image_size (`int`, *optional*, defaults to `self.min_image_size` (336)):
557
+ Minimum image size.
558
+ split_image (`bool`, *optional*, defaults to `self.split_image` (False)):
559
+ Whether to split the image.
560
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb` (True)):
561
+ Whether to convert the image to RGB.
562
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
563
+ Whether to rescale the image.
564
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
565
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
566
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize` (True)):
567
+ Whether to normalize the image.
568
+ resample (PILImageResampling, *optional*, defaults to `self.resample` (BICUBIC)):
569
+ The resampling filter to use if resizing the image.
570
+ return_tensors (`str` or `TensorType`, *optional*, defaults to "pt"):
571
+ The type of tensor to return.
572
+ data_format (`str` or `ChannelDimension`, *optional*):
573
+ The channel dimension format for the output image. Can be one of:
574
+ - `"channels_first"` or `ChannelDimension.FIRST`:
575
+ image in (num_channels, height, width) format.
576
+ - `"channels_last"` or `ChannelDimension.LAST`:
577
+ image in (height, width, num_channels) format.
578
+ If unset, will use same as the input image.
579
+ input_data_format (`str` or `ChannelDimension`, *optional*):
580
+ The channel dimension format for the input image. Can be one of:
581
+ - `"channels_first"` or `ChannelDimension.FIRST`:
582
+ image in (num_channels, height, width) format.
583
+ - `"channels_last"` or `ChannelDimension.LAST`:
584
+ image in (height, width, num_channels) format.
585
+ If unset, will use the inferred format of the input image.
586
+
587
+ Returns:
588
+ BatchFeature:
589
+ A BatchFeature object containing:
590
+ - 'pixel_values':
591
+ Tensor of processed image pixel values.
592
+ - 'pixel_mask':
593
+ Boolean pixel mask. This mask is a 2D tensor of shape (max_image_size, max_image_size) where:
594
+ - True (1) values indicate pixels that belong to the original resized image.
595
+ - False (0) values indicate pixels that are part of the padding.
596
+ The mask helps distinguish between actual image content and padded areas in subsequent processing steps.
597
+ - 'num_crops':
598
+ The maximum number of crops across all images.
599
+ """
600
+ image_mean = image_mean if image_mean is not None else self.image_mean
601
+ image_std = image_std if image_std is not None else self.image_std
602
+ max_image_size = max_image_size if max_image_size is not None else self.max_image_size
603
+ min_image_size = min_image_size if min_image_size is not None else self.min_image_size
604
+ split_image = split_image if split_image is not None else self.split_image
605
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
606
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
607
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
608
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
609
+ resample = resample if resample is not None else self.resample
610
+
611
+ if max_image_size not in [490, 980]:
612
+ raise ValueError("max_image_size must be either 490 or 980")
613
+
614
+ images = self.fetch_images(images)
615
+ images = make_flat_list_of_images(images)
616
+
617
+ if not valid_images(images):
618
+ raise ValueError(
619
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
620
+ "torch.Tensor, tf.Tensor or jax.ndarray."
621
+ )
622
+
623
+ validate_preprocess_arguments(
624
+ do_normalize=do_normalize,
625
+ image_mean=image_mean,
626
+ image_std=image_std,
627
+ resample=resample,
628
+ do_rescale=do_rescale,
629
+ rescale_factor=rescale_factor,
630
+ )
631
+
632
+ if do_convert_rgb:
633
+ images = [convert_to_rgb(image) for image in images]
634
+
635
+ # All transformations expect numpy arrays.
636
+ images = [to_numpy_array(image) for image in images]
637
+
638
+ if do_rescale and is_scaled_image(images[0]):
639
+ logger.warning_once(
640
+ "It looks like you are trying to rescale already rescaled images. If the input"
641
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
642
+ )
643
+
644
+ if input_data_format is None:
645
+ # We assume that all images have the same channel dimension format.
646
+ input_data_format = infer_channel_dimension_format(images[0])
647
+
648
+ pixel_values = []
649
+ pixel_masks = []
650
+ num_crops = None
651
+
652
+ for image in images:
653
+ if split_image:
654
+ crop_images = self.get_image_patches(
655
+ image,
656
+ self.split_resolutions,
657
+ max_image_size,
658
+ resample,
659
+ data_format=input_data_format,
660
+ input_data_format=input_data_format,
661
+ )
662
+ else:
663
+ crop_images = [image]
664
+ if num_crops is None or len(crop_images) > num_crops:
665
+ num_crops = len(crop_images)
666
+
667
+ for crop_image in crop_images:
668
+ # At this point the scale is the rescaling factor that would bring the image to max_size in its larger dimension
669
+ h, w = get_image_size(crop_image)
670
+ scale = max_image_size / max(h, w)
671
+ if w >= h:
672
+ new_size = (max(int(h * scale), min_image_size), max_image_size) # h, w
673
+ else:
674
+ new_size = (max_image_size, max(int(w * scale), min_image_size)) # h, w
675
+
676
+ crop_image_resized = resize(
677
+ crop_image,
678
+ new_size,
679
+ resample=resample,
680
+ data_format=input_data_format,
681
+ input_data_format=input_data_format,
682
+ )
683
+
684
+ padding_bottom, padding_right = max_image_size - new_size[0], max_image_size - new_size[1]
685
+ crop_image_padded = pad(
686
+ crop_image_resized,
687
+ ((0, padding_bottom), (0, padding_right)),
688
+ data_format=input_data_format,
689
+ input_data_format=input_data_format,
690
+ )
691
+
692
+ # Create a pixel mask
693
+ pixel_mask = np.zeros((max_image_size, max_image_size), dtype=bool)
694
+ pixel_mask[: new_size[0], : new_size[1]] = 1
695
+ pixel_masks.append(pixel_mask)
696
+
697
+ if do_rescale:
698
+ crop_image_padded = self.rescale(
699
+ image=crop_image_padded, scale=rescale_factor, input_data_format=input_data_format
700
+ )
701
+
702
+ if do_normalize:
703
+ crop_image_padded = self.normalize(
704
+ crop_image_padded,
705
+ self.image_mean,
706
+ self.image_std,
707
+ data_format=input_data_format,
708
+ input_data_format=input_data_format,
709
+ )
710
+ crop_image_padded = (
711
+ to_channel_dimension_format(crop_image_padded, data_format, input_data_format)
712
+ if data_format is not None
713
+ else crop_image_padded
714
+ )
715
+
716
+ pixel_values.append(crop_image_padded)
717
+ return BatchFeature(
718
+ data={
719
+ "pixel_values": np.stack(pixel_values, axis=0),
720
+ "pixel_mask": np.stack(pixel_masks, axis=0),
721
+ "num_crops": num_crops,
722
+ },
723
+ tensor_type=return_tensors,
724
+ )
725
+
726
+ def _resize_for_patching(
727
+ self, image: np.ndarray, target_resolution: tuple, resample, input_data_format: ChannelDimension
728
+ ) -> np.ndarray:
729
+ """
730
+ Resizes an image to a target resolution while maintaining aspect ratio.
731
+
732
+ Args:
733
+ image (np.ndarray):
734
+ The input image.
735
+ target_resolution (tuple):
736
+ The target resolution (height, width) of the image.
737
+ resample (`PILImageResampling`):
738
+ Resampling filter to use if resizing the image.
739
+ input_data_format (`ChannelDimension` or `str`):
740
+ The channel dimension format of the input image.
741
+
742
+ Returns:
743
+ np.ndarray: The resized and padded image.
744
+ """
745
+ new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
746
+
747
+ # Resize the image
748
+ resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
749
+
750
+ return resized_image
751
+
752
+ def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
753
+ original_height, original_width = original_resolution
754
+ target_height, target_width = target_resolution
755
+ paste_x, r_x = divmod(target_width - original_width, 2)
756
+ paste_y, r_y = divmod(target_height - original_height, 2)
757
+ return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
758
+
759
+ def _pad_for_patching(
760
+ self, image: np.ndarray, target_resolution: tuple, input_data_format: ChannelDimension
761
+ ) -> np.ndarray:
762
+ """
763
+ Pad an image to a target resolution while maintaining aspect ratio.
764
+ """
765
+ new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
766
+ padding = self._get_padding_size(new_resolution, target_resolution)
767
+
768
+ padded_image = self.pad(image, padding=padding)
769
+
770
+ return padded_image
771
+
772
+ def pad(
773
+ self,
774
+ image: np.ndarray,
775
+ padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]],
776
+ mode: PaddingMode = PaddingMode.CONSTANT,
777
+ constant_values: Union[float, Iterable[float]] = 0.0,
778
+ data_format: Optional[Union[str, ChannelDimension]] = None,
779
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
780
+ ) -> np.ndarray:
781
+ """
782
+ Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
783
+ dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
784
+ as input.
785
+
786
+ Args:
787
+ image (`np.ndarray`):
788
+ The image to pad.
789
+ padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
790
+ Padding to apply to the edges of the height, width axes. Can be one of three formats:
791
+ - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
792
+ - `((before, after),)` yields same before and after pad for height and width.
793
+ - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
794
+ mode (`PaddingMode`):
795
+ The padding mode to use. Can be one of:
796
+ - `"constant"`: pads with a constant value.
797
+ - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
798
+ vector along each axis.
799
+ - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
800
+ - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
801
+ constant_values (`float` or `Iterable[float]`, *optional*):
802
+ The value to use for the padding if `mode` is `"constant"`.
803
+ data_format (`str` or `ChannelDimension`, *optional*):
804
+ The channel dimension format for the output image. Can be one of:
805
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
806
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
807
+ If unset, will use same as the input image.
808
+ input_data_format (`str` or `ChannelDimension`, *optional*):
809
+ The channel dimension format for the input image. Can be one of:
810
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
811
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
812
+ If unset, will use the inferred format of the input image.
813
+
814
+ Returns:
815
+ `np.ndarray`: The padded image.
816
+
817
+ """
818
+
819
+ # call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
820
+ if isinstance(padding, int) or len(padding) != 4:
821
+ return pad(image, padding, mode, constant_values, data_format, input_data_format)
822
+
823
+ if input_data_format is None:
824
+ input_data_format = infer_channel_dimension_format(image)
825
+
826
+ padding_mode_mapping = {
827
+ PaddingMode.CONSTANT: "constant",
828
+ PaddingMode.REFLECT: "reflect",
829
+ PaddingMode.REPLICATE: "edge",
830
+ PaddingMode.SYMMETRIC: "symmetric",
831
+ }
832
+ image = np.pad(image, padding, mode=padding_mode_mapping[mode], constant_values=constant_values)
833
+ image = (
834
+ to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
835
+ )
836
+ return image
837
+
838
+ def get_image_patches(
839
+ self,
840
+ image: np.ndarray,
841
+ grid_pinpoints: list[tuple[int, int]],
842
+ patch_size: int,
843
+ resample: PILImageResampling,
844
+ data_format: ChannelDimension,
845
+ input_data_format: ChannelDimension,
846
+ ) -> list[np.ndarray]:
847
+ """
848
+ Process an image with variable resolutions by dividing it into patches.
849
+
850
+ Args:
851
+ image (`np.ndarray`):
852
+ The input image to be processed.
853
+ grid_pinpoints (list[tuple[int, int]]):
854
+ A list of possible resolutions as tuples.
855
+ patch_size (`int`):
856
+ Size of the patches to divide the image into.
857
+ resample (`PILImageResampling`):
858
+ Resampling filter to use if resizing the image.
859
+ data_format (`ChannelDimension` or `str`):
860
+ The channel dimension format for the output image.
861
+ input_data_format (`ChannelDimension` or `str`):
862
+ The channel dimension format of the input image.
863
+
864
+ Returns:
865
+ `list[np.ndarray]`: A list of NumPy arrays containing the processed image patches.
866
+ """
867
+ if not isinstance(grid_pinpoints, list):
868
+ raise TypeError("grid_pinpoints must be a list of possible resolutions.")
869
+
870
+ possible_resolutions = grid_pinpoints
871
+
872
+ image_size = get_image_size(image, channel_dim=input_data_format)
873
+ best_resolution = select_best_resolution(image_size, possible_resolutions)
874
+ resized_image = self._resize_for_patching(
875
+ image, best_resolution, resample=resample, input_data_format=input_data_format
876
+ )
877
+ padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
878
+
879
+ patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format)
880
+
881
+ # make sure that all patches are in the input data format
882
+ patches = [
883
+ to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format)
884
+ for patch in patches
885
+ ]
886
+ return patches
887
+
888
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
889
+ """
890
+ A utility that returns number of image patches for a given image size.
891
+
892
+ Args:
893
+ height (`int`):
894
+ Height of the input image.
895
+ width (`int`):
896
+ Width of the input image.
897
+ images_kwargs (`dict`, *optional*)
898
+ Any kwargs to override defaults of the image processor.
899
+ Returns:
900
+ `int`: Number of patches per image.
901
+ """
902
+ split_image = images_kwargs.get("split_image", self.split_image)
903
+ max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
904
+
905
+ resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
906
+ num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
907
+ return num_patches
908
+
909
+
910
+ class AriaProcessorKwargs(ProcessingKwargs, total=False):
911
+ _defaults = {
912
+ "text_kwargs": {
913
+ "padding": False,
914
+ "return_mm_token_type_ids": False,
915
+ },
916
+ "images_kwargs": {
917
+ "max_image_size": 980,
918
+ "split_image": False,
919
+ },
920
+ "return_tensors": TensorType.PYTORCH,
921
+ }
922
+
923
+
924
+ class AriaProcessor(ProcessorMixin):
925
+ """
926
+ AriaProcessor is a processor for the Aria model which wraps the Aria image preprocessor and the LLama slow tokenizer.
927
+
928
+ Args:
929
+ image_processor (`AriaImageProcessor`, *optional*):
930
+ The AriaImageProcessor to use for image preprocessing.
931
+ tokenizer (`PreTrainedTokenizerBase`, *optional*):
932
+ An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
933
+ chat_template (`str`, *optional*):
934
+ A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string.
935
+ size_conversion (`Dict`, *optional*):
936
+ A dictionary indicating size conversions for images.
937
+ """
938
+
939
+ attributes = ["image_processor", "tokenizer"]
940
+ image_processor_class = "AriaImageProcessor"
941
+ tokenizer_class = "AutoTokenizer"
942
+
943
+ def __init__(
944
+ self,
945
+ image_processor=None,
946
+ tokenizer: Union[AutoTokenizer, str] = None,
947
+ chat_template: Optional[str] = None,
948
+ size_conversion: Optional[dict[Union[float, int], int]] = None,
949
+ ):
950
+ if size_conversion is None:
951
+ size_conversion = {490: 128, 980: 256}
952
+ self.size_conversion = {int(k): v for k, v in size_conversion.items()}
953
+
954
+ self.image_token = tokenizer.image_token
955
+ self.image_token_id = tokenizer.image_token_id
956
+ if tokenizer is not None and tokenizer.pad_token is None:
957
+ tokenizer.pad_token = tokenizer.unk_token
958
+
959
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
960
+
961
+ def __call__(
962
+ self,
963
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
964
+ images: Optional[ImageInput] = None,
965
+ audio=None,
966
+ videos=None,
967
+ **kwargs: Unpack[AriaProcessorKwargs],
968
+ ) -> BatchFeature:
969
+ """
970
+ Main method to prepare for the model one or several sequences(s) and image(s).
971
+
972
+ Args:
973
+ text (`TextInput`, `PreTokenizedInput`, `list[TextInput]`, `list[PreTokenizedInput]`):
974
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
975
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
976
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
977
+ images (`ImageInput`):
978
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
979
+ tensor. Both channels-first and channels-last formats are supported.
980
+
981
+
982
+ Returns:
983
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
984
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
985
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
986
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
987
+ `None`).
988
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
989
+ - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
990
+ """
991
+ output_kwargs = self._merge_kwargs(
992
+ AriaProcessorKwargs,
993
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
994
+ **kwargs,
995
+ )
996
+
997
+ if isinstance(text, str):
998
+ text = [text]
999
+ elif not isinstance(text, list) and not isinstance(text[0], str):
1000
+ raise TypeError("Invalid input text. Please provide a string, or a list of strings")
1001
+
1002
+ if images is not None:
1003
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
1004
+ # expand the image_token according to the num_crops and tokens per image
1005
+ tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
1006
+ prompt_strings = []
1007
+ num_crops = image_inputs.pop("num_crops") * tokens_per_image
1008
+ for sample in text:
1009
+ sample = sample.replace(self.tokenizer.image_token, self.tokenizer.image_token * num_crops)
1010
+ prompt_strings.append(sample)
1011
+
1012
+ else:
1013
+ image_inputs = {}
1014
+ prompt_strings = text
1015
+
1016
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
1017
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
1018
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
1019
+ self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
1020
+
1021
+ if return_mm_token_type_ids:
1022
+ array_ids = np.array(text_inputs["input_ids"])
1023
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
1024
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
1025
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
1026
+
1027
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
1028
+
1029
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
1030
+ """
1031
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
1032
+ Args:
1033
+ image_sizes (`list[list[int]]`, *optional*):
1034
+ The input sizes formatted as (height, width) per each image.
1035
+ Returns:
1036
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
1037
+ input modalities, along with other useful data.
1038
+ """
1039
+
1040
+ vision_data = {}
1041
+ if image_sizes is not None:
1042
+ images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
1043
+ images_kwargs.update(kwargs)
1044
+
1045
+ max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
1046
+ num_image_patches = [
1047
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
1048
+ for image_size in image_sizes
1049
+ ]
1050
+ num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
1051
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
1052
+
1053
+ return MultiModalData(**vision_data)
1054
+
1055
+ @property
1056
+ def model_input_names(self):
1057
+ tokenizer_input_names = self.tokenizer.model_input_names
1058
+ image_processor_input_names = self.image_processor.model_input_names
1059
+
1060
+ # Remove `num_crops`, it is popped and used only when processing. Make a copy of list when removing
1061
+ # otherwise `self.image_processor.model_input_names` is also modified
1062
+ image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
1063
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
1064
+
1065
+
1066
+ class AriaSharedExpertsMLP(LlamaMLP):
1067
+ """
1068
+ Shared Expert MLP for shared experts.
1069
+
1070
+ Unlike routed experts, shared experts process all tokens without routing.
1071
+ This class reconfigures the intermediate size in comparison to the LlamaMLP.
1072
+
1073
+ Args:
1074
+ config (`AriaTextConfig`): Configuration object for the Aria language model.
1075
+ """
1076
+
1077
+ def __init__(self, config: AriaTextConfig):
1078
+ super().__init__(config)
1079
+ self.intermediate_size = config.intermediate_size * config.moe_num_shared_experts
1080
+
1081
+
1082
+ class AriaGroupedExpertsGemm(nn.Module):
1083
+ """
1084
+ Grouped GEMM (General Matrix Multiplication) module for efficient expert computation.
1085
+ This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm)
1086
+ for optimized performance. If the grouped_gemm library is not installed, it gracefully
1087
+ falls back to a sequential GEMM implementation, which may be slower but ensures
1088
+ functionality.
1089
+
1090
+ Args:
1091
+ in_features (`int`):
1092
+ Number of input features.
1093
+ out_features (`int`):
1094
+ Number of output features.
1095
+ groups (`int`):
1096
+ Number of expert groups.
1097
+ """
1098
+
1099
+ def __init__(self, in_features, out_features, groups):
1100
+ super().__init__()
1101
+ self.in_features = in_features
1102
+ self.out_features = out_features
1103
+ self.groups = groups
1104
+ self.weight = nn.Parameter(torch.empty(groups, in_features, out_features))
1105
+
1106
+ def forward(self, input, tokens_per_expert):
1107
+ """
1108
+ Perform grouped matrix multiplication.
1109
+
1110
+ Args:
1111
+ input (`torch.Tensor`):
1112
+ Input tensor of shape (num_tokens, in_features).
1113
+ tokens_per_expert (`torch.Tensor`):
1114
+ Number of tokens assigned to each expert.
1115
+
1116
+ Returns:
1117
+ torch.Tensor: Output tensor of shape (num_tokens, out_features).
1118
+ """
1119
+ return sequential_experts_gemm(
1120
+ input,
1121
+ self.weight,
1122
+ tokens_per_expert.cpu(),
1123
+ )
1124
+
1125
+
1126
+ class AriaGroupedExpertsMLP(nn.Module):
1127
+ """
1128
+ Grouped MLP module for Mixture of Experts.
1129
+
1130
+ Args:
1131
+ config (`AriaTextConfig`):
1132
+ Configuration object for the model.
1133
+ """
1134
+
1135
+ def __init__(self, config: AriaTextConfig) -> None:
1136
+ super().__init__()
1137
+ self.config = config
1138
+ self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts)
1139
+ self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts)
1140
+
1141
+ def forward(self, permuted_tokens, tokens_per_expert):
1142
+ """
1143
+ Forward pass of the Grouped MLP.
1144
+
1145
+ Args:
1146
+ permuted_tokens (torch.Tensor): Permuted input tokens.
1147
+ tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
1148
+
1149
+ Returns:
1150
+ torch.Tensor: Output tensor after passing through the MLP.
1151
+ """
1152
+ fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
1153
+ projection, gate = torch.chunk(fc1_output, 2, dim=-1)
1154
+ fc1_output = nn.functional.silu(projection) * gate
1155
+ fc2_output = self.fc2(fc1_output, tokens_per_expert)
1156
+ return fc2_output
1157
+
1158
+
1159
+ # Token permutation adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587
1160
+ class AriaTextMoELayer(nn.Module):
1161
+ """
1162
+ Aria Text Mixture of Experts (MoE) Layer.
1163
+
1164
+ This layer applies a gating mechanism to route input tokens to different experts.
1165
+
1166
+ Args:
1167
+ config (`AriaTextConfig`):
1168
+ Configuration object for the text component of the model.
1169
+ """
1170
+
1171
+ def __init__(self, config: AriaTextConfig):
1172
+ super().__init__()
1173
+
1174
+ self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
1175
+ self.experts = AriaGroupedExpertsMLP(config)
1176
+ self.shared_experts = AriaSharedExpertsMLP(config)
1177
+ self.config = config
1178
+
1179
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1180
+ """
1181
+ Forward pass of the MoE Layer.
1182
+
1183
+ Args:
1184
+ hidden_states (`torch.Tensor`):
1185
+ Input tensor of shape (batch_size, sequence_length, hidden_size).
1186
+
1187
+ Returns:
1188
+ torch.Tensor: Output tensor after passing through the MoE layer.
1189
+
1190
+ Process:
1191
+ 1. Route tokens to experts using the router.
1192
+ 2. Permute tokens based on routing decisions.
1193
+ 3. Process tokens through experts.
1194
+ 4. Unpermute and combine expert outputs.
1195
+ 5. Add shared expert output to the final result.
1196
+ """
1197
+ original_shape = hidden_states.shape
1198
+ hidden_states = hidden_states.view(-1, hidden_states.size(-1))
1199
+
1200
+ # Top K Routing
1201
+ logits = self.router(hidden_states)
1202
+ top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
1203
+ scores = nn.functional.softmax(top_logits, dim=-1)
1204
+
1205
+ original_dtype = top_indices.dtype
1206
+
1207
+ tokens_per_expert = torch.histc(
1208
+ top_indices.flatten().to(torch.float32),
1209
+ bins=self.config.moe_num_experts,
1210
+ min=0,
1211
+ max=self.config.moe_num_experts - 1,
1212
+ ).to(original_dtype)
1213
+ indices = top_indices
1214
+
1215
+ # Token permutation
1216
+ flatten_indices = indices.view(-1)
1217
+ sorted_indices = torch.argsort(flatten_indices)
1218
+ permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk)
1219
+
1220
+ # Process through experts
1221
+ expert_output = self.experts(permuted_tokens, tokens_per_expert)
1222
+
1223
+ # Token unpermutation
1224
+ unpermuted_tokens = torch.zeros(
1225
+ (scores.shape[0] * self.config.moe_topk, expert_output.size(1)),
1226
+ dtype=expert_output.dtype,
1227
+ device=expert_output.device,
1228
+ )
1229
+ unpermuted_tokens.index_copy_(0, sorted_indices, expert_output)
1230
+ unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1))
1231
+
1232
+ output = (unpermuted_tokens * scores.unsqueeze(-1)).sum(dim=1).view(original_shape)
1233
+
1234
+ # Add shared expert output
1235
+ shared_expert_output = self.shared_experts(hidden_states.view(original_shape))
1236
+ return output + shared_expert_output
1237
+
1238
+
1239
+ class AriaTextAttention(LlamaAttention):
1240
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
1241
+
1242
+ pass
1243
+
1244
+
1245
+ class AriaTextDecoderLayer(LlamaDecoderLayer):
1246
+ """
1247
+ Aria Text Decoder Layer.
1248
+
1249
+ This class defines a single decoder layer in the language model, incorporating self-attention and Mixture of Experts (MoE) feed-forward network.
1250
+
1251
+ Args:
1252
+ config (`AriaTextConfig`):
1253
+ Configuration object for the text component of the model.
1254
+ layer_idx (`int`):
1255
+ Index of the layer.
1256
+ """
1257
+
1258
+ def __init__(self, config: AriaTextConfig, layer_idx: int):
1259
+ super().__init__(config, layer_idx)
1260
+ self.mlp = AriaTextMoELayer(config)
1261
+
1262
+
1263
+ @auto_docstring
1264
+ class AriaTextPreTrainedModel(PreTrainedModel):
1265
+ config: AriaTextConfig
1266
+ base_model_prefix = "model"
1267
+ _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
1268
+ supports_gradient_checkpointing = True
1269
+ _skip_keys_device_placement = "past_key_values"
1270
+ _supports_flash_attn = True
1271
+ _supports_sdpa = True
1272
+
1273
+ _supports_attention_backend = True
1274
+ _can_record_outputs = {
1275
+ "hidden_states": AriaTextDecoderLayer,
1276
+ "attentions": AriaTextAttention,
1277
+ }
1278
+
1279
+ def _init_weights(self, module):
1280
+ super()._init_weights(module)
1281
+ if isinstance(module, AriaGroupedExpertsGemm):
1282
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1283
+
1284
+
1285
+ class AriaPreTrainedModel(LlamaPreTrainedModel):
1286
+ config: AriaConfig
1287
+ base_model_prefix = ""
1288
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
1289
+ _supports_attention_backend = True
1290
+
1291
+ def _init_weights(self, module):
1292
+ PreTrainedModel._init_weights(self, module)
1293
+ if isinstance(module, AriaProjector):
1294
+ nn.init.trunc_normal_(module.query, std=self.config.initializer_range)
1295
+
1296
+
1297
+ class AriaTextModel(LlamaModel):
1298
+ def __init__(self, config: AriaTextConfig):
1299
+ super().__init__(config)
1300
+ self.layers = nn.ModuleList(
1301
+ [AriaTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1302
+ )
1303
+ self.gradient_checkpointing = False
1304
+ self.post_init()
1305
+
1306
+
1307
+ class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
1308
+ _tied_weights_keys = ["lm_head.weight"]
1309
+
1310
+ def __init__(self, config: AriaTextConfig):
1311
+ super().__init__(config)
1312
+ self.model = AriaTextModel(config)
1313
+ self.vocab_size = config.vocab_size
1314
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1315
+
1316
+ # Initialize weights and apply final processing
1317
+ self.post_init()
1318
+
1319
+ @auto_docstring
1320
+ def forward(self, **super_kwargs):
1321
+ super().forward(self, **super_kwargs)
1322
+
1323
+
1324
+ class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
1325
+ pass
1326
+
1327
+
1328
+ class AriaModelOutputWithPast(LlavaModelOutputWithPast):
1329
+ pass
1330
+
1331
+
1332
+ class AriaModel(LlavaModel):
1333
+ def __init__(self, config: AriaConfig):
1334
+ super().__init__(config)
1335
+ self.multi_modal_projector = AriaProjector(config)
1336
+
1337
+ def _create_patch_attention_mask(self, pixel_mask):
1338
+ if pixel_mask is None:
1339
+ return None
1340
+
1341
+ patches_subgrid = pixel_mask.unfold(
1342
+ dimension=1,
1343
+ size=self.vision_tower.config.patch_size,
1344
+ step=self.vision_tower.config.patch_size,
1345
+ )
1346
+ patches_subgrid = patches_subgrid.unfold(
1347
+ dimension=2,
1348
+ size=self.vision_tower.config.patch_size,
1349
+ step=self.vision_tower.config.patch_size,
1350
+ )
1351
+ return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
1352
+
1353
+ def get_image_features(
1354
+ self,
1355
+ pixel_values: torch.FloatTensor,
1356
+ pixel_mask: Optional[torch.FloatTensor] = None,
1357
+ vision_feature_layer: int = -1,
1358
+ ):
1359
+ """
1360
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
1361
+
1362
+ Args:
1363
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
1364
+ The tensors corresponding to the input images.
1365
+ pixel_mask (`torch.FloatTensor]`, *optional*):
1366
+ The tensors corresponding to the input image mask.
1367
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
1368
+ The index of the layer to select the vision feature. If multiple indices are provided,
1369
+ the vision feature of the corresponding indices will be concatenated to form the
1370
+ vision features.
1371
+ Returns:
1372
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
1373
+ """
1374
+ vision_feature_layer = (
1375
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
1376
+ )
1377
+ patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
1378
+ image_outputs = self.vision_tower(
1379
+ pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
1380
+ )
1381
+ image_attn_mask = None
1382
+ if patch_attention_mask is not None:
1383
+ flattened_mask = patch_attention_mask.flatten(1)
1384
+ image_attn_mask = torch.logical_not(flattened_mask)
1385
+
1386
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
1387
+ image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
1388
+ return image_features
1389
+
1390
+ def forward(
1391
+ self,
1392
+ input_ids: Optional[torch.LongTensor] = None,
1393
+ pixel_values: Optional[torch.FloatTensor] = None,
1394
+ pixel_mask: Optional[torch.LongTensor] = None,
1395
+ attention_mask: Optional[torch.Tensor] = None,
1396
+ position_ids: Optional[torch.LongTensor] = None,
1397
+ past_key_values: Optional[Cache] = None,
1398
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1399
+ use_cache: Optional[bool] = None,
1400
+ cache_position: Optional[torch.LongTensor] = None,
1401
+ **kwargs: Unpack[FlashAttentionKwargs],
1402
+ ) -> Union[tuple, AriaModelOutputWithPast]:
1403
+ if inputs_embeds is None:
1404
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1405
+
1406
+ # 2. Merge text and images
1407
+ if pixel_values is not None and inputs_embeds.shape[1] != 1:
1408
+ image_features = self.get_image_features(
1409
+ pixel_values=pixel_values,
1410
+ pixel_mask=pixel_mask,
1411
+ vision_feature_layer=self.config.vision_feature_layer,
1412
+ )
1413
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
1414
+ special_image_mask = self.get_placeholder_mask(
1415
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
1416
+ )
1417
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
1418
+
1419
+ outputs = self.language_model(
1420
+ attention_mask=attention_mask,
1421
+ position_ids=position_ids,
1422
+ past_key_values=past_key_values,
1423
+ inputs_embeds=inputs_embeds,
1424
+ use_cache=use_cache,
1425
+ cache_position=cache_position,
1426
+ **kwargs,
1427
+ )
1428
+
1429
+ return AriaModelOutputWithPast(
1430
+ last_hidden_state=outputs.last_hidden_state,
1431
+ past_key_values=outputs.past_key_values if use_cache else None,
1432
+ hidden_states=outputs.hidden_states,
1433
+ attentions=outputs.attentions,
1434
+ image_hidden_states=image_features if pixel_values is not None else None,
1435
+ )
1436
+
1437
+
1438
+ @auto_docstring(
1439
+ custom_intro="""
1440
+ Aria model for conditional generation tasks.
1441
+
1442
+ This model combines a vision tower, a multi-modal projector, and a language model
1443
+ to perform tasks that involve both image and text inputs.
1444
+ """
1445
+ )
1446
+ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
1447
+ def get_image_features(
1448
+ self,
1449
+ pixel_values: torch.FloatTensor,
1450
+ pixel_mask: Optional[torch.FloatTensor] = None,
1451
+ vision_feature_layer: int = -1,
1452
+ ):
1453
+ return self.model.get_image_features(
1454
+ pixel_values=pixel_values,
1455
+ pixel_mask=pixel_mask,
1456
+ vision_feature_layer=vision_feature_layer,
1457
+ )
1458
+
1459
+ @can_return_tuple
1460
+ @auto_docstring
1461
+ def forward(
1462
+ self,
1463
+ input_ids: Optional[torch.LongTensor] = None,
1464
+ pixel_values: Optional[torch.FloatTensor] = None,
1465
+ pixel_mask: Optional[torch.LongTensor] = None,
1466
+ attention_mask: Optional[torch.Tensor] = None,
1467
+ position_ids: Optional[torch.LongTensor] = None,
1468
+ past_key_values: Optional[Cache] = None,
1469
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1470
+ labels: Optional[torch.LongTensor] = None,
1471
+ use_cache: Optional[bool] = None,
1472
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1473
+ cache_position: Optional[torch.LongTensor] = None,
1474
+ **kwargs: Unpack[TransformersKwargs],
1475
+ ) -> Union[tuple, AriaCausalLMOutputWithPast]:
1476
+ r"""
1477
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1478
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1479
+ config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `AriaForConditionalGeneration`).
1480
+ Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
1481
+ computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1482
+
1483
+ Example:
1484
+
1485
+ ```python
1486
+ >>> import requests
1487
+ >>> import torch
1488
+ >>> from PIL import Image
1489
+ >>> from io import BytesIO
1490
+
1491
+ >>> from transformers import AutoProcessor, AutoModel
1492
+ >>> from transformers.image_utils import load_image
1493
+
1494
+ >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
1495
+ >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
1496
+ >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
1497
+ >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
1498
+
1499
+ >>> processor = AutoProcessor.from_pretrained("Rhymes-AI/Aria")
1500
+ >>> model = AutoModel.from_pretrained("Rhymes-AI/Aria", dtype=torch.bfloat16, device_map="auto")
1501
+
1502
+ >>> # Create inputs
1503
+ >>> messages = [
1504
+ ... {
1505
+ ... "role": "user",
1506
+ ... "content": [
1507
+ ... {"type": "image"},
1508
+ ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
1509
+ ... {"type": "image"},
1510
+ ... {"type": "text", "text": "What can we see in this image?"},
1511
+ ... ]
1512
+ ... },
1513
+ ... {
1514
+ ... "role": "user",
1515
+ ... "content": [
1516
+ ... {"type": "image"},
1517
+ ... {"type": "text", "text": "In which city is that bridge located?"},
1518
+ ... ]
1519
+ ... }
1520
+ ... ]
1521
+
1522
+ >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
1523
+ >>> images = [[image1, image2], [image3]]
1524
+ >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
1525
+
1526
+ >>> # Generate
1527
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
1528
+ >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
1529
+
1530
+ >>> print(generated_texts[0])
1531
+ Assistant: There are buildings, trees, lights, and water visible in this image.
1532
+
1533
+ >>> print(generated_texts[1])
1534
+ Assistant: The bridge is in San Francisco.
1535
+ ```"""
1536
+ outputs = self.model(
1537
+ input_ids=input_ids,
1538
+ pixel_values=pixel_values,
1539
+ pixel_mask=pixel_mask,
1540
+ attention_mask=attention_mask,
1541
+ position_ids=position_ids,
1542
+ past_key_values=past_key_values,
1543
+ inputs_embeds=inputs_embeds,
1544
+ use_cache=use_cache,
1545
+ cache_position=cache_position,
1546
+ **kwargs,
1547
+ )
1548
+
1549
+ hidden_states = outputs[0]
1550
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1551
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1552
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1553
+
1554
+ loss = None
1555
+ if labels is not None:
1556
+ loss = self.loss_function(
1557
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
1558
+ )
1559
+
1560
+ return AriaCausalLMOutputWithPast(
1561
+ loss=loss,
1562
+ logits=logits,
1563
+ past_key_values=outputs.past_key_values,
1564
+ hidden_states=outputs.hidden_states,
1565
+ attentions=outputs.attentions,
1566
+ )
1567
+
1568
+ def prepare_inputs_for_generation(
1569
+ self,
1570
+ input_ids,
1571
+ past_key_values=None,
1572
+ inputs_embeds=None,
1573
+ pixel_values=None,
1574
+ pixel_mask=None,
1575
+ attention_mask=None,
1576
+ cache_position=None,
1577
+ logits_to_keep=None,
1578
+ **kwargs,
1579
+ ):
1580
+ model_inputs = super().prepare_inputs_for_generation(
1581
+ input_ids,
1582
+ past_key_values=past_key_values,
1583
+ inputs_embeds=inputs_embeds,
1584
+ attention_mask=attention_mask,
1585
+ cache_position=cache_position,
1586
+ logits_to_keep=logits_to_keep,
1587
+ **kwargs,
1588
+ )
1589
+
1590
+ if cache_position[0] == 0:
1591
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
1592
+ # Otherwise we need pixel values to be passed to model
1593
+ model_inputs["pixel_values"] = pixel_values
1594
+ model_inputs["pixel_mask"] = pixel_mask
1595
+
1596
+ return model_inputs
1597
+
1598
+
1599
+ __all__ = [
1600
+ "AriaConfig",
1601
+ "AriaTextConfig",
1602
+ "AriaImageProcessor",
1603
+ "AriaProcessor",
1604
+ "AriaForConditionalGeneration",
1605
+ "AriaPreTrainedModel",
1606
+ "AriaTextPreTrainedModel",
1607
+ "AriaTextModel",
1608
+ "AriaModel",
1609
+ "AriaTextForCausalLM",
1610
+ ]
venv/lib/python3.13/site-packages/transformers/models/aria/processing_aria.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/aria/modular_aria.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_aria.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ from typing import Optional, Union
22
+
23
+ import numpy as np
24
+
25
+ from ...image_processing_utils import BatchFeature
26
+ from ...image_utils import ImageInput
27
+ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
28
+ from ...tokenization_utils import PreTokenizedInput, TextInput
29
+ from ...utils import TensorType
30
+ from ..auto import AutoTokenizer
31
+
32
+
33
+ class AriaProcessorKwargs(ProcessingKwargs, total=False):
34
+ _defaults = {
35
+ "text_kwargs": {
36
+ "padding": False,
37
+ "return_mm_token_type_ids": False,
38
+ },
39
+ "images_kwargs": {
40
+ "max_image_size": 980,
41
+ "split_image": False,
42
+ },
43
+ "return_tensors": TensorType.PYTORCH,
44
+ }
45
+
46
+
47
+ class AriaProcessor(ProcessorMixin):
48
+ """
49
+ AriaProcessor is a processor for the Aria model which wraps the Aria image preprocessor and the LLama slow tokenizer.
50
+
51
+ Args:
52
+ image_processor (`AriaImageProcessor`, *optional*):
53
+ The AriaImageProcessor to use for image preprocessing.
54
+ tokenizer (`PreTrainedTokenizerBase`, *optional*):
55
+ An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
56
+ chat_template (`str`, *optional*):
57
+ A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string.
58
+ size_conversion (`Dict`, *optional*):
59
+ A dictionary indicating size conversions for images.
60
+ """
61
+
62
+ attributes = ["image_processor", "tokenizer"]
63
+ image_processor_class = "AriaImageProcessor"
64
+ tokenizer_class = "AutoTokenizer"
65
+
66
+ def __init__(
67
+ self,
68
+ image_processor=None,
69
+ tokenizer: Union[AutoTokenizer, str] = None,
70
+ chat_template: Optional[str] = None,
71
+ size_conversion: Optional[dict[Union[float, int], int]] = None,
72
+ ):
73
+ if size_conversion is None:
74
+ size_conversion = {490: 128, 980: 256}
75
+ self.size_conversion = {int(k): v for k, v in size_conversion.items()}
76
+
77
+ self.image_token = tokenizer.image_token
78
+ self.image_token_id = tokenizer.image_token_id
79
+ if tokenizer is not None and tokenizer.pad_token is None:
80
+ tokenizer.pad_token = tokenizer.unk_token
81
+
82
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
83
+
84
+ def __call__(
85
+ self,
86
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
87
+ images: Optional[ImageInput] = None,
88
+ audio=None,
89
+ videos=None,
90
+ **kwargs: Unpack[AriaProcessorKwargs],
91
+ ) -> BatchFeature:
92
+ """
93
+ Main method to prepare for the model one or several sequences(s) and image(s).
94
+
95
+ Args:
96
+ text (`TextInput`, `PreTokenizedInput`, `list[TextInput]`, `list[PreTokenizedInput]`):
97
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
98
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
99
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
100
+ images (`ImageInput`):
101
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
102
+ tensor. Both channels-first and channels-last formats are supported.
103
+
104
+
105
+ Returns:
106
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
107
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
108
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
109
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
110
+ `None`).
111
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
112
+ - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
113
+ """
114
+ output_kwargs = self._merge_kwargs(
115
+ AriaProcessorKwargs,
116
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
117
+ **kwargs,
118
+ )
119
+
120
+ if isinstance(text, str):
121
+ text = [text]
122
+ elif not isinstance(text, list) and not isinstance(text[0], str):
123
+ raise TypeError("Invalid input text. Please provide a string, or a list of strings")
124
+
125
+ if images is not None:
126
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
127
+ # expand the image_token according to the num_crops and tokens per image
128
+ tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
129
+ prompt_strings = []
130
+ num_crops = image_inputs.pop("num_crops") * tokens_per_image
131
+ for sample in text:
132
+ sample = sample.replace(self.tokenizer.image_token, self.tokenizer.image_token * num_crops)
133
+ prompt_strings.append(sample)
134
+
135
+ else:
136
+ image_inputs = {}
137
+ prompt_strings = text
138
+
139
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
140
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
141
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
142
+ self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
143
+
144
+ if return_mm_token_type_ids:
145
+ array_ids = np.array(text_inputs["input_ids"])
146
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
147
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
148
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
149
+
150
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
151
+
152
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
153
+ """
154
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
155
+ Args:
156
+ image_sizes (`list[list[int]]`, *optional*):
157
+ The input sizes formatted as (height, width) per each image.
158
+ Returns:
159
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
160
+ input modalities, along with other useful data.
161
+ """
162
+
163
+ vision_data = {}
164
+ if image_sizes is not None:
165
+ images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
166
+ images_kwargs.update(kwargs)
167
+
168
+ max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
169
+ num_image_patches = [
170
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
171
+ for image_size in image_sizes
172
+ ]
173
+ num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
174
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
175
+
176
+ return MultiModalData(**vision_data)
177
+
178
+ @property
179
+ def model_input_names(self):
180
+ tokenizer_input_names = self.tokenizer.model_input_names
181
+ image_processor_input_names = self.image_processor.model_input_names
182
+
183
+ # Remove `num_crops`, it is popped and used only when processing. Make a copy of list when removing
184
+ # otherwise `self.image_processor.model_input_names` is also modified
185
+ image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
186
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
187
+
188
+
189
+ __all__ = ["AriaProcessor"]
venv/lib/python3.13/site-packages/transformers/models/auto/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .auto_factory import *
22
+ from .configuration_auto import *
23
+ from .feature_extraction_auto import *
24
+ from .image_processing_auto import *
25
+ from .modeling_auto import *
26
+ from .modeling_flax_auto import *
27
+ from .modeling_tf_auto import *
28
+ from .processing_auto import *
29
+ from .tokenization_auto import *
30
+ from .video_processing_auto import *
31
+ else:
32
+ import sys
33
+
34
+ _file = globals()["__file__"]
35
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
venv/lib/python3.13/site-packages/transformers/models/auto/auto_factory.py ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Factory function to build auto-model classes."""
16
+
17
+ import copy
18
+ import importlib
19
+ import json
20
+ import os
21
+ import warnings
22
+ from collections import OrderedDict
23
+ from collections.abc import Iterator
24
+ from typing import Any, TypeVar, Union
25
+
26
+ from ...configuration_utils import PretrainedConfig
27
+ from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
28
+ from ...utils import (
29
+ CONFIG_NAME,
30
+ cached_file,
31
+ copy_func,
32
+ extract_commit_hash,
33
+ find_adapter_config_file,
34
+ is_peft_available,
35
+ is_torch_available,
36
+ logging,
37
+ requires_backends,
38
+ )
39
+ from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
40
+
41
+
42
+ if is_torch_available():
43
+ from ...generation import GenerationMixin
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _T = TypeVar("_T")
49
+ # Tokenizers will depend on packages installed, too much variance and there are no common base or Protocol
50
+ _LazyAutoMappingValue = tuple[Union[type[Any], None], Union[type[Any], None]]
51
+
52
+ CLASS_DOCSTRING = """
53
+ This is a generic model class that will be instantiated as one of the model classes of the library when created
54
+ with the [`~BaseAutoModelClass.from_pretrained`] class method or the [`~BaseAutoModelClass.from_config`] class
55
+ method.
56
+
57
+ This class cannot be instantiated directly using `__init__()` (throws an error).
58
+ """
59
+
60
+ FROM_CONFIG_DOCSTRING = """
61
+ Instantiates one of the model classes of the library from a configuration.
62
+
63
+ Note:
64
+ Loading a model from its configuration file does **not** load the model weights. It only affects the
65
+ model's configuration. Use [`~BaseAutoModelClass.from_pretrained`] to load the model weights.
66
+
67
+ Args:
68
+ config ([`PretrainedConfig`]):
69
+ The model class to instantiate is selected based on the configuration class:
70
+
71
+ List options
72
+ attn_implementation (`str`, *optional*):
73
+ The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
74
+
75
+ Examples:
76
+
77
+ ```python
78
+ >>> from transformers import AutoConfig, BaseAutoModelClass
79
+
80
+ >>> # Download configuration from huggingface.co and cache.
81
+ >>> config = AutoConfig.from_pretrained("checkpoint_placeholder")
82
+ >>> model = BaseAutoModelClass.from_config(config)
83
+ ```
84
+ """
85
+
86
+ FROM_PRETRAINED_TORCH_DOCSTRING = """
87
+ Instantiate one of the model classes of the library from a pretrained model.
88
+
89
+ The model class to instantiate is selected based on the `model_type` property of the config object (either
90
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
91
+ falling back to using pattern matching on `pretrained_model_name_or_path`:
92
+
93
+ List options
94
+
95
+ The model is set in evaluation mode by default using `model.eval()` (so for instance, dropout modules are
96
+ deactivated). To train the model, you should first set it back in training mode with `model.train()`
97
+
98
+ Args:
99
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
100
+ Can be either:
101
+
102
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
103
+ - A path to a *directory* containing model weights saved using
104
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
105
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
106
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
107
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
108
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
109
+ model_args (additional positional arguments, *optional*):
110
+ Will be passed along to the underlying model `__init__()` method.
111
+ config ([`PretrainedConfig`], *optional*):
112
+ Configuration for the model to use instead of an automatically loaded configuration. Configuration can
113
+ be automatically loaded when:
114
+
115
+ - The model is a model provided by the library (loaded with the *model id* string of a pretrained
116
+ model).
117
+ - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
118
+ save directory.
119
+ - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
120
+ configuration JSON file named *config.json* is found in the directory.
121
+ state_dict (*dict[str, torch.Tensor]*, *optional*):
122
+ A state dictionary to use instead of a state dictionary loaded from saved weights file.
123
+
124
+ This option can be used if you want to create a model from a pretrained configuration but load your own
125
+ weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and
126
+ [`~PreTrainedModel.from_pretrained`] is not a simpler option.
127
+ cache_dir (`str` or `os.PathLike`, *optional*):
128
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
129
+ standard cache should not be used.
130
+ from_tf (`bool`, *optional*, defaults to `False`):
131
+ Load the model weights from a TensorFlow checkpoint save file (see docstring of
132
+ `pretrained_model_name_or_path` argument).
133
+ force_download (`bool`, *optional*, defaults to `False`):
134
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
135
+ cached versions if they exist.
136
+ resume_download:
137
+ Deprecated and ignored. All downloads are now resumed by default when possible.
138
+ Will be removed in v5 of Transformers.
139
+ proxies (`dict[str, str]`, *optional*):
140
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
141
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
142
+ output_loading_info(`bool`, *optional*, defaults to `False`):
143
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
144
+ local_files_only(`bool`, *optional*, defaults to `False`):
145
+ Whether or not to only look at local files (e.g., not try downloading the model).
146
+ revision (`str`, *optional*, defaults to `"main"`):
147
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
148
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
149
+ identifier allowed by git.
150
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
151
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
152
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
153
+ execute code present on the Hub on your local machine.
154
+ code_revision (`str`, *optional*, defaults to `"main"`):
155
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than
156
+ the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
157
+ system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
158
+ allowed by git.
159
+ kwargs (additional keyword arguments, *optional*):
160
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
161
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
162
+ automatically loaded:
163
+
164
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
165
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
166
+ already been done)
167
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
168
+ initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
169
+ corresponds to a configuration attribute will be used to override said attribute with the
170
+ supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
171
+ will be passed to the underlying model's `__init__` function.
172
+
173
+ Examples:
174
+
175
+ ```python
176
+ >>> from transformers import AutoConfig, BaseAutoModelClass
177
+
178
+ >>> # Download model and configuration from huggingface.co and cache.
179
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
180
+
181
+ >>> # Update configuration during loading
182
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
183
+ >>> model.config.output_attentions
184
+ True
185
+
186
+ >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
187
+ >>> config = AutoConfig.from_pretrained("./tf_model/shortcut_placeholder_tf_model_config.json")
188
+ >>> model = BaseAutoModelClass.from_pretrained(
189
+ ... "./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index", from_tf=True, config=config
190
+ ... )
191
+ ```
192
+ """
193
+
194
+ FROM_PRETRAINED_TF_DOCSTRING = """
195
+ Instantiate one of the model classes of the library from a pretrained model.
196
+
197
+ The model class to instantiate is selected based on the `model_type` property of the config object (either
198
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
199
+ falling back to using pattern matching on `pretrained_model_name_or_path`:
200
+
201
+ List options
202
+
203
+ Args:
204
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
205
+ Can be either:
206
+
207
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
208
+ - A path to a *directory* containing model weights saved using
209
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
210
+ - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
211
+ case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
212
+ argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
213
+ using the provided conversion scripts and loading the TensorFlow model afterwards.
214
+ model_args (additional positional arguments, *optional*):
215
+ Will be passed along to the underlying model `__init__()` method.
216
+ config ([`PretrainedConfig`], *optional*):
217
+ Configuration for the model to use instead of an automatically loaded configuration. Configuration can
218
+ be automatically loaded when:
219
+
220
+ - The model is a model provided by the library (loaded with the *model id* string of a pretrained
221
+ model).
222
+ - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
223
+ save directory.
224
+ - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
225
+ configuration JSON file named *config.json* is found in the directory.
226
+ cache_dir (`str` or `os.PathLike`, *optional*):
227
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
228
+ standard cache should not be used.
229
+ from_pt (`bool`, *optional*, defaults to `False`):
230
+ Load the model weights from a PyTorch checkpoint save file (see docstring of
231
+ `pretrained_model_name_or_path` argument).
232
+ force_download (`bool`, *optional*, defaults to `False`):
233
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
234
+ cached versions if they exist.
235
+ resume_download:
236
+ Deprecated and ignored. All downloads are now resumed by default when possible.
237
+ Will be removed in v5 of Transformers.
238
+ proxies (`dict[str, str]`, *optional*):
239
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
240
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
241
+ output_loading_info(`bool`, *optional*, defaults to `False`):
242
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
243
+ local_files_only(`bool`, *optional*, defaults to `False`):
244
+ Whether or not to only look at local files (e.g., not try downloading the model).
245
+ revision (`str`, *optional*, defaults to `"main"`):
246
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
247
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
248
+ identifier allowed by git.
249
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
250
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
251
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
252
+ execute code present on the Hub on your local machine.
253
+ code_revision (`str`, *optional*, defaults to `"main"`):
254
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than
255
+ the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
256
+ system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
257
+ allowed by git.
258
+ kwargs (additional keyword arguments, *optional*):
259
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
260
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
261
+ automatically loaded:
262
+
263
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
264
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
265
+ already been done)
266
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
267
+ initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
268
+ corresponds to a configuration attribute will be used to override said attribute with the
269
+ supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
270
+ will be passed to the underlying model's `__init__` function.
271
+
272
+ Examples:
273
+
274
+ ```python
275
+ >>> from transformers import AutoConfig, BaseAutoModelClass
276
+
277
+ >>> # Download model and configuration from huggingface.co and cache.
278
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
279
+
280
+ >>> # Update configuration during loading
281
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
282
+ >>> model.config.output_attentions
283
+ True
284
+
285
+ >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
286
+ >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json")
287
+ >>> model = BaseAutoModelClass.from_pretrained(
288
+ ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config
289
+ ... )
290
+ ```
291
+ """
292
+
293
+ FROM_PRETRAINED_FLAX_DOCSTRING = """
294
+ Instantiate one of the model classes of the library from a pretrained model.
295
+
296
+ The model class to instantiate is selected based on the `model_type` property of the config object (either
297
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
298
+ falling back to using pattern matching on `pretrained_model_name_or_path`:
299
+
300
+ List options
301
+
302
+ Args:
303
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
304
+ Can be either:
305
+
306
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
307
+ - A path to a *directory* containing model weights saved using
308
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
309
+ - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
310
+ case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
311
+ argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
312
+ using the provided conversion scripts and loading the TensorFlow model afterwards.
313
+ model_args (additional positional arguments, *optional*):
314
+ Will be passed along to the underlying model `__init__()` method.
315
+ config ([`PretrainedConfig`], *optional*):
316
+ Configuration for the model to use instead of an automatically loaded configuration. Configuration can
317
+ be automatically loaded when:
318
+
319
+ - The model is a model provided by the library (loaded with the *model id* string of a pretrained
320
+ model).
321
+ - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
322
+ save directory.
323
+ - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
324
+ configuration JSON file named *config.json* is found in the directory.
325
+ cache_dir (`str` or `os.PathLike`, *optional*):
326
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
327
+ standard cache should not be used.
328
+ from_pt (`bool`, *optional*, defaults to `False`):
329
+ Load the model weights from a PyTorch checkpoint save file (see docstring of
330
+ `pretrained_model_name_or_path` argument).
331
+ force_download (`bool`, *optional*, defaults to `False`):
332
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
333
+ cached versions if they exist.
334
+ resume_download:
335
+ Deprecated and ignored. All downloads are now resumed by default when possible.
336
+ Will be removed in v5 of Transformers.
337
+ proxies (`dict[str, str]`, *optional*):
338
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
339
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
340
+ output_loading_info(`bool`, *optional*, defaults to `False`):
341
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
342
+ local_files_only(`bool`, *optional*, defaults to `False`):
343
+ Whether or not to only look at local files (e.g., not try downloading the model).
344
+ revision (`str`, *optional*, defaults to `"main"`):
345
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
346
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
347
+ identifier allowed by git.
348
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
349
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
350
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
351
+ execute code present on the Hub on your local machine.
352
+ code_revision (`str`, *optional*, defaults to `"main"`):
353
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than
354
+ the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
355
+ system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
356
+ allowed by git.
357
+ kwargs (additional keyword arguments, *optional*):
358
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
359
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
360
+ automatically loaded:
361
+
362
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
363
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
364
+ already been done)
365
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
366
+ initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
367
+ corresponds to a configuration attribute will be used to override said attribute with the
368
+ supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
369
+ will be passed to the underlying model's `__init__` function.
370
+
371
+ Examples:
372
+
373
+ ```python
374
+ >>> from transformers import AutoConfig, BaseAutoModelClass
375
+
376
+ >>> # Download model and configuration from huggingface.co and cache.
377
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
378
+
379
+ >>> # Update configuration during loading
380
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
381
+ >>> model.config.output_attentions
382
+ True
383
+
384
+ >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
385
+ >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json")
386
+ >>> model = BaseAutoModelClass.from_pretrained(
387
+ ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config
388
+ ... )
389
+ ```
390
+ """
391
+
392
+
393
+ def _get_model_class(config, model_mapping):
394
+ supported_models = model_mapping[type(config)]
395
+ if not isinstance(supported_models, (list, tuple)):
396
+ return supported_models
397
+
398
+ name_to_model = {model.__name__: model for model in supported_models}
399
+ architectures = getattr(config, "architectures", [])
400
+ for arch in architectures:
401
+ if arch in name_to_model:
402
+ return name_to_model[arch]
403
+ elif f"TF{arch}" in name_to_model:
404
+ return name_to_model[f"TF{arch}"]
405
+ elif f"Flax{arch}" in name_to_model:
406
+ return name_to_model[f"Flax{arch}"]
407
+
408
+ # If not architecture is set in the config or match the supported models, the first element of the tuple is the
409
+ # defaults.
410
+ return supported_models[0]
411
+
412
+
413
+ class _BaseAutoModelClass:
414
+ # Base class for auto models.
415
+ _model_mapping = None
416
+
417
+ def __init__(self, *args, **kwargs) -> None:
418
+ raise OSError(
419
+ f"{self.__class__.__name__} is designed to be instantiated "
420
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
421
+ f"`{self.__class__.__name__}.from_config(config)` methods."
422
+ )
423
+
424
+ @classmethod
425
+ def from_config(cls, config, **kwargs):
426
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
427
+ has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
428
+ has_local_code = type(config) in cls._model_mapping
429
+ if has_remote_code:
430
+ class_ref = config.auto_map[cls.__name__]
431
+ if "--" in class_ref:
432
+ upstream_repo = class_ref.split("--")[0]
433
+ else:
434
+ upstream_repo = None
435
+ trust_remote_code = resolve_trust_remote_code(
436
+ trust_remote_code, config._name_or_path, has_local_code, has_remote_code, upstream_repo=upstream_repo
437
+ )
438
+
439
+ if has_remote_code and trust_remote_code:
440
+ if "--" in class_ref:
441
+ repo_id, class_ref = class_ref.split("--")
442
+ else:
443
+ repo_id = config.name_or_path
444
+ model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
445
+ # This block handles the case where the user is loading a model with `trust_remote_code=True`
446
+ # but a library model exists with the same name. We don't want to override the autoclass
447
+ # mappings in this case, or all future loads of that model will be the remote code model.
448
+ if not has_local_code:
449
+ cls.register(config.__class__, model_class, exist_ok=True)
450
+ model_class.register_for_auto_class(auto_class=cls)
451
+ _ = kwargs.pop("code_revision", None)
452
+ model_class = add_generation_mixin_to_remote_model(model_class)
453
+ return model_class._from_config(config, **kwargs)
454
+ elif type(config) in cls._model_mapping:
455
+ model_class = _get_model_class(config, cls._model_mapping)
456
+ return model_class._from_config(config, **kwargs)
457
+
458
+ raise ValueError(
459
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
460
+ f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}."
461
+ )
462
+
463
+ @classmethod
464
+ def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
465
+ """Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses."""
466
+ return config
467
+
468
+ @classmethod
469
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], *model_args, **kwargs):
470
+ config = kwargs.pop("config", None)
471
+ trust_remote_code = kwargs.get("trust_remote_code")
472
+ kwargs["_from_auto"] = True
473
+ hub_kwargs_names = [
474
+ "cache_dir",
475
+ "force_download",
476
+ "local_files_only",
477
+ "proxies",
478
+ "resume_download",
479
+ "revision",
480
+ "subfolder",
481
+ "use_auth_token",
482
+ "token",
483
+ ]
484
+ hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
485
+ code_revision = kwargs.pop("code_revision", None)
486
+ commit_hash = kwargs.pop("_commit_hash", None)
487
+ adapter_kwargs = kwargs.pop("adapter_kwargs", None)
488
+
489
+ token = hub_kwargs.pop("token", None)
490
+ use_auth_token = hub_kwargs.pop("use_auth_token", None)
491
+ if use_auth_token is not None:
492
+ warnings.warn(
493
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
494
+ FutureWarning,
495
+ )
496
+ if token is not None:
497
+ raise ValueError(
498
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
499
+ )
500
+ token = use_auth_token
501
+
502
+ if token is not None:
503
+ hub_kwargs["token"] = token
504
+
505
+ if commit_hash is None:
506
+ if not isinstance(config, PretrainedConfig):
507
+ # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
508
+ resolved_config_file = cached_file(
509
+ pretrained_model_name_or_path,
510
+ CONFIG_NAME,
511
+ _raise_exceptions_for_gated_repo=False,
512
+ _raise_exceptions_for_missing_entries=False,
513
+ _raise_exceptions_for_connection_errors=False,
514
+ **hub_kwargs,
515
+ )
516
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
517
+ else:
518
+ commit_hash = getattr(config, "_commit_hash", None)
519
+
520
+ if is_peft_available():
521
+ if adapter_kwargs is None:
522
+ adapter_kwargs = {}
523
+ if token is not None:
524
+ adapter_kwargs["token"] = token
525
+
526
+ maybe_adapter_path = find_adapter_config_file(
527
+ pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs
528
+ )
529
+
530
+ if maybe_adapter_path is not None:
531
+ with open(maybe_adapter_path, "r", encoding="utf-8") as f:
532
+ adapter_config = json.load(f)
533
+
534
+ adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path
535
+ pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
536
+
537
+ if not isinstance(config, PretrainedConfig):
538
+ kwargs_orig = copy.deepcopy(kwargs)
539
+ # ensure not to pollute the config object with dtype="auto" - since it's
540
+ # meaningless in the context of the config object - torch.dtype values are acceptable
541
+ if kwargs.get("torch_dtype") == "auto":
542
+ _ = kwargs.pop("torch_dtype")
543
+ if kwargs.get("dtype") == "auto":
544
+ _ = kwargs.pop("dtype")
545
+ # to not overwrite the quantization_config if config has a quantization_config
546
+ if kwargs.get("quantization_config") is not None:
547
+ _ = kwargs.pop("quantization_config")
548
+
549
+ config, kwargs = AutoConfig.from_pretrained(
550
+ pretrained_model_name_or_path,
551
+ return_unused_kwargs=True,
552
+ code_revision=code_revision,
553
+ _commit_hash=commit_hash,
554
+ **hub_kwargs,
555
+ **kwargs,
556
+ )
557
+
558
+ # if torch_dtype=auto was passed here, ensure to pass it on
559
+ if kwargs_orig.get("torch_dtype", None) == "auto":
560
+ kwargs["torch_dtype"] = "auto"
561
+ if kwargs_orig.get("dtype", None) == "auto":
562
+ kwargs["dtype"] = "auto"
563
+ if kwargs_orig.get("quantization_config", None) is not None:
564
+ kwargs["quantization_config"] = kwargs_orig["quantization_config"]
565
+
566
+ has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
567
+ has_local_code = type(config) in cls._model_mapping
568
+ upstream_repo = None
569
+ if has_remote_code:
570
+ class_ref = config.auto_map[cls.__name__]
571
+ if "--" in class_ref:
572
+ upstream_repo = class_ref.split("--")[0]
573
+ trust_remote_code = resolve_trust_remote_code(
574
+ trust_remote_code,
575
+ pretrained_model_name_or_path,
576
+ has_local_code,
577
+ has_remote_code,
578
+ upstream_repo=upstream_repo,
579
+ )
580
+ kwargs["trust_remote_code"] = trust_remote_code
581
+
582
+ # Set the adapter kwargs
583
+ kwargs["adapter_kwargs"] = adapter_kwargs
584
+
585
+ if has_remote_code and trust_remote_code:
586
+ model_class = get_class_from_dynamic_module(
587
+ class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
588
+ )
589
+ _ = hub_kwargs.pop("code_revision", None)
590
+ # This block handles the case where the user is loading a model with `trust_remote_code=True`
591
+ # but a library model exists with the same name. We don't want to override the autoclass
592
+ # mappings in this case, or all future loads of that model will be the remote code model.
593
+ if not has_local_code:
594
+ cls.register(config.__class__, model_class, exist_ok=True)
595
+ model_class.register_for_auto_class(auto_class=cls)
596
+ model_class = add_generation_mixin_to_remote_model(model_class)
597
+ return model_class.from_pretrained(
598
+ pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
599
+ )
600
+ elif type(config) in cls._model_mapping:
601
+ model_class = _get_model_class(config, cls._model_mapping)
602
+ if model_class.config_class == config.sub_configs.get("text_config", None):
603
+ config = config.get_text_config()
604
+ return model_class.from_pretrained(
605
+ pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
606
+ )
607
+ raise ValueError(
608
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
609
+ f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}."
610
+ )
611
+
612
+ @classmethod
613
+ def register(cls, config_class, model_class, exist_ok=False) -> None:
614
+ """
615
+ Register a new model for this class.
616
+
617
+ Args:
618
+ config_class ([`PretrainedConfig`]):
619
+ The configuration corresponding to the model to register.
620
+ model_class ([`PreTrainedModel`]):
621
+ The model to register.
622
+ """
623
+ if hasattr(model_class, "config_class") and model_class.config_class.__name__ != config_class.__name__:
624
+ raise ValueError(
625
+ "The model class you are passing has a `config_class` attribute that is not consistent with the "
626
+ f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix "
627
+ "one of those so they match!"
628
+ )
629
+ cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok)
630
+
631
+
632
+ class _BaseAutoBackboneClass(_BaseAutoModelClass):
633
+ # Base class for auto backbone models.
634
+ _model_mapping = None
635
+
636
+ @classmethod
637
+ def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
638
+ requires_backends(cls, ["vision", "timm"])
639
+ from ...models.timm_backbone import TimmBackboneConfig
640
+
641
+ config = kwargs.pop("config", TimmBackboneConfig())
642
+
643
+ if kwargs.get("out_features") is not None:
644
+ raise ValueError("Cannot specify `out_features` for timm backbones")
645
+
646
+ if kwargs.get("output_loading_info", False):
647
+ raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")
648
+
649
+ num_channels = kwargs.pop("num_channels", config.num_channels)
650
+ features_only = kwargs.pop("features_only", config.features_only)
651
+ use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
652
+ out_indices = kwargs.pop("out_indices", config.out_indices)
653
+ config = TimmBackboneConfig(
654
+ backbone=pretrained_model_name_or_path,
655
+ num_channels=num_channels,
656
+ features_only=features_only,
657
+ use_pretrained_backbone=use_pretrained_backbone,
658
+ out_indices=out_indices,
659
+ )
660
+ return super().from_config(config, **kwargs)
661
+
662
+ @classmethod
663
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
664
+ use_timm_backbone = kwargs.pop("use_timm_backbone", False)
665
+ if use_timm_backbone:
666
+ return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
667
+
668
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
669
+
670
+
671
+ def insert_head_doc(docstring, head_doc: str = ""):
672
+ if len(head_doc) > 0:
673
+ return docstring.replace(
674
+ "one of the model classes of the library ",
675
+ f"one of the model classes of the library (with a {head_doc} head) ",
676
+ )
677
+ return docstring.replace(
678
+ "one of the model classes of the library ", "one of the base model classes of the library "
679
+ )
680
+
681
+
682
+ def auto_class_update(cls, checkpoint_for_example: str = "google-bert/bert-base-cased", head_doc: str = ""):
683
+ # Create a new class with the right name from the base class
684
+ model_mapping = cls._model_mapping
685
+ name = cls.__name__
686
+ class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
687
+ cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
688
+
689
+ # Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
690
+ # have a specific docstrings for them.
691
+ from_config = copy_func(_BaseAutoModelClass.from_config)
692
+ from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc)
693
+ from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
694
+ from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
695
+ from_config.__doc__ = from_config_docstring
696
+ from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
697
+ cls.from_config = classmethod(from_config)
698
+
699
+ if name.startswith("TF"):
700
+ from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
701
+ elif name.startswith("Flax"):
702
+ from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING
703
+ else:
704
+ from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING
705
+ from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)
706
+ from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc)
707
+ from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name)
708
+ from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
709
+ shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
710
+ from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
711
+ from_pretrained.__doc__ = from_pretrained_docstring
712
+ from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
713
+ cls.from_pretrained = classmethod(from_pretrained)
714
+ return cls
715
+
716
+
717
+ def get_values(model_mapping):
718
+ result = []
719
+ for model in model_mapping.values():
720
+ if isinstance(model, (list, tuple)):
721
+ result += list(model)
722
+ else:
723
+ result.append(model)
724
+
725
+ return result
726
+
727
+
728
+ def getattribute_from_module(module, attr):
729
+ if attr is None:
730
+ return None
731
+ if isinstance(attr, tuple):
732
+ return tuple(getattribute_from_module(module, a) for a in attr)
733
+ if hasattr(module, attr):
734
+ return getattr(module, attr)
735
+ # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
736
+ # object at the top level.
737
+ transformers_module = importlib.import_module("transformers")
738
+
739
+ if module != transformers_module:
740
+ try:
741
+ return getattribute_from_module(transformers_module, attr)
742
+ except ValueError:
743
+ raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!")
744
+ else:
745
+ raise ValueError(f"Could not find {attr} in {transformers_module}!")
746
+
747
+
748
+ def add_generation_mixin_to_remote_model(model_class):
749
+ """
750
+ Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model.
751
+
752
+ This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make
753
+ `PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded
754
+ from the Hub may not have the `generate` method after we remove the inheritance.
755
+ """
756
+ # 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing
757
+ if "torch.nn.modules.module.Module" not in str(model_class.__mro__):
758
+ return model_class
759
+
760
+ # 2. If it already **directly** inherits from GenerationMixin, do nothing
761
+ if "GenerationMixin" in str(model_class.__bases__):
762
+ return model_class
763
+
764
+ # 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
765
+ # `prepare_inputs_for_generation` method.
766
+ has_custom_generate_in_class = hasattr(model_class, "generate") and "GenerationMixin" not in str(
767
+ getattr(model_class, "generate")
768
+ )
769
+ has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str(
770
+ getattr(model_class, "prepare_inputs_for_generation")
771
+ )
772
+ if has_custom_generate_in_class or has_custom_prepare_inputs:
773
+ model_class_with_generation_mixin = type(
774
+ model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
775
+ )
776
+ return model_class_with_generation_mixin
777
+ return model_class
778
+
779
+
780
+ class _LazyAutoMapping(OrderedDict[type[PretrainedConfig], _LazyAutoMappingValue]):
781
+ """
782
+ " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
783
+
784
+ Args:
785
+ - config_mapping: The map model type to config class
786
+ - model_mapping: The map model type to model (or tokenizer) class
787
+ """
788
+
789
+ def __init__(self, config_mapping, model_mapping) -> None:
790
+ self._config_mapping = config_mapping
791
+ self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
792
+ self._model_mapping = model_mapping
793
+ self._model_mapping._model_mapping = self
794
+ self._extra_content = {}
795
+ self._modules = {}
796
+
797
+ def __len__(self) -> int:
798
+ common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys())
799
+ return len(common_keys) + len(self._extra_content)
800
+
801
+ def __getitem__(self, key: type[PretrainedConfig]) -> _LazyAutoMappingValue:
802
+ if key in self._extra_content:
803
+ return self._extra_content[key]
804
+ model_type = self._reverse_config_mapping[key.__name__]
805
+ if model_type in self._model_mapping:
806
+ model_name = self._model_mapping[model_type]
807
+ return self._load_attr_from_module(model_type, model_name)
808
+
809
+ # Maybe there was several model types associated with this config.
810
+ model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
811
+ for mtype in model_types:
812
+ if mtype in self._model_mapping:
813
+ model_name = self._model_mapping[mtype]
814
+ return self._load_attr_from_module(mtype, model_name)
815
+ raise KeyError(key)
816
+
817
+ def _load_attr_from_module(self, model_type, attr):
818
+ module_name = model_type_to_module_name(model_type)
819
+ if module_name not in self._modules:
820
+ self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
821
+ return getattribute_from_module(self._modules[module_name], attr)
822
+
823
+ def keys(self) -> list[type[PretrainedConfig]]:
824
+ mapping_keys = [
825
+ self._load_attr_from_module(key, name)
826
+ for key, name in self._config_mapping.items()
827
+ if key in self._model_mapping
828
+ ]
829
+ return mapping_keys + list(self._extra_content.keys())
830
+
831
+ def get(self, key: type[PretrainedConfig], default: _T) -> Union[_LazyAutoMappingValue, _T]:
832
+ try:
833
+ return self.__getitem__(key)
834
+ except KeyError:
835
+ return default
836
+
837
+ def __bool__(self) -> bool:
838
+ return bool(self.keys())
839
+
840
+ def values(self) -> list[_LazyAutoMappingValue]:
841
+ mapping_values = [
842
+ self._load_attr_from_module(key, name)
843
+ for key, name in self._model_mapping.items()
844
+ if key in self._config_mapping
845
+ ]
846
+ return mapping_values + list(self._extra_content.values())
847
+
848
+ def items(self) -> list[tuple[type[PretrainedConfig], _LazyAutoMappingValue]]:
849
+ mapping_items = [
850
+ (
851
+ self._load_attr_from_module(key, self._config_mapping[key]),
852
+ self._load_attr_from_module(key, self._model_mapping[key]),
853
+ )
854
+ for key in self._model_mapping
855
+ if key in self._config_mapping
856
+ ]
857
+ return mapping_items + list(self._extra_content.items())
858
+
859
+ def __iter__(self) -> Iterator[type[PretrainedConfig]]:
860
+ return iter(self.keys())
861
+
862
+ def __contains__(self, item: type) -> bool:
863
+ if item in self._extra_content:
864
+ return True
865
+ if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
866
+ return False
867
+ model_type = self._reverse_config_mapping[item.__name__]
868
+ return model_type in self._model_mapping
869
+
870
+ def register(self, key: type[PretrainedConfig], value: _LazyAutoMappingValue, exist_ok=False) -> None:
871
+ """
872
+ Register a new model in this mapping.
873
+ """
874
+ if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
875
+ model_type = self._reverse_config_mapping[key.__name__]
876
+ if model_type in self._model_mapping and not exist_ok:
877
+ raise ValueError(f"'{key}' is already used by a Transformers model.")
878
+
879
+ self._extra_content[key] = value
880
+
881
+
882
+ __all__ = ["get_values"]
venv/lib/python3.13/site-packages/transformers/models/auto/configuration_auto.py ADDED
@@ -0,0 +1,1404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Auto Config class."""
16
+
17
+ import importlib
18
+ import os
19
+ import re
20
+ import warnings
21
+ from collections import OrderedDict
22
+ from collections.abc import Callable, Iterator, KeysView, ValuesView
23
+ from typing import Any, TypeVar, Union
24
+
25
+ from ...configuration_utils import PretrainedConfig
26
+ from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
27
+ from ...utils import CONFIG_NAME, logging
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ _CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
34
+
35
+
36
+ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
37
+ [
38
+ # Add configs here
39
+ ("aimv2", "Aimv2Config"),
40
+ ("aimv2_vision_model", "Aimv2VisionConfig"),
41
+ ("albert", "AlbertConfig"),
42
+ ("align", "AlignConfig"),
43
+ ("altclip", "AltCLIPConfig"),
44
+ ("apertus", "ApertusConfig"),
45
+ ("arcee", "ArceeConfig"),
46
+ ("aria", "AriaConfig"),
47
+ ("aria_text", "AriaTextConfig"),
48
+ ("audio-spectrogram-transformer", "ASTConfig"),
49
+ ("autoformer", "AutoformerConfig"),
50
+ ("aya_vision", "AyaVisionConfig"),
51
+ ("bamba", "BambaConfig"),
52
+ ("bark", "BarkConfig"),
53
+ ("bart", "BartConfig"),
54
+ ("beit", "BeitConfig"),
55
+ ("bert", "BertConfig"),
56
+ ("bert-generation", "BertGenerationConfig"),
57
+ ("big_bird", "BigBirdConfig"),
58
+ ("bigbird_pegasus", "BigBirdPegasusConfig"),
59
+ ("biogpt", "BioGptConfig"),
60
+ ("bit", "BitConfig"),
61
+ ("bitnet", "BitNetConfig"),
62
+ ("blenderbot", "BlenderbotConfig"),
63
+ ("blenderbot-small", "BlenderbotSmallConfig"),
64
+ ("blip", "BlipConfig"),
65
+ ("blip-2", "Blip2Config"),
66
+ ("blip_2_qformer", "Blip2QFormerConfig"),
67
+ ("bloom", "BloomConfig"),
68
+ ("blt", "BltConfig"),
69
+ ("bridgetower", "BridgeTowerConfig"),
70
+ ("bros", "BrosConfig"),
71
+ ("camembert", "CamembertConfig"),
72
+ ("canine", "CanineConfig"),
73
+ ("chameleon", "ChameleonConfig"),
74
+ ("chinese_clip", "ChineseCLIPConfig"),
75
+ ("chinese_clip_vision_model", "ChineseCLIPVisionConfig"),
76
+ ("clap", "ClapConfig"),
77
+ ("clip", "CLIPConfig"),
78
+ ("clip_text_model", "CLIPTextConfig"),
79
+ ("clip_vision_model", "CLIPVisionConfig"),
80
+ ("clipseg", "CLIPSegConfig"),
81
+ ("clvp", "ClvpConfig"),
82
+ ("code_llama", "LlamaConfig"),
83
+ ("codegen", "CodeGenConfig"),
84
+ ("cohere", "CohereConfig"),
85
+ ("cohere2", "Cohere2Config"),
86
+ ("cohere2_vision", "Cohere2VisionConfig"),
87
+ ("colpali", "ColPaliConfig"),
88
+ ("colqwen2", "ColQwen2Config"),
89
+ ("conditional_detr", "ConditionalDetrConfig"),
90
+ ("convbert", "ConvBertConfig"),
91
+ ("convnext", "ConvNextConfig"),
92
+ ("convnextv2", "ConvNextV2Config"),
93
+ ("cpmant", "CpmAntConfig"),
94
+ ("csm", "CsmConfig"),
95
+ ("ctrl", "CTRLConfig"),
96
+ ("cvt", "CvtConfig"),
97
+ ("d_fine", "DFineConfig"),
98
+ ("dab-detr", "DabDetrConfig"),
99
+ ("dac", "DacConfig"),
100
+ ("data2vec-audio", "Data2VecAudioConfig"),
101
+ ("data2vec-text", "Data2VecTextConfig"),
102
+ ("data2vec-vision", "Data2VecVisionConfig"),
103
+ ("dbrx", "DbrxConfig"),
104
+ ("deberta", "DebertaConfig"),
105
+ ("deberta-v2", "DebertaV2Config"),
106
+ ("decision_transformer", "DecisionTransformerConfig"),
107
+ ("deepseek_v2", "DeepseekV2Config"),
108
+ ("deepseek_v3", "DeepseekV3Config"),
109
+ ("deepseek_vl", "DeepseekVLConfig"),
110
+ ("deepseek_vl_hybrid", "DeepseekVLHybridConfig"),
111
+ ("deformable_detr", "DeformableDetrConfig"),
112
+ ("deit", "DeiTConfig"),
113
+ ("depth_anything", "DepthAnythingConfig"),
114
+ ("depth_pro", "DepthProConfig"),
115
+ ("deta", "DetaConfig"),
116
+ ("detr", "DetrConfig"),
117
+ ("dia", "DiaConfig"),
118
+ ("diffllama", "DiffLlamaConfig"),
119
+ ("dinat", "DinatConfig"),
120
+ ("dinov2", "Dinov2Config"),
121
+ ("dinov2_with_registers", "Dinov2WithRegistersConfig"),
122
+ ("dinov3_convnext", "DINOv3ConvNextConfig"),
123
+ ("dinov3_vit", "DINOv3ViTConfig"),
124
+ ("distilbert", "DistilBertConfig"),
125
+ ("doge", "DogeConfig"),
126
+ ("donut-swin", "DonutSwinConfig"),
127
+ ("dots1", "Dots1Config"),
128
+ ("dpr", "DPRConfig"),
129
+ ("dpt", "DPTConfig"),
130
+ ("edgetam", "EdgeTamConfig"),
131
+ ("edgetam_video", "EdgeTamVideoConfig"),
132
+ ("edgetam_vision_model", "EdgeTamVisionConfig"),
133
+ ("efficientformer", "EfficientFormerConfig"),
134
+ ("efficientloftr", "EfficientLoFTRConfig"),
135
+ ("efficientnet", "EfficientNetConfig"),
136
+ ("electra", "ElectraConfig"),
137
+ ("emu3", "Emu3Config"),
138
+ ("encodec", "EncodecConfig"),
139
+ ("encoder-decoder", "EncoderDecoderConfig"),
140
+ ("eomt", "EomtConfig"),
141
+ ("ernie", "ErnieConfig"),
142
+ ("ernie4_5", "Ernie4_5Config"),
143
+ ("ernie4_5_moe", "Ernie4_5_MoeConfig"),
144
+ ("ernie_m", "ErnieMConfig"),
145
+ ("esm", "EsmConfig"),
146
+ ("evolla", "EvollaConfig"),
147
+ ("exaone4", "Exaone4Config"),
148
+ ("falcon", "FalconConfig"),
149
+ ("falcon_h1", "FalconH1Config"),
150
+ ("falcon_mamba", "FalconMambaConfig"),
151
+ ("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
152
+ ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGanConfig"),
153
+ ("flaubert", "FlaubertConfig"),
154
+ ("flava", "FlavaConfig"),
155
+ ("flex_olmo", "FlexOlmoConfig"),
156
+ ("florence2", "Florence2Config"),
157
+ ("fnet", "FNetConfig"),
158
+ ("focalnet", "FocalNetConfig"),
159
+ ("fsmt", "FSMTConfig"),
160
+ ("funnel", "FunnelConfig"),
161
+ ("fuyu", "FuyuConfig"),
162
+ ("gemma", "GemmaConfig"),
163
+ ("gemma2", "Gemma2Config"),
164
+ ("gemma3", "Gemma3Config"),
165
+ ("gemma3_text", "Gemma3TextConfig"),
166
+ ("gemma3n", "Gemma3nConfig"),
167
+ ("gemma3n_audio", "Gemma3nAudioConfig"),
168
+ ("gemma3n_text", "Gemma3nTextConfig"),
169
+ ("gemma3n_vision", "Gemma3nVisionConfig"),
170
+ ("git", "GitConfig"),
171
+ ("glm", "GlmConfig"),
172
+ ("glm4", "Glm4Config"),
173
+ ("glm4_moe", "Glm4MoeConfig"),
174
+ ("glm4v", "Glm4vConfig"),
175
+ ("glm4v_moe", "Glm4vMoeConfig"),
176
+ ("glm4v_moe_text", "Glm4vMoeTextConfig"),
177
+ ("glm4v_text", "Glm4vTextConfig"),
178
+ ("glpn", "GLPNConfig"),
179
+ ("got_ocr2", "GotOcr2Config"),
180
+ ("gpt-sw3", "GPT2Config"),
181
+ ("gpt2", "GPT2Config"),
182
+ ("gpt_bigcode", "GPTBigCodeConfig"),
183
+ ("gpt_neo", "GPTNeoConfig"),
184
+ ("gpt_neox", "GPTNeoXConfig"),
185
+ ("gpt_neox_japanese", "GPTNeoXJapaneseConfig"),
186
+ ("gpt_oss", "GptOssConfig"),
187
+ ("gptj", "GPTJConfig"),
188
+ ("gptsan-japanese", "GPTSanJapaneseConfig"),
189
+ ("granite", "GraniteConfig"),
190
+ ("granite_speech", "GraniteSpeechConfig"),
191
+ ("granitemoe", "GraniteMoeConfig"),
192
+ ("granitemoehybrid", "GraniteMoeHybridConfig"),
193
+ ("granitemoeshared", "GraniteMoeSharedConfig"),
194
+ ("granitevision", "LlavaNextConfig"),
195
+ ("graphormer", "GraphormerConfig"),
196
+ ("grounding-dino", "GroundingDinoConfig"),
197
+ ("groupvit", "GroupViTConfig"),
198
+ ("helium", "HeliumConfig"),
199
+ ("hgnet_v2", "HGNetV2Config"),
200
+ ("hiera", "HieraConfig"),
201
+ ("hubert", "HubertConfig"),
202
+ ("hunyuan_v1_dense", "HunYuanDenseV1Config"),
203
+ ("hunyuan_v1_moe", "HunYuanMoEV1Config"),
204
+ ("ibert", "IBertConfig"),
205
+ ("idefics", "IdeficsConfig"),
206
+ ("idefics2", "Idefics2Config"),
207
+ ("idefics3", "Idefics3Config"),
208
+ ("idefics3_vision", "Idefics3VisionConfig"),
209
+ ("ijepa", "IJepaConfig"),
210
+ ("imagegpt", "ImageGPTConfig"),
211
+ ("informer", "InformerConfig"),
212
+ ("instructblip", "InstructBlipConfig"),
213
+ ("instructblipvideo", "InstructBlipVideoConfig"),
214
+ ("internvl", "InternVLConfig"),
215
+ ("internvl_vision", "InternVLVisionConfig"),
216
+ ("jamba", "JambaConfig"),
217
+ ("janus", "JanusConfig"),
218
+ ("jetmoe", "JetMoeConfig"),
219
+ ("jukebox", "JukeboxConfig"),
220
+ ("kosmos-2", "Kosmos2Config"),
221
+ ("kosmos-2.5", "Kosmos2_5Config"),
222
+ ("kyutai_speech_to_text", "KyutaiSpeechToTextConfig"),
223
+ ("layoutlm", "LayoutLMConfig"),
224
+ ("layoutlmv2", "LayoutLMv2Config"),
225
+ ("layoutlmv3", "LayoutLMv3Config"),
226
+ ("led", "LEDConfig"),
227
+ ("levit", "LevitConfig"),
228
+ ("lfm2", "Lfm2Config"),
229
+ ("lfm2_vl", "Lfm2VlConfig"),
230
+ ("lightglue", "LightGlueConfig"),
231
+ ("lilt", "LiltConfig"),
232
+ ("llama", "LlamaConfig"),
233
+ ("llama4", "Llama4Config"),
234
+ ("llama4_text", "Llama4TextConfig"),
235
+ ("llava", "LlavaConfig"),
236
+ ("llava_next", "LlavaNextConfig"),
237
+ ("llava_next_video", "LlavaNextVideoConfig"),
238
+ ("llava_onevision", "LlavaOnevisionConfig"),
239
+ ("longcat_flash", "LongcatFlashConfig"),
240
+ ("longformer", "LongformerConfig"),
241
+ ("longt5", "LongT5Config"),
242
+ ("luke", "LukeConfig"),
243
+ ("lxmert", "LxmertConfig"),
244
+ ("m2m_100", "M2M100Config"),
245
+ ("mamba", "MambaConfig"),
246
+ ("mamba2", "Mamba2Config"),
247
+ ("marian", "MarianConfig"),
248
+ ("markuplm", "MarkupLMConfig"),
249
+ ("mask2former", "Mask2FormerConfig"),
250
+ ("maskformer", "MaskFormerConfig"),
251
+ ("maskformer-swin", "MaskFormerSwinConfig"),
252
+ ("mbart", "MBartConfig"),
253
+ ("mctct", "MCTCTConfig"),
254
+ ("mega", "MegaConfig"),
255
+ ("megatron-bert", "MegatronBertConfig"),
256
+ ("metaclip_2", "MetaClip2Config"),
257
+ ("mgp-str", "MgpstrConfig"),
258
+ ("mimi", "MimiConfig"),
259
+ ("minimax", "MiniMaxConfig"),
260
+ ("ministral", "MinistralConfig"),
261
+ ("mistral", "MistralConfig"),
262
+ ("mistral3", "Mistral3Config"),
263
+ ("mixtral", "MixtralConfig"),
264
+ ("mlcd", "MLCDVisionConfig"),
265
+ ("mllama", "MllamaConfig"),
266
+ ("mm-grounding-dino", "MMGroundingDinoConfig"),
267
+ ("mobilebert", "MobileBertConfig"),
268
+ ("mobilenet_v1", "MobileNetV1Config"),
269
+ ("mobilenet_v2", "MobileNetV2Config"),
270
+ ("mobilevit", "MobileViTConfig"),
271
+ ("mobilevitv2", "MobileViTV2Config"),
272
+ ("modernbert", "ModernBertConfig"),
273
+ ("modernbert-decoder", "ModernBertDecoderConfig"),
274
+ ("moonshine", "MoonshineConfig"),
275
+ ("moshi", "MoshiConfig"),
276
+ ("mpnet", "MPNetConfig"),
277
+ ("mpt", "MptConfig"),
278
+ ("mra", "MraConfig"),
279
+ ("mt5", "MT5Config"),
280
+ ("musicgen", "MusicgenConfig"),
281
+ ("musicgen_melody", "MusicgenMelodyConfig"),
282
+ ("mvp", "MvpConfig"),
283
+ ("nat", "NatConfig"),
284
+ ("nemotron", "NemotronConfig"),
285
+ ("nezha", "NezhaConfig"),
286
+ ("nllb-moe", "NllbMoeConfig"),
287
+ ("nougat", "VisionEncoderDecoderConfig"),
288
+ ("nystromformer", "NystromformerConfig"),
289
+ ("olmo", "OlmoConfig"),
290
+ ("olmo2", "Olmo2Config"),
291
+ ("olmo3", "Olmo3Config"),
292
+ ("olmoe", "OlmoeConfig"),
293
+ ("omdet-turbo", "OmDetTurboConfig"),
294
+ ("oneformer", "OneFormerConfig"),
295
+ ("open-llama", "OpenLlamaConfig"),
296
+ ("openai-gpt", "OpenAIGPTConfig"),
297
+ ("opt", "OPTConfig"),
298
+ ("ovis2", "Ovis2Config"),
299
+ ("owlv2", "Owlv2Config"),
300
+ ("owlvit", "OwlViTConfig"),
301
+ ("paligemma", "PaliGemmaConfig"),
302
+ ("parakeet_ctc", "ParakeetCTCConfig"),
303
+ ("parakeet_encoder", "ParakeetEncoderConfig"),
304
+ ("patchtsmixer", "PatchTSMixerConfig"),
305
+ ("patchtst", "PatchTSTConfig"),
306
+ ("pegasus", "PegasusConfig"),
307
+ ("pegasus_x", "PegasusXConfig"),
308
+ ("perceiver", "PerceiverConfig"),
309
+ ("perception_encoder", "TimmWrapperConfig"),
310
+ ("perception_lm", "PerceptionLMConfig"),
311
+ ("persimmon", "PersimmonConfig"),
312
+ ("phi", "PhiConfig"),
313
+ ("phi3", "Phi3Config"),
314
+ ("phi4_multimodal", "Phi4MultimodalConfig"),
315
+ ("phimoe", "PhimoeConfig"),
316
+ ("pix2struct", "Pix2StructConfig"),
317
+ ("pixtral", "PixtralVisionConfig"),
318
+ ("plbart", "PLBartConfig"),
319
+ ("poolformer", "PoolFormerConfig"),
320
+ ("pop2piano", "Pop2PianoConfig"),
321
+ ("prompt_depth_anything", "PromptDepthAnythingConfig"),
322
+ ("prophetnet", "ProphetNetConfig"),
323
+ ("pvt", "PvtConfig"),
324
+ ("pvt_v2", "PvtV2Config"),
325
+ ("qdqbert", "QDQBertConfig"),
326
+ ("qwen2", "Qwen2Config"),
327
+ ("qwen2_5_omni", "Qwen2_5OmniConfig"),
328
+ ("qwen2_5_vl", "Qwen2_5_VLConfig"),
329
+ ("qwen2_5_vl_text", "Qwen2_5_VLTextConfig"),
330
+ ("qwen2_audio", "Qwen2AudioConfig"),
331
+ ("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"),
332
+ ("qwen2_moe", "Qwen2MoeConfig"),
333
+ ("qwen2_vl", "Qwen2VLConfig"),
334
+ ("qwen2_vl_text", "Qwen2VLTextConfig"),
335
+ ("qwen3", "Qwen3Config"),
336
+ ("qwen3_moe", "Qwen3MoeConfig"),
337
+ ("qwen3_next", "Qwen3NextConfig"),
338
+ ("qwen3_omni_moe", "Qwen3OmniMoeConfig"),
339
+ ("qwen3_vl", "Qwen3VLConfig"),
340
+ ("qwen3_vl_moe", "Qwen3VLMoeConfig"),
341
+ ("qwen3_vl_moe_text", "Qwen3VLMoeTextConfig"),
342
+ ("qwen3_vl_text", "Qwen3VLTextConfig"),
343
+ ("rag", "RagConfig"),
344
+ ("realm", "RealmConfig"),
345
+ ("recurrent_gemma", "RecurrentGemmaConfig"),
346
+ ("reformer", "ReformerConfig"),
347
+ ("regnet", "RegNetConfig"),
348
+ ("rembert", "RemBertConfig"),
349
+ ("resnet", "ResNetConfig"),
350
+ ("retribert", "RetriBertConfig"),
351
+ ("roberta", "RobertaConfig"),
352
+ ("roberta-prelayernorm", "RobertaPreLayerNormConfig"),
353
+ ("roc_bert", "RoCBertConfig"),
354
+ ("roformer", "RoFormerConfig"),
355
+ ("rt_detr", "RTDetrConfig"),
356
+ ("rt_detr_resnet", "RTDetrResNetConfig"),
357
+ ("rt_detr_v2", "RTDetrV2Config"),
358
+ ("rwkv", "RwkvConfig"),
359
+ ("sam", "SamConfig"),
360
+ ("sam2", "Sam2Config"),
361
+ ("sam2_hiera_det_model", "Sam2HieraDetConfig"),
362
+ ("sam2_video", "Sam2VideoConfig"),
363
+ ("sam2_vision_model", "Sam2VisionConfig"),
364
+ ("sam_hq", "SamHQConfig"),
365
+ ("sam_hq_vision_model", "SamHQVisionConfig"),
366
+ ("sam_vision_model", "SamVisionConfig"),
367
+ ("seamless_m4t", "SeamlessM4TConfig"),
368
+ ("seamless_m4t_v2", "SeamlessM4Tv2Config"),
369
+ ("seed_oss", "SeedOssConfig"),
370
+ ("segformer", "SegformerConfig"),
371
+ ("seggpt", "SegGptConfig"),
372
+ ("sew", "SEWConfig"),
373
+ ("sew-d", "SEWDConfig"),
374
+ ("shieldgemma2", "ShieldGemma2Config"),
375
+ ("siglip", "SiglipConfig"),
376
+ ("siglip2", "Siglip2Config"),
377
+ ("siglip2_vision_model", "Siglip2VisionConfig"),
378
+ ("siglip_vision_model", "SiglipVisionConfig"),
379
+ ("smollm3", "SmolLM3Config"),
380
+ ("smolvlm", "SmolVLMConfig"),
381
+ ("smolvlm_vision", "SmolVLMVisionConfig"),
382
+ ("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
383
+ ("speech_to_text", "Speech2TextConfig"),
384
+ ("speech_to_text_2", "Speech2Text2Config"),
385
+ ("speecht5", "SpeechT5Config"),
386
+ ("splinter", "SplinterConfig"),
387
+ ("squeezebert", "SqueezeBertConfig"),
388
+ ("stablelm", "StableLmConfig"),
389
+ ("starcoder2", "Starcoder2Config"),
390
+ ("superglue", "SuperGlueConfig"),
391
+ ("superpoint", "SuperPointConfig"),
392
+ ("swiftformer", "SwiftFormerConfig"),
393
+ ("swin", "SwinConfig"),
394
+ ("swin2sr", "Swin2SRConfig"),
395
+ ("swinv2", "Swinv2Config"),
396
+ ("switch_transformers", "SwitchTransformersConfig"),
397
+ ("t5", "T5Config"),
398
+ ("t5gemma", "T5GemmaConfig"),
399
+ ("table-transformer", "TableTransformerConfig"),
400
+ ("tapas", "TapasConfig"),
401
+ ("textnet", "TextNetConfig"),
402
+ ("time_series_transformer", "TimeSeriesTransformerConfig"),
403
+ ("timesfm", "TimesFmConfig"),
404
+ ("timesformer", "TimesformerConfig"),
405
+ ("timm_backbone", "TimmBackboneConfig"),
406
+ ("timm_wrapper", "TimmWrapperConfig"),
407
+ ("trajectory_transformer", "TrajectoryTransformerConfig"),
408
+ ("transfo-xl", "TransfoXLConfig"),
409
+ ("trocr", "TrOCRConfig"),
410
+ ("tvlt", "TvltConfig"),
411
+ ("tvp", "TvpConfig"),
412
+ ("udop", "UdopConfig"),
413
+ ("umt5", "UMT5Config"),
414
+ ("unispeech", "UniSpeechConfig"),
415
+ ("unispeech-sat", "UniSpeechSatConfig"),
416
+ ("univnet", "UnivNetConfig"),
417
+ ("upernet", "UperNetConfig"),
418
+ ("van", "VanConfig"),
419
+ ("vaultgemma", "VaultGemmaConfig"),
420
+ ("video_llava", "VideoLlavaConfig"),
421
+ ("videomae", "VideoMAEConfig"),
422
+ ("vilt", "ViltConfig"),
423
+ ("vipllava", "VipLlavaConfig"),
424
+ ("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
425
+ ("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
426
+ ("visual_bert", "VisualBertConfig"),
427
+ ("vit", "ViTConfig"),
428
+ ("vit_hybrid", "ViTHybridConfig"),
429
+ ("vit_mae", "ViTMAEConfig"),
430
+ ("vit_msn", "ViTMSNConfig"),
431
+ ("vitdet", "VitDetConfig"),
432
+ ("vitmatte", "VitMatteConfig"),
433
+ ("vitpose", "VitPoseConfig"),
434
+ ("vitpose_backbone", "VitPoseBackboneConfig"),
435
+ ("vits", "VitsConfig"),
436
+ ("vivit", "VivitConfig"),
437
+ ("vjepa2", "VJEPA2Config"),
438
+ ("voxtral", "VoxtralConfig"),
439
+ ("voxtral_encoder", "VoxtralEncoderConfig"),
440
+ ("wav2vec2", "Wav2Vec2Config"),
441
+ ("wav2vec2-bert", "Wav2Vec2BertConfig"),
442
+ ("wav2vec2-conformer", "Wav2Vec2ConformerConfig"),
443
+ ("wavlm", "WavLMConfig"),
444
+ ("whisper", "WhisperConfig"),
445
+ ("xclip", "XCLIPConfig"),
446
+ ("xcodec", "XcodecConfig"),
447
+ ("xglm", "XGLMConfig"),
448
+ ("xlm", "XLMConfig"),
449
+ ("xlm-prophetnet", "XLMProphetNetConfig"),
450
+ ("xlm-roberta", "XLMRobertaConfig"),
451
+ ("xlm-roberta-xl", "XLMRobertaXLConfig"),
452
+ ("xlnet", "XLNetConfig"),
453
+ ("xlstm", "xLSTMConfig"),
454
+ ("xmod", "XmodConfig"),
455
+ ("yolos", "YolosConfig"),
456
+ ("yoso", "YosoConfig"),
457
+ ("zamba", "ZambaConfig"),
458
+ ("zamba2", "Zamba2Config"),
459
+ ("zoedepth", "ZoeDepthConfig"),
460
+ ]
461
+ )
462
+
463
+
464
+ MODEL_NAMES_MAPPING = OrderedDict[str, str](
465
+ [
466
+ # Add full (and cased) model names here
467
+ ("aimv2", "AIMv2"),
468
+ ("aimv2_vision_model", "Aimv2VisionModel"),
469
+ ("albert", "ALBERT"),
470
+ ("align", "ALIGN"),
471
+ ("altclip", "AltCLIP"),
472
+ ("apertus", "Apertus"),
473
+ ("arcee", "Arcee"),
474
+ ("aria", "Aria"),
475
+ ("aria_text", "AriaText"),
476
+ ("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
477
+ ("autoformer", "Autoformer"),
478
+ ("aya_vision", "AyaVision"),
479
+ ("bamba", "Bamba"),
480
+ ("bark", "Bark"),
481
+ ("bart", "BART"),
482
+ ("barthez", "BARThez"),
483
+ ("bartpho", "BARTpho"),
484
+ ("beit", "BEiT"),
485
+ ("bert", "BERT"),
486
+ ("bert-generation", "Bert Generation"),
487
+ ("bert-japanese", "BertJapanese"),
488
+ ("bertweet", "BERTweet"),
489
+ ("big_bird", "BigBird"),
490
+ ("bigbird_pegasus", "BigBird-Pegasus"),
491
+ ("biogpt", "BioGpt"),
492
+ ("bit", "BiT"),
493
+ ("bitnet", "BitNet"),
494
+ ("blenderbot", "Blenderbot"),
495
+ ("blenderbot-small", "BlenderbotSmall"),
496
+ ("blip", "BLIP"),
497
+ ("blip-2", "BLIP-2"),
498
+ ("blip_2_qformer", "BLIP-2 QFormer"),
499
+ ("bloom", "BLOOM"),
500
+ ("blt", "Blt"),
501
+ ("bort", "BORT"),
502
+ ("bridgetower", "BridgeTower"),
503
+ ("bros", "BROS"),
504
+ ("byt5", "ByT5"),
505
+ ("camembert", "CamemBERT"),
506
+ ("canine", "CANINE"),
507
+ ("chameleon", "Chameleon"),
508
+ ("chinese_clip", "Chinese-CLIP"),
509
+ ("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
510
+ ("clap", "CLAP"),
511
+ ("clip", "CLIP"),
512
+ ("clip_text_model", "CLIPTextModel"),
513
+ ("clip_vision_model", "CLIPVisionModel"),
514
+ ("clipseg", "CLIPSeg"),
515
+ ("clvp", "CLVP"),
516
+ ("code_llama", "CodeLlama"),
517
+ ("codegen", "CodeGen"),
518
+ ("cohere", "Cohere"),
519
+ ("cohere2", "Cohere2"),
520
+ ("cohere2_vision", "Cohere2Vision"),
521
+ ("colpali", "ColPali"),
522
+ ("colqwen2", "ColQwen2"),
523
+ ("conditional_detr", "Conditional DETR"),
524
+ ("convbert", "ConvBERT"),
525
+ ("convnext", "ConvNeXT"),
526
+ ("convnextv2", "ConvNeXTV2"),
527
+ ("cpm", "CPM"),
528
+ ("cpmant", "CPM-Ant"),
529
+ ("csm", "CSM"),
530
+ ("ctrl", "CTRL"),
531
+ ("cvt", "CvT"),
532
+ ("d_fine", "D-FINE"),
533
+ ("dab-detr", "DAB-DETR"),
534
+ ("dac", "DAC"),
535
+ ("data2vec-audio", "Data2VecAudio"),
536
+ ("data2vec-text", "Data2VecText"),
537
+ ("data2vec-vision", "Data2VecVision"),
538
+ ("dbrx", "DBRX"),
539
+ ("deberta", "DeBERTa"),
540
+ ("deberta-v2", "DeBERTa-v2"),
541
+ ("decision_transformer", "Decision Transformer"),
542
+ ("deepseek_v2", "DeepSeek-V2"),
543
+ ("deepseek_v3", "DeepSeek-V3"),
544
+ ("deepseek_vl", "DeepseekVL"),
545
+ ("deepseek_vl_hybrid", "DeepseekVLHybrid"),
546
+ ("deformable_detr", "Deformable DETR"),
547
+ ("deit", "DeiT"),
548
+ ("deplot", "DePlot"),
549
+ ("depth_anything", "Depth Anything"),
550
+ ("depth_anything_v2", "Depth Anything V2"),
551
+ ("depth_pro", "DepthPro"),
552
+ ("deta", "DETA"),
553
+ ("detr", "DETR"),
554
+ ("dia", "Dia"),
555
+ ("dialogpt", "DialoGPT"),
556
+ ("diffllama", "DiffLlama"),
557
+ ("dinat", "DiNAT"),
558
+ ("dinov2", "DINOv2"),
559
+ ("dinov2_with_registers", "DINOv2 with Registers"),
560
+ ("dinov3_convnext", "DINOv3 ConvNext"),
561
+ ("dinov3_vit", "DINOv3 ViT"),
562
+ ("distilbert", "DistilBERT"),
563
+ ("dit", "DiT"),
564
+ ("doge", "Doge"),
565
+ ("donut-swin", "DonutSwin"),
566
+ ("dots1", "dots1"),
567
+ ("dpr", "DPR"),
568
+ ("dpt", "DPT"),
569
+ ("edgetam", "EdgeTAM"),
570
+ ("edgetam_video", "EdgeTamVideo"),
571
+ ("edgetam_vision_model", "EdgeTamVisionModel"),
572
+ ("efficientformer", "EfficientFormer"),
573
+ ("efficientloftr", "EfficientLoFTR"),
574
+ ("efficientnet", "EfficientNet"),
575
+ ("electra", "ELECTRA"),
576
+ ("emu3", "Emu3"),
577
+ ("encodec", "EnCodec"),
578
+ ("encoder-decoder", "Encoder decoder"),
579
+ ("eomt", "EoMT"),
580
+ ("ernie", "ERNIE"),
581
+ ("ernie4_5", "Ernie4_5"),
582
+ ("ernie4_5_moe", "Ernie4_5_MoE"),
583
+ ("ernie_m", "ErnieM"),
584
+ ("esm", "ESM"),
585
+ ("evolla", "Evolla"),
586
+ ("exaone4", "EXAONE-4.0"),
587
+ ("falcon", "Falcon"),
588
+ ("falcon3", "Falcon3"),
589
+ ("falcon_h1", "FalconH1"),
590
+ ("falcon_mamba", "FalconMamba"),
591
+ ("fastspeech2_conformer", "FastSpeech2Conformer"),
592
+ ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
593
+ ("flan-t5", "FLAN-T5"),
594
+ ("flan-ul2", "FLAN-UL2"),
595
+ ("flaubert", "FlauBERT"),
596
+ ("flava", "FLAVA"),
597
+ ("flex_olmo", "FlexOlmo"),
598
+ ("florence2", "Florence2"),
599
+ ("fnet", "FNet"),
600
+ ("focalnet", "FocalNet"),
601
+ ("fsmt", "FairSeq Machine-Translation"),
602
+ ("funnel", "Funnel Transformer"),
603
+ ("fuyu", "Fuyu"),
604
+ ("gemma", "Gemma"),
605
+ ("gemma2", "Gemma2"),
606
+ ("gemma3", "Gemma3ForConditionalGeneration"),
607
+ ("gemma3_text", "Gemma3ForCausalLM"),
608
+ ("gemma3n", "Gemma3nForConditionalGeneration"),
609
+ ("gemma3n_audio", "Gemma3nAudioEncoder"),
610
+ ("gemma3n_text", "Gemma3nForCausalLM"),
611
+ ("gemma3n_vision", "TimmWrapperModel"),
612
+ ("git", "GIT"),
613
+ ("glm", "GLM"),
614
+ ("glm4", "GLM4"),
615
+ ("glm4_moe", "Glm4MoE"),
616
+ ("glm4v", "GLM4V"),
617
+ ("glm4v_moe", "GLM4VMOE"),
618
+ ("glm4v_moe_text", "GLM4VMOE"),
619
+ ("glm4v_text", "GLM4V"),
620
+ ("glpn", "GLPN"),
621
+ ("got_ocr2", "GOT-OCR2"),
622
+ ("gpt-sw3", "GPT-Sw3"),
623
+ ("gpt2", "OpenAI GPT-2"),
624
+ ("gpt_bigcode", "GPTBigCode"),
625
+ ("gpt_neo", "GPT Neo"),
626
+ ("gpt_neox", "GPT NeoX"),
627
+ ("gpt_neox_japanese", "GPT NeoX Japanese"),
628
+ ("gpt_oss", "GptOss"),
629
+ ("gptj", "GPT-J"),
630
+ ("gptsan-japanese", "GPTSAN-japanese"),
631
+ ("granite", "Granite"),
632
+ ("granite_speech", "GraniteSpeech"),
633
+ ("granitemoe", "GraniteMoeMoe"),
634
+ ("granitemoehybrid", "GraniteMoeHybrid"),
635
+ ("granitemoeshared", "GraniteMoeSharedMoe"),
636
+ ("granitevision", "LLaVA-NeXT"),
637
+ ("graphormer", "Graphormer"),
638
+ ("grounding-dino", "Grounding DINO"),
639
+ ("groupvit", "GroupViT"),
640
+ ("helium", "Helium"),
641
+ ("herbert", "HerBERT"),
642
+ ("hgnet_v2", "HGNet-V2"),
643
+ ("hiera", "Hiera"),
644
+ ("hubert", "Hubert"),
645
+ ("hunyuan_v1_dense", "HunYuanDenseV1"),
646
+ ("hunyuan_v1_moe", "HunYuanMoeV1"),
647
+ ("ibert", "I-BERT"),
648
+ ("idefics", "IDEFICS"),
649
+ ("idefics2", "Idefics2"),
650
+ ("idefics3", "Idefics3"),
651
+ ("idefics3_vision", "Idefics3VisionTransformer"),
652
+ ("ijepa", "I-JEPA"),
653
+ ("imagegpt", "ImageGPT"),
654
+ ("informer", "Informer"),
655
+ ("instructblip", "InstructBLIP"),
656
+ ("instructblipvideo", "InstructBlipVideo"),
657
+ ("internvl", "InternVL"),
658
+ ("internvl_vision", "InternVLVision"),
659
+ ("jamba", "Jamba"),
660
+ ("janus", "Janus"),
661
+ ("jetmoe", "JetMoe"),
662
+ ("jukebox", "Jukebox"),
663
+ ("kosmos-2", "KOSMOS-2"),
664
+ ("kosmos-2.5", "KOSMOS-2.5"),
665
+ ("kyutai_speech_to_text", "KyutaiSpeechToText"),
666
+ ("layoutlm", "LayoutLM"),
667
+ ("layoutlmv2", "LayoutLMv2"),
668
+ ("layoutlmv3", "LayoutLMv3"),
669
+ ("layoutxlm", "LayoutXLM"),
670
+ ("led", "LED"),
671
+ ("levit", "LeViT"),
672
+ ("lfm2", "Lfm2"),
673
+ ("lfm2_vl", "Lfm2Vl"),
674
+ ("lightglue", "LightGlue"),
675
+ ("lilt", "LiLT"),
676
+ ("llama", "LLaMA"),
677
+ ("llama2", "Llama2"),
678
+ ("llama3", "Llama3"),
679
+ ("llama4", "Llama4"),
680
+ ("llama4_text", "Llama4ForCausalLM"),
681
+ ("llava", "LLaVa"),
682
+ ("llava_next", "LLaVA-NeXT"),
683
+ ("llava_next_video", "LLaVa-NeXT-Video"),
684
+ ("llava_onevision", "LLaVA-Onevision"),
685
+ ("longcat_flash", "LongCatFlash"),
686
+ ("longformer", "Longformer"),
687
+ ("longt5", "LongT5"),
688
+ ("luke", "LUKE"),
689
+ ("lxmert", "LXMERT"),
690
+ ("m2m_100", "M2M100"),
691
+ ("madlad-400", "MADLAD-400"),
692
+ ("mamba", "Mamba"),
693
+ ("mamba2", "mamba2"),
694
+ ("marian", "Marian"),
695
+ ("markuplm", "MarkupLM"),
696
+ ("mask2former", "Mask2Former"),
697
+ ("maskformer", "MaskFormer"),
698
+ ("maskformer-swin", "MaskFormerSwin"),
699
+ ("matcha", "MatCha"),
700
+ ("mbart", "mBART"),
701
+ ("mbart50", "mBART-50"),
702
+ ("mctct", "M-CTC-T"),
703
+ ("mega", "MEGA"),
704
+ ("megatron-bert", "Megatron-BERT"),
705
+ ("megatron_gpt2", "Megatron-GPT2"),
706
+ ("metaclip_2", "MetaCLIP 2"),
707
+ ("mgp-str", "MGP-STR"),
708
+ ("mimi", "Mimi"),
709
+ ("minimax", "MiniMax"),
710
+ ("ministral", "Ministral"),
711
+ ("mistral", "Mistral"),
712
+ ("mistral3", "Mistral3"),
713
+ ("mixtral", "Mixtral"),
714
+ ("mlcd", "MLCD"),
715
+ ("mllama", "Mllama"),
716
+ ("mluke", "mLUKE"),
717
+ ("mm-grounding-dino", "MM Grounding DINO"),
718
+ ("mms", "MMS"),
719
+ ("mobilebert", "MobileBERT"),
720
+ ("mobilenet_v1", "MobileNetV1"),
721
+ ("mobilenet_v2", "MobileNetV2"),
722
+ ("mobilevit", "MobileViT"),
723
+ ("mobilevitv2", "MobileViTV2"),
724
+ ("modernbert", "ModernBERT"),
725
+ ("modernbert-decoder", "ModernBertDecoder"),
726
+ ("moonshine", "Moonshine"),
727
+ ("moshi", "Moshi"),
728
+ ("mpnet", "MPNet"),
729
+ ("mpt", "MPT"),
730
+ ("mra", "MRA"),
731
+ ("mt5", "MT5"),
732
+ ("musicgen", "MusicGen"),
733
+ ("musicgen_melody", "MusicGen Melody"),
734
+ ("mvp", "MVP"),
735
+ ("myt5", "myt5"),
736
+ ("nat", "NAT"),
737
+ ("nemotron", "Nemotron"),
738
+ ("nezha", "Nezha"),
739
+ ("nllb", "NLLB"),
740
+ ("nllb-moe", "NLLB-MOE"),
741
+ ("nougat", "Nougat"),
742
+ ("nystromformer", "Nyströmformer"),
743
+ ("olmo", "OLMo"),
744
+ ("olmo2", "OLMo2"),
745
+ ("olmo3", "Olmo3"),
746
+ ("olmoe", "OLMoE"),
747
+ ("omdet-turbo", "OmDet-Turbo"),
748
+ ("oneformer", "OneFormer"),
749
+ ("open-llama", "OpenLlama"),
750
+ ("openai-gpt", "OpenAI GPT"),
751
+ ("opt", "OPT"),
752
+ ("ovis2", "Ovis2"),
753
+ ("owlv2", "OWLv2"),
754
+ ("owlvit", "OWL-ViT"),
755
+ ("paligemma", "PaliGemma"),
756
+ ("parakeet", "Parakeet"),
757
+ ("parakeet_ctc", "Parakeet"),
758
+ ("parakeet_encoder", "ParakeetEncoder"),
759
+ ("patchtsmixer", "PatchTSMixer"),
760
+ ("patchtst", "PatchTST"),
761
+ ("pegasus", "Pegasus"),
762
+ ("pegasus_x", "PEGASUS-X"),
763
+ ("perceiver", "Perceiver"),
764
+ ("perception_encoder", "PerceptionEncoder"),
765
+ ("perception_lm", "PerceptionLM"),
766
+ ("persimmon", "Persimmon"),
767
+ ("phi", "Phi"),
768
+ ("phi3", "Phi3"),
769
+ ("phi4_multimodal", "Phi4Multimodal"),
770
+ ("phimoe", "Phimoe"),
771
+ ("phobert", "PhoBERT"),
772
+ ("pix2struct", "Pix2Struct"),
773
+ ("pixtral", "Pixtral"),
774
+ ("plbart", "PLBart"),
775
+ ("poolformer", "PoolFormer"),
776
+ ("pop2piano", "Pop2Piano"),
777
+ ("prompt_depth_anything", "PromptDepthAnything"),
778
+ ("prophetnet", "ProphetNet"),
779
+ ("pvt", "PVT"),
780
+ ("pvt_v2", "PVTv2"),
781
+ ("qdqbert", "QDQBert"),
782
+ ("qwen2", "Qwen2"),
783
+ ("qwen2_5_omni", "Qwen2_5Omni"),
784
+ ("qwen2_5_vl", "Qwen2_5_VL"),
785
+ ("qwen2_5_vl_text", "Qwen2_5_VL"),
786
+ ("qwen2_audio", "Qwen2Audio"),
787
+ ("qwen2_audio_encoder", "Qwen2AudioEncoder"),
788
+ ("qwen2_moe", "Qwen2MoE"),
789
+ ("qwen2_vl", "Qwen2VL"),
790
+ ("qwen2_vl_text", "Qwen2VL"),
791
+ ("qwen3", "Qwen3"),
792
+ ("qwen3_moe", "Qwen3MoE"),
793
+ ("qwen3_next", "Qwen3Next"),
794
+ ("qwen3_omni_moe", "Qwen3OmniMoE"),
795
+ ("qwen3_vl", "Qwen3VL"),
796
+ ("qwen3_vl_moe", "Qwen3VLMoe"),
797
+ ("qwen3_vl_moe_text", "Qwen3VLMoe"),
798
+ ("qwen3_vl_text", "Qwen3VL"),
799
+ ("rag", "RAG"),
800
+ ("realm", "REALM"),
801
+ ("recurrent_gemma", "RecurrentGemma"),
802
+ ("reformer", "Reformer"),
803
+ ("regnet", "RegNet"),
804
+ ("rembert", "RemBERT"),
805
+ ("resnet", "ResNet"),
806
+ ("retribert", "RetriBERT"),
807
+ ("roberta", "RoBERTa"),
808
+ ("roberta-prelayernorm", "RoBERTa-PreLayerNorm"),
809
+ ("roc_bert", "RoCBert"),
810
+ ("roformer", "RoFormer"),
811
+ ("rt_detr", "RT-DETR"),
812
+ ("rt_detr_resnet", "RT-DETR-ResNet"),
813
+ ("rt_detr_v2", "RT-DETRv2"),
814
+ ("rwkv", "RWKV"),
815
+ ("sam", "SAM"),
816
+ ("sam2", "SAM2"),
817
+ ("sam2_hiera_det_model", "Sam2HieraDetModel"),
818
+ ("sam2_video", "Sam2VideoModel"),
819
+ ("sam2_vision_model", "Sam2VisionModel"),
820
+ ("sam_hq", "SAM-HQ"),
821
+ ("sam_hq_vision_model", "SamHQVisionModel"),
822
+ ("sam_vision_model", "SamVisionModel"),
823
+ ("seamless_m4t", "SeamlessM4T"),
824
+ ("seamless_m4t_v2", "SeamlessM4Tv2"),
825
+ ("seed_oss", "SeedOss"),
826
+ ("segformer", "SegFormer"),
827
+ ("seggpt", "SegGPT"),
828
+ ("sew", "SEW"),
829
+ ("sew-d", "SEW-D"),
830
+ ("shieldgemma2", "Shieldgemma2"),
831
+ ("siglip", "SigLIP"),
832
+ ("siglip2", "SigLIP2"),
833
+ ("siglip2_vision_model", "Siglip2VisionModel"),
834
+ ("siglip_vision_model", "SiglipVisionModel"),
835
+ ("smollm3", "SmolLM3"),
836
+ ("smolvlm", "SmolVLM"),
837
+ ("smolvlm_vision", "SmolVLMVisionTransformer"),
838
+ ("speech-encoder-decoder", "Speech Encoder decoder"),
839
+ ("speech_to_text", "Speech2Text"),
840
+ ("speech_to_text_2", "Speech2Text2"),
841
+ ("speecht5", "SpeechT5"),
842
+ ("splinter", "Splinter"),
843
+ ("squeezebert", "SqueezeBERT"),
844
+ ("stablelm", "StableLm"),
845
+ ("starcoder2", "Starcoder2"),
846
+ ("superglue", "SuperGlue"),
847
+ ("superpoint", "SuperPoint"),
848
+ ("swiftformer", "SwiftFormer"),
849
+ ("swin", "Swin Transformer"),
850
+ ("swin2sr", "Swin2SR"),
851
+ ("swinv2", "Swin Transformer V2"),
852
+ ("switch_transformers", "SwitchTransformers"),
853
+ ("t5", "T5"),
854
+ ("t5gemma", "T5Gemma"),
855
+ ("t5v1.1", "T5v1.1"),
856
+ ("table-transformer", "Table Transformer"),
857
+ ("tapas", "TAPAS"),
858
+ ("tapex", "TAPEX"),
859
+ ("textnet", "TextNet"),
860
+ ("time_series_transformer", "Time Series Transformer"),
861
+ ("timesfm", "TimesFm"),
862
+ ("timesformer", "TimeSformer"),
863
+ ("timm_backbone", "TimmBackbone"),
864
+ ("timm_wrapper", "TimmWrapperModel"),
865
+ ("trajectory_transformer", "Trajectory Transformer"),
866
+ ("transfo-xl", "Transformer-XL"),
867
+ ("trocr", "TrOCR"),
868
+ ("tvlt", "TVLT"),
869
+ ("tvp", "TVP"),
870
+ ("udop", "UDOP"),
871
+ ("ul2", "UL2"),
872
+ ("umt5", "UMT5"),
873
+ ("unispeech", "UniSpeech"),
874
+ ("unispeech-sat", "UniSpeechSat"),
875
+ ("univnet", "UnivNet"),
876
+ ("upernet", "UPerNet"),
877
+ ("van", "VAN"),
878
+ ("vaultgemma", "VaultGemma"),
879
+ ("video_llava", "VideoLlava"),
880
+ ("videomae", "VideoMAE"),
881
+ ("vilt", "ViLT"),
882
+ ("vipllava", "VipLlava"),
883
+ ("vision-encoder-decoder", "Vision Encoder decoder"),
884
+ ("vision-text-dual-encoder", "VisionTextDualEncoder"),
885
+ ("visual_bert", "VisualBERT"),
886
+ ("vit", "ViT"),
887
+ ("vit_hybrid", "ViT Hybrid"),
888
+ ("vit_mae", "ViTMAE"),
889
+ ("vit_msn", "ViTMSN"),
890
+ ("vitdet", "VitDet"),
891
+ ("vitmatte", "ViTMatte"),
892
+ ("vitpose", "ViTPose"),
893
+ ("vitpose_backbone", "ViTPoseBackbone"),
894
+ ("vits", "VITS"),
895
+ ("vivit", "ViViT"),
896
+ ("vjepa2", "VJEPA2Model"),
897
+ ("voxtral", "Voxtral"),
898
+ ("voxtral_encoder", "Voxtral Encoder"),
899
+ ("wav2vec2", "Wav2Vec2"),
900
+ ("wav2vec2-bert", "Wav2Vec2-BERT"),
901
+ ("wav2vec2-conformer", "Wav2Vec2-Conformer"),
902
+ ("wav2vec2_phoneme", "Wav2Vec2Phoneme"),
903
+ ("wavlm", "WavLM"),
904
+ ("whisper", "Whisper"),
905
+ ("xclip", "X-CLIP"),
906
+ ("xcodec", "X-CODEC"),
907
+ ("xglm", "XGLM"),
908
+ ("xlm", "XLM"),
909
+ ("xlm-prophetnet", "XLM-ProphetNet"),
910
+ ("xlm-roberta", "XLM-RoBERTa"),
911
+ ("xlm-roberta-xl", "XLM-RoBERTa-XL"),
912
+ ("xlm-v", "XLM-V"),
913
+ ("xlnet", "XLNet"),
914
+ ("xls_r", "XLS-R"),
915
+ ("xlsr_wav2vec2", "XLSR-Wav2Vec2"),
916
+ ("xlstm", "xLSTM"),
917
+ ("xmod", "X-MOD"),
918
+ ("yolos", "YOLOS"),
919
+ ("yoso", "YOSO"),
920
+ ("zamba", "Zamba"),
921
+ ("zamba2", "Zamba2"),
922
+ ("zoedepth", "ZoeDepth"),
923
+ ]
924
+ )
925
+
926
+ # This is tied to the processing `-` -> `_` in `model_type_to_module_name`. For example, instead of putting
927
+ # `transfo-xl` (as in `CONFIG_MAPPING_NAMES`), we should use `transfo_xl`.
928
+ DEPRECATED_MODELS = [
929
+ "bort",
930
+ "deta",
931
+ "efficientformer",
932
+ "ernie_m",
933
+ "gptsan_japanese",
934
+ "graphormer",
935
+ "jukebox",
936
+ "mctct",
937
+ "mega",
938
+ "mmbt",
939
+ "nat",
940
+ "nezha",
941
+ "open_llama",
942
+ "qdqbert",
943
+ "realm",
944
+ "retribert",
945
+ "speech_to_text_2",
946
+ "tapex",
947
+ "trajectory_transformer",
948
+ "transfo_xl",
949
+ "tvlt",
950
+ "van",
951
+ "vit_hybrid",
952
+ "xlm_prophetnet",
953
+ ]
954
+
955
+ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
956
+ [
957
+ ("openai-gpt", "openai"),
958
+ ("data2vec-audio", "data2vec"),
959
+ ("data2vec-text", "data2vec"),
960
+ ("data2vec-vision", "data2vec"),
961
+ ("donut-swin", "donut"),
962
+ ("kosmos-2", "kosmos2"),
963
+ ("kosmos-2.5", "kosmos2_5"),
964
+ ("maskformer-swin", "maskformer"),
965
+ ("xclip", "x_clip"),
966
+ ("clip_vision_model", "clip"),
967
+ ("qwen2_audio_encoder", "qwen2_audio"),
968
+ ("voxtral_encoder", "voxtral"),
969
+ ("clip_text_model", "clip"),
970
+ ("aria_text", "aria"),
971
+ ("gemma3_text", "gemma3"),
972
+ ("gemma3n_audio", "gemma3n"),
973
+ ("gemma3n_text", "gemma3n"),
974
+ ("gemma3n_vision", "gemma3n"),
975
+ ("glm4v_text", "glm4v"),
976
+ ("glm4v_moe_text", "glm4v_moe"),
977
+ ("idefics3_vision", "idefics3"),
978
+ ("siglip_vision_model", "siglip"),
979
+ ("siglip2_vision_model", "siglip2"),
980
+ ("aimv2_vision_model", "aimv2"),
981
+ ("smolvlm_vision", "smolvlm"),
982
+ ("chinese_clip_vision_model", "chinese_clip"),
983
+ ("rt_detr_resnet", "rt_detr"),
984
+ ("granitevision", "llava_next"),
985
+ ("internvl_vision", "internvl"),
986
+ ("qwen2_5_vl_text", "qwen2_5_vl"),
987
+ ("qwen2_vl_text", "qwen2_vl"),
988
+ ("qwen3_vl_text", "qwen3_vl"),
989
+ ("qwen3_vl_moe_text", "qwen3_vl_moe"),
990
+ ("sam_vision_model", "sam"),
991
+ ("sam2_vision_model", "sam2"),
992
+ ("edgetam_vision_model", "edgetam"),
993
+ ("sam2_hiera_det_model", "sam2"),
994
+ ("sam_hq_vision_model", "sam_hq"),
995
+ ("llama4_text", "llama4"),
996
+ ("blip_2_qformer", "blip_2"),
997
+ ("fastspeech2_conformer_with_hifigan", "fastspeech2_conformer"),
998
+ ("perception_encoder", "perception_lm"),
999
+ ("parakeet_encoder", "parakeet"),
1000
+ ("parakeet_ctc", "parakeet"),
1001
+ ]
1002
+ )
1003
+
1004
+
1005
+ def model_type_to_module_name(key) -> str:
1006
+ """Converts a config key to the corresponding module."""
1007
+ # Special treatment
1008
+ if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
1009
+ key = SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]
1010
+
1011
+ if key in DEPRECATED_MODELS:
1012
+ key = f"deprecated.{key}"
1013
+ return key
1014
+
1015
+ key = key.replace("-", "_")
1016
+ if key in DEPRECATED_MODELS:
1017
+ key = f"deprecated.{key}"
1018
+
1019
+ return key
1020
+
1021
+
1022
+ def config_class_to_model_type(config) -> Union[str, None]:
1023
+ """Converts a config class name to the corresponding model type"""
1024
+ for key, cls in CONFIG_MAPPING_NAMES.items():
1025
+ if cls == config:
1026
+ return key
1027
+ # if key not found check in extra content
1028
+ for key, cls in CONFIG_MAPPING._extra_content.items():
1029
+ if cls.__name__ == config:
1030
+ return key
1031
+ return None
1032
+
1033
+
1034
+ class _LazyConfigMapping(OrderedDict[str, type[PretrainedConfig]]):
1035
+ """
1036
+ A dictionary that lazily load its values when they are requested.
1037
+ """
1038
+
1039
+ def __init__(self, mapping) -> None:
1040
+ self._mapping = mapping
1041
+ self._extra_content = {}
1042
+ self._modules = {}
1043
+
1044
+ def __getitem__(self, key: str) -> type[PretrainedConfig]:
1045
+ if key in self._extra_content:
1046
+ return self._extra_content[key]
1047
+ if key not in self._mapping:
1048
+ raise KeyError(key)
1049
+ value = self._mapping[key]
1050
+ module_name = model_type_to_module_name(key)
1051
+ if module_name not in self._modules:
1052
+ self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
1053
+ if hasattr(self._modules[module_name], value):
1054
+ return getattr(self._modules[module_name], value)
1055
+
1056
+ # Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the
1057
+ # object at the top level.
1058
+ transformers_module = importlib.import_module("transformers")
1059
+ return getattr(transformers_module, value)
1060
+
1061
+ def keys(self) -> list[str]:
1062
+ return list(self._mapping.keys()) + list(self._extra_content.keys())
1063
+
1064
+ def values(self) -> list[type[PretrainedConfig]]:
1065
+ return [self[k] for k in self._mapping] + list(self._extra_content.values())
1066
+
1067
+ def items(self) -> list[tuple[str, type[PretrainedConfig]]]:
1068
+ return [(k, self[k]) for k in self._mapping] + list(self._extra_content.items())
1069
+
1070
+ def __iter__(self) -> Iterator[str]:
1071
+ return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
1072
+
1073
+ def __contains__(self, item: object) -> bool:
1074
+ return item in self._mapping or item in self._extra_content
1075
+
1076
+ def register(self, key: str, value: type[PretrainedConfig], exist_ok=False) -> None:
1077
+ """
1078
+ Register a new configuration in this mapping.
1079
+ """
1080
+ if key in self._mapping and not exist_ok:
1081
+ raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
1082
+ self._extra_content[key] = value
1083
+
1084
+
1085
+ CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
1086
+
1087
+
1088
+ class _LazyLoadAllMappings(OrderedDict[str, str]):
1089
+ """
1090
+ A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
1091
+ etc.)
1092
+
1093
+ Args:
1094
+ mapping: The mapping to load.
1095
+ """
1096
+
1097
+ def __init__(self, mapping):
1098
+ self._mapping = mapping
1099
+ self._initialized = False
1100
+ self._data = {}
1101
+
1102
+ def _initialize(self):
1103
+ if self._initialized:
1104
+ return
1105
+
1106
+ for model_type, map_name in self._mapping.items():
1107
+ module_name = model_type_to_module_name(model_type)
1108
+ module = importlib.import_module(f".{module_name}", "transformers.models")
1109
+ mapping = getattr(module, map_name)
1110
+ self._data.update(mapping)
1111
+
1112
+ self._initialized = True
1113
+
1114
+ def __getitem__(self, key):
1115
+ self._initialize()
1116
+ return self._data[key]
1117
+
1118
+ def keys(self) -> KeysView[str]:
1119
+ self._initialize()
1120
+ return self._data.keys()
1121
+
1122
+ def values(self) -> ValuesView[str]:
1123
+ self._initialize()
1124
+ return self._data.values()
1125
+
1126
+ def items(self) -> KeysView[str]:
1127
+ self._initialize()
1128
+ return self._data.keys()
1129
+
1130
+ def __iter__(self) -> Iterator[str]:
1131
+ self._initialize()
1132
+ return iter(self._data)
1133
+
1134
+ def __contains__(self, item: object) -> bool:
1135
+ self._initialize()
1136
+ return item in self._data
1137
+
1138
+
1139
+ def _get_class_name(model_class: Union[str, list[str]]):
1140
+ if isinstance(model_class, (list, tuple)):
1141
+ return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
1142
+ return f"[`{model_class}`]"
1143
+
1144
+
1145
+ def _list_model_options(indent, config_to_class=None, use_model_types=True):
1146
+ if config_to_class is None and not use_model_types:
1147
+ raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
1148
+ if use_model_types:
1149
+ if config_to_class is None:
1150
+ model_type_to_name = {model_type: f"[`{config}`]" for model_type, config in CONFIG_MAPPING_NAMES.items()}
1151
+ else:
1152
+ model_type_to_name = {
1153
+ model_type: _get_class_name(model_class)
1154
+ for model_type, model_class in config_to_class.items()
1155
+ if model_type in MODEL_NAMES_MAPPING
1156
+ }
1157
+ lines = [
1158
+ f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
1159
+ for model_type in sorted(model_type_to_name.keys())
1160
+ ]
1161
+ else:
1162
+ config_to_name = {
1163
+ CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
1164
+ for config, clas in config_to_class.items()
1165
+ if config in CONFIG_MAPPING_NAMES
1166
+ }
1167
+ config_to_model_name = {
1168
+ config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
1169
+ }
1170
+ lines = [
1171
+ f"{indent}- [`{config_name}`] configuration class:"
1172
+ f" {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
1173
+ for config_name in sorted(config_to_name.keys())
1174
+ ]
1175
+ return "\n".join(lines)
1176
+
1177
+
1178
+ def replace_list_option_in_docstrings(
1179
+ config_to_class=None, use_model_types: bool = True
1180
+ ) -> Callable[[_CallableT], _CallableT]:
1181
+ def docstring_decorator(fn):
1182
+ docstrings = fn.__doc__
1183
+ if docstrings is None:
1184
+ # Example: -OO
1185
+ return fn
1186
+ lines = docstrings.split("\n")
1187
+ i = 0
1188
+ while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None:
1189
+ i += 1
1190
+ if i < len(lines):
1191
+ indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0]
1192
+ if use_model_types:
1193
+ indent = f"{indent} "
1194
+ lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types)
1195
+ docstrings = "\n".join(lines)
1196
+ else:
1197
+ raise ValueError(
1198
+ f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current"
1199
+ f" docstring is:\n{docstrings}"
1200
+ )
1201
+ fn.__doc__ = docstrings
1202
+ return fn
1203
+
1204
+ return docstring_decorator
1205
+
1206
+
1207
+ class AutoConfig:
1208
+ r"""
1209
+ This is a generic configuration class that will be instantiated as one of the configuration classes of the library
1210
+ when created with the [`~AutoConfig.from_pretrained`] class method.
1211
+
1212
+ This class cannot be instantiated directly using `__init__()` (throws an error).
1213
+ """
1214
+
1215
+ def __init__(self) -> None:
1216
+ raise OSError(
1217
+ "AutoConfig is designed to be instantiated "
1218
+ "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
1219
+ )
1220
+
1221
+ @classmethod
1222
+ def for_model(cls, model_type: str, *args, **kwargs) -> PretrainedConfig:
1223
+ if model_type in CONFIG_MAPPING:
1224
+ config_class = CONFIG_MAPPING[model_type]
1225
+ return config_class(*args, **kwargs)
1226
+ raise ValueError(
1227
+ f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
1228
+ )
1229
+
1230
+ @classmethod
1231
+ @replace_list_option_in_docstrings()
1232
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], **kwargs):
1233
+ r"""
1234
+ Instantiate one of the configuration classes of the library from a pretrained model configuration.
1235
+
1236
+ The configuration class to instantiate is selected based on the `model_type` property of the config object that
1237
+ is loaded, or when it's missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
1238
+
1239
+ List options
1240
+
1241
+ Args:
1242
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
1243
+ Can be either:
1244
+
1245
+ - A string, the *model id* of a pretrained model configuration hosted inside a model repo on
1246
+ huggingface.co.
1247
+ - A path to a *directory* containing a configuration file saved using the
1248
+ [`~PretrainedConfig.save_pretrained`] method, or the [`~PreTrainedModel.save_pretrained`] method,
1249
+ e.g., `./my_model_directory/`.
1250
+ - A path or url to a saved configuration JSON *file*, e.g.,
1251
+ `./my_model_directory/configuration.json`.
1252
+ cache_dir (`str` or `os.PathLike`, *optional*):
1253
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
1254
+ standard cache should not be used.
1255
+ force_download (`bool`, *optional*, defaults to `False`):
1256
+ Whether or not to force the (re-)download the model weights and configuration files and override the
1257
+ cached versions if they exist.
1258
+ resume_download:
1259
+ Deprecated and ignored. All downloads are now resumed by default when possible.
1260
+ Will be removed in v5 of Transformers.
1261
+ proxies (`dict[str, str]`, *optional*):
1262
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
1263
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1264
+ revision (`str`, *optional*, defaults to `"main"`):
1265
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
1266
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
1267
+ identifier allowed by git.
1268
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
1269
+ If `False`, then this function returns just the final configuration object.
1270
+
1271
+ If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
1272
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
1273
+ part of `kwargs` which has not been used to update `config` and is otherwise ignored.
1274
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
1275
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
1276
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
1277
+ execute code present on the Hub on your local machine.
1278
+ kwargs(additional keyword arguments, *optional*):
1279
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
1280
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
1281
+ by the `return_unused_kwargs` keyword parameter.
1282
+
1283
+ Examples:
1284
+
1285
+ ```python
1286
+ >>> from transformers import AutoConfig
1287
+
1288
+ >>> # Download configuration from huggingface.co and cache.
1289
+ >>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")
1290
+
1291
+ >>> # Download configuration from huggingface.co (user-uploaded) and cache.
1292
+ >>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")
1293
+
1294
+ >>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
1295
+ >>> config = AutoConfig.from_pretrained("./test/bert_saved_model/")
1296
+
1297
+ >>> # Load a specific configuration file.
1298
+ >>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")
1299
+
1300
+ >>> # Change some config attributes when loading a pretrained config.
1301
+ >>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
1302
+ >>> config.output_attentions
1303
+ True
1304
+
1305
+ >>> config, unused_kwargs = AutoConfig.from_pretrained(
1306
+ ... "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
1307
+ ... )
1308
+ >>> config.output_attentions
1309
+ True
1310
+
1311
+ >>> unused_kwargs
1312
+ {'foo': False}
1313
+ ```
1314
+ """
1315
+ use_auth_token = kwargs.pop("use_auth_token", None)
1316
+ if use_auth_token is not None:
1317
+ warnings.warn(
1318
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
1319
+ FutureWarning,
1320
+ )
1321
+ if kwargs.get("token") is not None:
1322
+ raise ValueError(
1323
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
1324
+ )
1325
+ kwargs["token"] = use_auth_token
1326
+
1327
+ kwargs["_from_auto"] = True
1328
+ kwargs["name_or_path"] = pretrained_model_name_or_path
1329
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
1330
+ code_revision = kwargs.pop("code_revision", None)
1331
+
1332
+ config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
1333
+ has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
1334
+ has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING
1335
+ if has_remote_code:
1336
+ class_ref = config_dict["auto_map"]["AutoConfig"]
1337
+ if "--" in class_ref:
1338
+ upstream_repo = class_ref.split("--")[0]
1339
+ else:
1340
+ upstream_repo = None
1341
+ trust_remote_code = resolve_trust_remote_code(
1342
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
1343
+ )
1344
+
1345
+ if has_remote_code and trust_remote_code:
1346
+ config_class = get_class_from_dynamic_module(
1347
+ class_ref, pretrained_model_name_or_path, code_revision=code_revision, **kwargs
1348
+ )
1349
+ config_class.register_for_auto_class()
1350
+ return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
1351
+ elif "model_type" in config_dict:
1352
+ # Apply heuristic: if model_type is mistral but layer_types is present, treat as ministral
1353
+ if config_dict["model_type"] == "mistral" and "layer_types" in config_dict:
1354
+ logger.info(
1355
+ "Detected mistral model with layer_types, treating as ministral for alternating attention compatibility. "
1356
+ )
1357
+ config_dict["model_type"] = "ministral"
1358
+
1359
+ try:
1360
+ config_class = CONFIG_MAPPING[config_dict["model_type"]]
1361
+ except KeyError:
1362
+ raise ValueError(
1363
+ f"The checkpoint you are trying to load has model type `{config_dict['model_type']}` "
1364
+ "but Transformers does not recognize this architecture. This could be because of an "
1365
+ "issue with the checkpoint, or because your version of Transformers is out of date.\n\n"
1366
+ "You can update Transformers with the command `pip install --upgrade transformers`. If this "
1367
+ "does not work, and the checkpoint is very new, then there may not be a release version "
1368
+ "that supports this model yet. In this case, you can get the most up-to-date code by installing "
1369
+ "Transformers from source with the command "
1370
+ "`pip install git+https://github.com/huggingface/transformers.git`"
1371
+ )
1372
+ return config_class.from_dict(config_dict, **unused_kwargs)
1373
+ else:
1374
+ # Fallback: use pattern matching on the string.
1375
+ # We go from longer names to shorter names to catch roberta before bert (for instance)
1376
+ for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True):
1377
+ if pattern in str(pretrained_model_name_or_path):
1378
+ return CONFIG_MAPPING[pattern].from_dict(config_dict, **unused_kwargs)
1379
+
1380
+ raise ValueError(
1381
+ f"Unrecognized model in {pretrained_model_name_or_path}. "
1382
+ f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
1383
+ f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
1384
+ )
1385
+
1386
+ @staticmethod
1387
+ def register(model_type, config, exist_ok=False) -> None:
1388
+ """
1389
+ Register a new configuration for this class.
1390
+
1391
+ Args:
1392
+ model_type (`str`): The model type like "bert" or "gpt".
1393
+ config ([`PretrainedConfig`]): The config to register.
1394
+ """
1395
+ if issubclass(config, PretrainedConfig) and config.model_type != model_type:
1396
+ raise ValueError(
1397
+ "The config you are passing has a `model_type` attribute that is not consistent with the model type "
1398
+ f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
1399
+ "match!"
1400
+ )
1401
+ CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok)
1402
+
1403
+
1404
+ __all__ = ["CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"]
venv/lib/python3.13/site-packages/transformers/models/auto/feature_extraction_auto.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """AutoFeatureExtractor class."""
16
+
17
+ import importlib
18
+ import json
19
+ import os
20
+ import warnings
21
+ from collections import OrderedDict
22
+ from typing import Optional, Union
23
+
24
+ # Build the list of all feature extractors
25
+ from ...configuration_utils import PretrainedConfig
26
+ from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
27
+ from ...feature_extraction_utils import FeatureExtractionMixin
28
+ from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, cached_file, logging
29
+ from .auto_factory import _LazyAutoMapping
30
+ from .configuration_auto import (
31
+ CONFIG_MAPPING_NAMES,
32
+ AutoConfig,
33
+ model_type_to_module_name,
34
+ replace_list_option_in_docstrings,
35
+ )
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
41
+ [
42
+ ("audio-spectrogram-transformer", "ASTFeatureExtractor"),
43
+ ("beit", "BeitFeatureExtractor"),
44
+ ("chinese_clip", "ChineseCLIPFeatureExtractor"),
45
+ ("clap", "ClapFeatureExtractor"),
46
+ ("clip", "CLIPFeatureExtractor"),
47
+ ("clipseg", "ViTFeatureExtractor"),
48
+ ("clvp", "ClvpFeatureExtractor"),
49
+ ("conditional_detr", "ConditionalDetrFeatureExtractor"),
50
+ ("convnext", "ConvNextFeatureExtractor"),
51
+ ("cvt", "ConvNextFeatureExtractor"),
52
+ ("dac", "DacFeatureExtractor"),
53
+ ("data2vec-audio", "Wav2Vec2FeatureExtractor"),
54
+ ("data2vec-vision", "BeitFeatureExtractor"),
55
+ ("deformable_detr", "DeformableDetrFeatureExtractor"),
56
+ ("deit", "DeiTFeatureExtractor"),
57
+ ("detr", "DetrFeatureExtractor"),
58
+ ("dia", "DiaFeatureExtractor"),
59
+ ("dinat", "ViTFeatureExtractor"),
60
+ ("donut-swin", "DonutFeatureExtractor"),
61
+ ("dpt", "DPTFeatureExtractor"),
62
+ ("encodec", "EncodecFeatureExtractor"),
63
+ ("flava", "FlavaFeatureExtractor"),
64
+ ("gemma3n", "Gemma3nAudioFeatureExtractor"),
65
+ ("glpn", "GLPNFeatureExtractor"),
66
+ ("granite_speech", "GraniteSpeechFeatureExtractor"),
67
+ ("groupvit", "CLIPFeatureExtractor"),
68
+ ("hubert", "Wav2Vec2FeatureExtractor"),
69
+ ("imagegpt", "ImageGPTFeatureExtractor"),
70
+ ("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"),
71
+ ("layoutlmv2", "LayoutLMv2FeatureExtractor"),
72
+ ("layoutlmv3", "LayoutLMv3FeatureExtractor"),
73
+ ("levit", "LevitFeatureExtractor"),
74
+ ("maskformer", "MaskFormerFeatureExtractor"),
75
+ ("mctct", "MCTCTFeatureExtractor"),
76
+ ("mimi", "EncodecFeatureExtractor"),
77
+ ("mobilenet_v1", "MobileNetV1FeatureExtractor"),
78
+ ("mobilenet_v2", "MobileNetV2FeatureExtractor"),
79
+ ("mobilevit", "MobileViTFeatureExtractor"),
80
+ ("moonshine", "Wav2Vec2FeatureExtractor"),
81
+ ("moshi", "EncodecFeatureExtractor"),
82
+ ("nat", "ViTFeatureExtractor"),
83
+ ("owlvit", "OwlViTFeatureExtractor"),
84
+ ("parakeet_ctc", "ParakeetFeatureExtractor"),
85
+ ("parakeet_encoder", "ParakeetFeatureExtractor"),
86
+ ("perceiver", "PerceiverFeatureExtractor"),
87
+ ("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
88
+ ("poolformer", "PoolFormerFeatureExtractor"),
89
+ ("pop2piano", "Pop2PianoFeatureExtractor"),
90
+ ("regnet", "ConvNextFeatureExtractor"),
91
+ ("resnet", "ConvNextFeatureExtractor"),
92
+ ("seamless_m4t", "SeamlessM4TFeatureExtractor"),
93
+ ("seamless_m4t_v2", "SeamlessM4TFeatureExtractor"),
94
+ ("segformer", "SegformerFeatureExtractor"),
95
+ ("sew", "Wav2Vec2FeatureExtractor"),
96
+ ("sew-d", "Wav2Vec2FeatureExtractor"),
97
+ ("speech_to_text", "Speech2TextFeatureExtractor"),
98
+ ("speecht5", "SpeechT5FeatureExtractor"),
99
+ ("swiftformer", "ViTFeatureExtractor"),
100
+ ("swin", "ViTFeatureExtractor"),
101
+ ("swinv2", "ViTFeatureExtractor"),
102
+ ("table-transformer", "DetrFeatureExtractor"),
103
+ ("timesformer", "VideoMAEFeatureExtractor"),
104
+ ("tvlt", "TvltFeatureExtractor"),
105
+ ("unispeech", "Wav2Vec2FeatureExtractor"),
106
+ ("unispeech-sat", "Wav2Vec2FeatureExtractor"),
107
+ ("univnet", "UnivNetFeatureExtractor"),
108
+ ("van", "ConvNextFeatureExtractor"),
109
+ ("videomae", "VideoMAEFeatureExtractor"),
110
+ ("vilt", "ViltFeatureExtractor"),
111
+ ("vit", "ViTFeatureExtractor"),
112
+ ("vit_mae", "ViTFeatureExtractor"),
113
+ ("vit_msn", "ViTFeatureExtractor"),
114
+ ("wav2vec2", "Wav2Vec2FeatureExtractor"),
115
+ ("wav2vec2-bert", "Wav2Vec2FeatureExtractor"),
116
+ ("wav2vec2-conformer", "Wav2Vec2FeatureExtractor"),
117
+ ("wavlm", "Wav2Vec2FeatureExtractor"),
118
+ ("whisper", "WhisperFeatureExtractor"),
119
+ ("xclip", "CLIPFeatureExtractor"),
120
+ ("xcodec", "DacFeatureExtractor"),
121
+ ("yolos", "YolosFeatureExtractor"),
122
+ ]
123
+ )
124
+
125
+ FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
126
+
127
+
128
+ def feature_extractor_class_from_name(class_name: str):
129
+ for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
130
+ if class_name in extractors:
131
+ module_name = model_type_to_module_name(module_name)
132
+
133
+ module = importlib.import_module(f".{module_name}", "transformers.models")
134
+ try:
135
+ return getattr(module, class_name)
136
+ except AttributeError:
137
+ continue
138
+
139
+ for extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.values():
140
+ if getattr(extractor, "__name__", None) == class_name:
141
+ return extractor
142
+
143
+ # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
144
+ # init and we return the proper dummy to get an appropriate error message.
145
+ main_module = importlib.import_module("transformers")
146
+ if hasattr(main_module, class_name):
147
+ return getattr(main_module, class_name)
148
+
149
+ return None
150
+
151
+
152
+ def get_feature_extractor_config(
153
+ pretrained_model_name_or_path: Union[str, os.PathLike],
154
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
155
+ force_download: bool = False,
156
+ resume_download: Optional[bool] = None,
157
+ proxies: Optional[dict[str, str]] = None,
158
+ token: Optional[Union[bool, str]] = None,
159
+ revision: Optional[str] = None,
160
+ local_files_only: bool = False,
161
+ **kwargs,
162
+ ):
163
+ """
164
+ Loads the tokenizer configuration from a pretrained model tokenizer configuration.
165
+
166
+ Args:
167
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
168
+ This can be either:
169
+
170
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
171
+ huggingface.co.
172
+ - a path to a *directory* containing a configuration file saved using the
173
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
174
+
175
+ cache_dir (`str` or `os.PathLike`, *optional*):
176
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
177
+ cache should not be used.
178
+ force_download (`bool`, *optional*, defaults to `False`):
179
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
180
+ exist.
181
+ resume_download:
182
+ Deprecated and ignored. All downloads are now resumed by default when possible.
183
+ Will be removed in v5 of Transformers.
184
+ proxies (`dict[str, str]`, *optional*):
185
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
186
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
187
+ token (`str` or *bool*, *optional*):
188
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
189
+ when running `hf auth login` (stored in `~/.huggingface`).
190
+ revision (`str`, *optional*, defaults to `"main"`):
191
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
192
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
193
+ identifier allowed by git.
194
+ local_files_only (`bool`, *optional*, defaults to `False`):
195
+ If `True`, will only try to load the tokenizer configuration from local files.
196
+
197
+ <Tip>
198
+
199
+ Passing `token=True` is required when you want to use a private model.
200
+
201
+ </Tip>
202
+
203
+ Returns:
204
+ `Dict`: The configuration of the tokenizer.
205
+
206
+ Examples:
207
+
208
+ ```python
209
+ # Download configuration from huggingface.co and cache.
210
+ tokenizer_config = get_tokenizer_config("google-bert/bert-base-uncased")
211
+ # This model does not have a tokenizer config so the result will be an empty dict.
212
+ tokenizer_config = get_tokenizer_config("FacebookAI/xlm-roberta-base")
213
+
214
+ # Save a pretrained tokenizer locally and you can reload its config
215
+ from transformers import AutoTokenizer
216
+
217
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
218
+ tokenizer.save_pretrained("tokenizer-test")
219
+ tokenizer_config = get_tokenizer_config("tokenizer-test")
220
+ ```"""
221
+ use_auth_token = kwargs.pop("use_auth_token", None)
222
+ if use_auth_token is not None:
223
+ warnings.warn(
224
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
225
+ FutureWarning,
226
+ )
227
+ if token is not None:
228
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
229
+ token = use_auth_token
230
+
231
+ resolved_config_file = cached_file(
232
+ pretrained_model_name_or_path,
233
+ FEATURE_EXTRACTOR_NAME,
234
+ cache_dir=cache_dir,
235
+ force_download=force_download,
236
+ resume_download=resume_download,
237
+ proxies=proxies,
238
+ token=token,
239
+ revision=revision,
240
+ local_files_only=local_files_only,
241
+ _raise_exceptions_for_gated_repo=False,
242
+ _raise_exceptions_for_missing_entries=False,
243
+ _raise_exceptions_for_connection_errors=False,
244
+ )
245
+ if resolved_config_file is None:
246
+ logger.info(
247
+ "Could not locate the feature extractor configuration file, will try to use the model config instead."
248
+ )
249
+ return {}
250
+
251
+ with open(resolved_config_file, encoding="utf-8") as reader:
252
+ return json.load(reader)
253
+
254
+
255
+ class AutoFeatureExtractor:
256
+ r"""
257
+ This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the
258
+ library when created with the [`AutoFeatureExtractor.from_pretrained`] class method.
259
+
260
+ This class cannot be instantiated directly using `__init__()` (throws an error).
261
+ """
262
+
263
+ def __init__(self):
264
+ raise OSError(
265
+ "AutoFeatureExtractor is designed to be instantiated "
266
+ "using the `AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)` method."
267
+ )
268
+
269
+ @classmethod
270
+ @replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING_NAMES)
271
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
272
+ r"""
273
+ Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary.
274
+
275
+ The feature extractor class to instantiate is selected based on the `model_type` property of the config object
276
+ (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
277
+ missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
278
+
279
+ List options
280
+
281
+ Params:
282
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
283
+ This can be either:
284
+
285
+ - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
286
+ huggingface.co.
287
+ - a path to a *directory* containing a feature extractor file saved using the
288
+ [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g.,
289
+ `./my_model_directory/`.
290
+ - a path or url to a saved feature extractor JSON *file*, e.g.,
291
+ `./my_model_directory/preprocessor_config.json`.
292
+ cache_dir (`str` or `os.PathLike`, *optional*):
293
+ Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
294
+ standard cache should not be used.
295
+ force_download (`bool`, *optional*, defaults to `False`):
296
+ Whether or not to force to (re-)download the feature extractor files and override the cached versions
297
+ if they exist.
298
+ resume_download:
299
+ Deprecated and ignored. All downloads are now resumed by default when possible.
300
+ Will be removed in v5 of Transformers.
301
+ proxies (`dict[str, str]`, *optional*):
302
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
303
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
304
+ token (`str` or *bool*, *optional*):
305
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
306
+ when running `hf auth login` (stored in `~/.huggingface`).
307
+ revision (`str`, *optional*, defaults to `"main"`):
308
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
309
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
310
+ identifier allowed by git.
311
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
312
+ If `False`, then this function returns just the final feature extractor object. If `True`, then this
313
+ functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
314
+ consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
315
+ `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
316
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
317
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
318
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
319
+ execute code present on the Hub on your local machine.
320
+ kwargs (`dict[str, Any]`, *optional*):
321
+ The values in kwargs of any keys which are feature extractor attributes will be used to override the
322
+ loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
323
+ controlled by the `return_unused_kwargs` keyword parameter.
324
+
325
+ <Tip>
326
+
327
+ Passing `token=True` is required when you want to use a private model.
328
+
329
+ </Tip>
330
+
331
+ Examples:
332
+
333
+ ```python
334
+ >>> from transformers import AutoFeatureExtractor
335
+
336
+ >>> # Download feature extractor from huggingface.co and cache.
337
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
338
+
339
+ >>> # If feature extractor files are in a directory (e.g. feature extractor was saved using *save_pretrained('./test/saved_model/')*)
340
+ >>> # feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
341
+ ```"""
342
+ use_auth_token = kwargs.pop("use_auth_token", None)
343
+ if use_auth_token is not None:
344
+ warnings.warn(
345
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
346
+ FutureWarning,
347
+ )
348
+ if kwargs.get("token") is not None:
349
+ raise ValueError(
350
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
351
+ )
352
+ kwargs["token"] = use_auth_token
353
+
354
+ config = kwargs.pop("config", None)
355
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
356
+ kwargs["_from_auto"] = True
357
+
358
+ config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
359
+ feature_extractor_class = config_dict.get("feature_extractor_type", None)
360
+ feature_extractor_auto_map = None
361
+ if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
362
+ feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
363
+
364
+ # If we don't find the feature extractor class in the feature extractor config, let's try the model config.
365
+ if feature_extractor_class is None and feature_extractor_auto_map is None:
366
+ if not isinstance(config, PretrainedConfig):
367
+ config = AutoConfig.from_pretrained(
368
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
369
+ )
370
+ # It could be in `config.feature_extractor_type``
371
+ feature_extractor_class = getattr(config, "feature_extractor_type", None)
372
+ if hasattr(config, "auto_map") and "AutoFeatureExtractor" in config.auto_map:
373
+ feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"]
374
+
375
+ if feature_extractor_class is not None:
376
+ feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
377
+
378
+ has_remote_code = feature_extractor_auto_map is not None
379
+ has_local_code = feature_extractor_class is not None or type(config) in FEATURE_EXTRACTOR_MAPPING
380
+ if has_remote_code:
381
+ if "--" in feature_extractor_auto_map:
382
+ upstream_repo = feature_extractor_auto_map.split("--")[0]
383
+ else:
384
+ upstream_repo = None
385
+ trust_remote_code = resolve_trust_remote_code(
386
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
387
+ )
388
+
389
+ if has_remote_code and trust_remote_code:
390
+ feature_extractor_class = get_class_from_dynamic_module(
391
+ feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs
392
+ )
393
+ _ = kwargs.pop("code_revision", None)
394
+ feature_extractor_class.register_for_auto_class()
395
+ return feature_extractor_class.from_dict(config_dict, **kwargs)
396
+ elif feature_extractor_class is not None:
397
+ return feature_extractor_class.from_dict(config_dict, **kwargs)
398
+ # Last try: we use the FEATURE_EXTRACTOR_MAPPING.
399
+ elif type(config) in FEATURE_EXTRACTOR_MAPPING:
400
+ feature_extractor_class = FEATURE_EXTRACTOR_MAPPING[type(config)]
401
+ return feature_extractor_class.from_dict(config_dict, **kwargs)
402
+
403
+ raise ValueError(
404
+ f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a "
405
+ f"`feature_extractor_type` key in its {FEATURE_EXTRACTOR_NAME} of {CONFIG_NAME}, or one of the following "
406
+ f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES)}"
407
+ )
408
+
409
+ @staticmethod
410
+ def register(config_class, feature_extractor_class, exist_ok=False):
411
+ """
412
+ Register a new feature extractor for this class.
413
+
414
+ Args:
415
+ config_class ([`PretrainedConfig`]):
416
+ The configuration corresponding to the model to register.
417
+ feature_extractor_class ([`FeatureExtractorMixin`]): The feature extractor to register.
418
+ """
419
+ FEATURE_EXTRACTOR_MAPPING.register(config_class, feature_extractor_class, exist_ok=exist_ok)
420
+
421
+
422
+ __all__ = ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"]
venv/lib/python3.13/site-packages/transformers/models/auto/image_processing_auto.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """AutoImageProcessor class."""
16
+
17
+ import importlib
18
+ import json
19
+ import os
20
+ import warnings
21
+ from collections import OrderedDict
22
+ from typing import TYPE_CHECKING, Optional, Union
23
+
24
+ # Build the list of all image processors
25
+ from ...configuration_utils import PretrainedConfig
26
+ from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
27
+ from ...image_processing_utils import ImageProcessingMixin
28
+ from ...image_processing_utils_fast import BaseImageProcessorFast
29
+ from ...utils import (
30
+ CONFIG_NAME,
31
+ IMAGE_PROCESSOR_NAME,
32
+ cached_file,
33
+ is_timm_config_dict,
34
+ is_timm_local_checkpoint,
35
+ is_torchvision_available,
36
+ is_vision_available,
37
+ logging,
38
+ )
39
+ from ...utils.import_utils import requires
40
+ from .auto_factory import _LazyAutoMapping
41
+ from .configuration_auto import (
42
+ CONFIG_MAPPING_NAMES,
43
+ AutoConfig,
44
+ model_type_to_module_name,
45
+ replace_list_option_in_docstrings,
46
+ )
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ FORCE_FAST_IMAGE_PROCESSOR = ["Qwen2VLImageProcessor"]
53
+
54
+
55
+ if TYPE_CHECKING:
56
+ # This significantly improves completion suggestion performance when
57
+ # the transformers package is used with Microsoft's Pylance language server.
58
+ IMAGE_PROCESSOR_MAPPING_NAMES: OrderedDict[str, tuple[Optional[str], Optional[str]]] = OrderedDict()
59
+ else:
60
+ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
61
+ [
62
+ ("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
63
+ ("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
64
+ ("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
65
+ ("aria", ("AriaImageProcessor", None)),
66
+ ("beit", ("BeitImageProcessor", "BeitImageProcessorFast")),
67
+ ("bit", ("BitImageProcessor", "BitImageProcessorFast")),
68
+ ("blip", ("BlipImageProcessor", "BlipImageProcessorFast")),
69
+ ("blip-2", ("BlipImageProcessor", "BlipImageProcessorFast")),
70
+ ("bridgetower", ("BridgeTowerImageProcessor", "BridgeTowerImageProcessorFast")),
71
+ ("chameleon", ("ChameleonImageProcessor", "ChameleonImageProcessorFast")),
72
+ ("chinese_clip", ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast")),
73
+ ("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
74
+ ("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")),
75
+ ("cohere2_vision", (None, "Cohere2VisionImageProcessorFast")),
76
+ ("conditional_detr", ("ConditionalDetrImageProcessor", "ConditionalDetrImageProcessorFast")),
77
+ ("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
78
+ ("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
79
+ ("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
80
+ ("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")),
81
+ ("deepseek_vl", ("DeepseekVLImageProcessor", "DeepseekVLImageProcessorFast")),
82
+ ("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor", "DeepseekVLHybridImageProcessorFast")),
83
+ ("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
84
+ ("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
85
+ ("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")),
86
+ ("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")),
87
+ ("deta", ("DetaImageProcessor", None)),
88
+ ("detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
89
+ ("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")),
90
+ ("dinov2", ("BitImageProcessor", "BitImageProcessorFast")),
91
+ ("dinov3_vit", (None, "DINOv3ViTImageProcessorFast")),
92
+ ("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")),
93
+ ("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")),
94
+ ("edgetam", (None, "Sam2ImageProcessorFast")),
95
+ ("efficientformer", ("EfficientFormerImageProcessor", None)),
96
+ ("efficientloftr", ("EfficientLoFTRImageProcessor", "EfficientLoFTRImageProcessorFast")),
97
+ ("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
98
+ ("eomt", ("EomtImageProcessor", "EomtImageProcessorFast")),
99
+ ("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")),
100
+ ("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
101
+ ("fuyu", ("FuyuImageProcessor", None)),
102
+ ("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
103
+ ("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
104
+ ("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
105
+ ("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")),
106
+ ("glpn", ("GLPNImageProcessor", None)),
107
+ ("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
108
+ ("grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
109
+ ("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
110
+ ("hiera", ("BitImageProcessor", "BitImageProcessorFast")),
111
+ ("idefics", ("IdeficsImageProcessor", None)),
112
+ ("idefics2", ("Idefics2ImageProcessor", "Idefics2ImageProcessorFast")),
113
+ ("idefics3", ("Idefics3ImageProcessor", "Idefics3ImageProcessorFast")),
114
+ ("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
115
+ ("imagegpt", ("ImageGPTImageProcessor", "ImageGPTImageProcessorFast")),
116
+ ("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
117
+ ("instructblipvideo", ("InstructBlipVideoImageProcessor", None)),
118
+ ("janus", ("JanusImageProcessor", "JanusImageProcessorFast")),
119
+ ("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
120
+ ("kosmos-2.5", ("Kosmos2_5ImageProcessor", "Kosmos2_5ImageProcessorFast")),
121
+ ("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
122
+ ("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
123
+ ("levit", ("LevitImageProcessor", "LevitImageProcessorFast")),
124
+ ("lfm2_vl", (None, "Lfm2VlImageProcessorFast")),
125
+ ("lightglue", ("LightGlueImageProcessor", None)),
126
+ ("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
127
+ ("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
128
+ ("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
129
+ ("llava_next_video", ("LlavaNextVideoImageProcessor", None)),
130
+ ("llava_onevision", ("LlavaOnevisionImageProcessor", "LlavaOnevisionImageProcessorFast")),
131
+ ("mask2former", ("Mask2FormerImageProcessor", "Mask2FormerImageProcessorFast")),
132
+ ("maskformer", ("MaskFormerImageProcessor", "MaskFormerImageProcessorFast")),
133
+ ("metaclip_2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
134
+ ("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
135
+ ("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
136
+ ("mlcd", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
137
+ ("mllama", ("MllamaImageProcessor", None)),
138
+ ("mm-grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
139
+ ("mobilenet_v1", ("MobileNetV1ImageProcessor", "MobileNetV1ImageProcessorFast")),
140
+ ("mobilenet_v2", ("MobileNetV2ImageProcessor", "MobileNetV2ImageProcessorFast")),
141
+ ("mobilevit", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")),
142
+ ("mobilevitv2", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")),
143
+ ("nat", ("ViTImageProcessor", "ViTImageProcessorFast")),
144
+ ("nougat", ("NougatImageProcessor", "NougatImageProcessorFast")),
145
+ ("oneformer", ("OneFormerImageProcessor", "OneFormerImageProcessorFast")),
146
+ ("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")),
147
+ ("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")),
148
+ ("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
149
+ ("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
150
+ ("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
151
+ ("perception_lm", (None, "PerceptionLMImageProcessorFast")),
152
+ ("phi4_multimodal", (None, "Phi4MultimodalImageProcessorFast")),
153
+ ("pix2struct", ("Pix2StructImageProcessor", None)),
154
+ ("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
155
+ ("poolformer", ("PoolFormerImageProcessor", "PoolFormerImageProcessorFast")),
156
+ ("prompt_depth_anything", ("PromptDepthAnythingImageProcessor", "PromptDepthAnythingImageProcessorFast")),
157
+ ("pvt", ("PvtImageProcessor", "PvtImageProcessorFast")),
158
+ ("pvt_v2", ("PvtImageProcessor", "PvtImageProcessorFast")),
159
+ ("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
160
+ ("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
161
+ ("qwen3_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
162
+ ("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
163
+ ("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
164
+ ("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
165
+ ("sam", ("SamImageProcessor", "SamImageProcessorFast")),
166
+ ("sam2", (None, "Sam2ImageProcessorFast")),
167
+ ("sam_hq", ("SamImageProcessor", "SamImageProcessorFast")),
168
+ ("segformer", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
169
+ ("seggpt", ("SegGptImageProcessor", None)),
170
+ ("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
171
+ ("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
172
+ ("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
173
+ ("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")),
174
+ ("superglue", ("SuperGlueImageProcessor", None)),
175
+ ("superpoint", ("SuperPointImageProcessor", "SuperPointImageProcessorFast")),
176
+ ("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
177
+ ("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
178
+ ("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")),
179
+ ("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
180
+ ("table-transformer", ("DetrImageProcessor", "DetrImageProcessorFast")),
181
+ ("textnet", ("TextNetImageProcessor", "TextNetImageProcessorFast")),
182
+ ("timesformer", ("VideoMAEImageProcessor", None)),
183
+ ("timm_wrapper", ("TimmWrapperImageProcessor", None)),
184
+ ("tvlt", ("TvltImageProcessor", None)),
185
+ ("tvp", ("TvpImageProcessor", "TvpImageProcessorFast")),
186
+ ("udop", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
187
+ ("upernet", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
188
+ ("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
189
+ ("videomae", ("VideoMAEImageProcessor", None)),
190
+ ("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")),
191
+ ("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
192
+ ("vit", ("ViTImageProcessor", "ViTImageProcessorFast")),
193
+ ("vit_hybrid", ("ViTHybridImageProcessor", None)),
194
+ ("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast")),
195
+ ("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast")),
196
+ ("vitmatte", ("VitMatteImageProcessor", "VitMatteImageProcessorFast")),
197
+ ("xclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
198
+ ("yolos", ("YolosImageProcessor", "YolosImageProcessorFast")),
199
+ ("zoedepth", ("ZoeDepthImageProcessor", "ZoeDepthImageProcessorFast")),
200
+ ]
201
+ )
202
+
203
+ # Override to None if the packages are not available
204
+ for model_type, (slow_class, fast_class) in IMAGE_PROCESSOR_MAPPING_NAMES.items():
205
+ if not is_vision_available():
206
+ slow_class = None
207
+ if not is_torchvision_available():
208
+ fast_class = None
209
+
210
+ IMAGE_PROCESSOR_MAPPING_NAMES[model_type] = (slow_class, fast_class)
211
+
212
+ IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
213
+
214
+
215
+ def get_image_processor_class_from_name(class_name: str):
216
+ if class_name == "BaseImageProcessorFast":
217
+ return BaseImageProcessorFast
218
+
219
+ for module_name, extractors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
220
+ if class_name in extractors:
221
+ module_name = model_type_to_module_name(module_name)
222
+
223
+ module = importlib.import_module(f".{module_name}", "transformers.models")
224
+ try:
225
+ return getattr(module, class_name)
226
+ except AttributeError:
227
+ continue
228
+
229
+ for extractors in IMAGE_PROCESSOR_MAPPING._extra_content.values():
230
+ for extractor in extractors:
231
+ if getattr(extractor, "__name__", None) == class_name:
232
+ return extractor
233
+
234
+ # We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main
235
+ # init and we return the proper dummy to get an appropriate error message.
236
+ main_module = importlib.import_module("transformers")
237
+ if hasattr(main_module, class_name):
238
+ return getattr(main_module, class_name)
239
+
240
+ return None
241
+
242
+
243
+ def get_image_processor_config(
244
+ pretrained_model_name_or_path: Union[str, os.PathLike],
245
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
246
+ force_download: bool = False,
247
+ resume_download: Optional[bool] = None,
248
+ proxies: Optional[dict[str, str]] = None,
249
+ token: Optional[Union[bool, str]] = None,
250
+ revision: Optional[str] = None,
251
+ local_files_only: bool = False,
252
+ **kwargs,
253
+ ):
254
+ """
255
+ Loads the image processor configuration from a pretrained model image processor configuration.
256
+
257
+ Args:
258
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
259
+ This can be either:
260
+
261
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
262
+ huggingface.co.
263
+ - a path to a *directory* containing a configuration file saved using the
264
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
265
+
266
+ cache_dir (`str` or `os.PathLike`, *optional*):
267
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
268
+ cache should not be used.
269
+ force_download (`bool`, *optional*, defaults to `False`):
270
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
271
+ exist.
272
+ resume_download:
273
+ Deprecated and ignored. All downloads are now resumed by default when possible.
274
+ Will be removed in v5 of Transformers.
275
+ proxies (`dict[str, str]`, *optional*):
276
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
277
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
278
+ token (`str` or *bool*, *optional*):
279
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
280
+ when running `hf auth login` (stored in `~/.huggingface`).
281
+ revision (`str`, *optional*, defaults to `"main"`):
282
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
283
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
284
+ identifier allowed by git.
285
+ local_files_only (`bool`, *optional*, defaults to `False`):
286
+ If `True`, will only try to load the image processor configuration from local files.
287
+
288
+ <Tip>
289
+
290
+ Passing `token=True` is required when you want to use a private model.
291
+
292
+ </Tip>
293
+
294
+ Returns:
295
+ `Dict`: The configuration of the image processor.
296
+
297
+ Examples:
298
+
299
+ ```python
300
+ # Download configuration from huggingface.co and cache.
301
+ image_processor_config = get_image_processor_config("google-bert/bert-base-uncased")
302
+ # This model does not have a image processor config so the result will be an empty dict.
303
+ image_processor_config = get_image_processor_config("FacebookAI/xlm-roberta-base")
304
+
305
+ # Save a pretrained image processor locally and you can reload its config
306
+ from transformers import AutoTokenizer
307
+
308
+ image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
309
+ image_processor.save_pretrained("image-processor-test")
310
+ image_processor_config = get_image_processor_config("image-processor-test")
311
+ ```"""
312
+ use_auth_token = kwargs.pop("use_auth_token", None)
313
+ if use_auth_token is not None:
314
+ warnings.warn(
315
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
316
+ FutureWarning,
317
+ )
318
+ if token is not None:
319
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
320
+ token = use_auth_token
321
+
322
+ resolved_config_file = cached_file(
323
+ pretrained_model_name_or_path,
324
+ IMAGE_PROCESSOR_NAME,
325
+ cache_dir=cache_dir,
326
+ force_download=force_download,
327
+ resume_download=resume_download,
328
+ proxies=proxies,
329
+ token=token,
330
+ revision=revision,
331
+ local_files_only=local_files_only,
332
+ _raise_exceptions_for_gated_repo=False,
333
+ _raise_exceptions_for_missing_entries=False,
334
+ _raise_exceptions_for_connection_errors=False,
335
+ )
336
+ if resolved_config_file is None:
337
+ logger.info(
338
+ "Could not locate the image processor configuration file, will try to use the model config instead."
339
+ )
340
+ return {}
341
+
342
+ with open(resolved_config_file, encoding="utf-8") as reader:
343
+ return json.load(reader)
344
+
345
+
346
+ def _warning_fast_image_processor_available(fast_class):
347
+ logger.warning(
348
+ f"Fast image processor class {fast_class} is available for this model. "
349
+ "Using slow image processor class. To use the fast image processor class set `use_fast=True`."
350
+ )
351
+
352
+
353
+ @requires(backends=("vision",))
354
+ class AutoImageProcessor:
355
+ r"""
356
+ This is a generic image processor class that will be instantiated as one of the image processor classes of the
357
+ library when created with the [`AutoImageProcessor.from_pretrained`] class method.
358
+
359
+ This class cannot be instantiated directly using `__init__()` (throws an error).
360
+ """
361
+
362
+ def __init__(self):
363
+ raise OSError(
364
+ "AutoImageProcessor is designed to be instantiated "
365
+ "using the `AutoImageProcessor.from_pretrained(pretrained_model_name_or_path)` method."
366
+ )
367
+
368
+ @classmethod
369
+ @replace_list_option_in_docstrings(IMAGE_PROCESSOR_MAPPING_NAMES)
370
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
371
+ r"""
372
+ Instantiate one of the image processor classes of the library from a pretrained model vocabulary.
373
+
374
+ The image processor class to instantiate is selected based on the `model_type` property of the config object
375
+ (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
376
+ missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
377
+
378
+ List options
379
+
380
+ Params:
381
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
382
+ This can be either:
383
+
384
+ - a string, the *model id* of a pretrained image_processor hosted inside a model repo on
385
+ huggingface.co.
386
+ - a path to a *directory* containing a image processor file saved using the
387
+ [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
388
+ `./my_model_directory/`.
389
+ - a path or url to a saved image processor JSON *file*, e.g.,
390
+ `./my_model_directory/preprocessor_config.json`.
391
+ cache_dir (`str` or `os.PathLike`, *optional*):
392
+ Path to a directory in which a downloaded pretrained model image processor should be cached if the
393
+ standard cache should not be used.
394
+ force_download (`bool`, *optional*, defaults to `False`):
395
+ Whether or not to force to (re-)download the image processor files and override the cached versions if
396
+ they exist.
397
+ resume_download:
398
+ Deprecated and ignored. All downloads are now resumed by default when possible.
399
+ Will be removed in v5 of Transformers.
400
+ proxies (`dict[str, str]`, *optional*):
401
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
402
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
403
+ token (`str` or *bool*, *optional*):
404
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
405
+ when running `hf auth login` (stored in `~/.huggingface`).
406
+ revision (`str`, *optional*, defaults to `"main"`):
407
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
408
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
409
+ identifier allowed by git.
410
+ use_fast (`bool`, *optional*, defaults to `False`):
411
+ Use a fast torchvision-base image processor if it is supported for a given model.
412
+ If a fast image processor is not available for a given model, a normal numpy-based image processor
413
+ is returned instead.
414
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
415
+ If `False`, then this function returns just the final image processor object. If `True`, then this
416
+ functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
417
+ consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
418
+ `kwargs` which has not been used to update `image_processor` and is otherwise ignored.
419
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
420
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
421
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
422
+ execute code present on the Hub on your local machine.
423
+ image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
424
+ The name of the file in the model directory to use for the image processor config.
425
+ kwargs (`dict[str, Any]`, *optional*):
426
+ The values in kwargs of any keys which are image processor attributes will be used to override the
427
+ loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
428
+ controlled by the `return_unused_kwargs` keyword parameter.
429
+
430
+ <Tip>
431
+
432
+ Passing `token=True` is required when you want to use a private model.
433
+
434
+ </Tip>
435
+
436
+ Examples:
437
+
438
+ ```python
439
+ >>> from transformers import AutoImageProcessor
440
+
441
+ >>> # Download image processor from huggingface.co and cache.
442
+ >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
443
+
444
+ >>> # If image processor files are in a directory (e.g. image processor was saved using *save_pretrained('./test/saved_model/')*)
445
+ >>> # image_processor = AutoImageProcessor.from_pretrained("./test/saved_model/")
446
+ ```"""
447
+ use_auth_token = kwargs.pop("use_auth_token", None)
448
+ if use_auth_token is not None:
449
+ warnings.warn(
450
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
451
+ FutureWarning,
452
+ )
453
+ if kwargs.get("token") is not None:
454
+ raise ValueError(
455
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
456
+ )
457
+ kwargs["token"] = use_auth_token
458
+
459
+ config = kwargs.pop("config", None)
460
+ # TODO: @yoni, change in v4.48 (use_fast set to True by default)
461
+ use_fast = kwargs.pop("use_fast", None)
462
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
463
+ kwargs["_from_auto"] = True
464
+
465
+ # Resolve the image processor config filename
466
+ if "image_processor_filename" in kwargs:
467
+ image_processor_filename = kwargs.pop("image_processor_filename")
468
+ elif is_timm_local_checkpoint(pretrained_model_name_or_path):
469
+ image_processor_filename = CONFIG_NAME
470
+ else:
471
+ image_processor_filename = IMAGE_PROCESSOR_NAME
472
+
473
+ # Load the image processor config
474
+ try:
475
+ # Main path for all transformers models and local TimmWrapper checkpoints
476
+ config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
477
+ pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
478
+ )
479
+ except Exception as initial_exception:
480
+ # Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
481
+ # instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
482
+ # except the model name, the only way to check if a remote checkpoint is a timm model is to try to
483
+ # load `config.json` and if it fails with some error, we raise the initial exception.
484
+ try:
485
+ config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
486
+ pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
487
+ )
488
+ except Exception:
489
+ raise initial_exception
490
+
491
+ # In case we have a config_dict, but it's not a timm config dict, we raise the initial exception,
492
+ # because only timm models have image processing in `config.json`.
493
+ if not is_timm_config_dict(config_dict):
494
+ raise initial_exception
495
+
496
+ image_processor_type = config_dict.get("image_processor_type", None)
497
+ image_processor_auto_map = None
498
+ if "AutoImageProcessor" in config_dict.get("auto_map", {}):
499
+ image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
500
+
501
+ # If we still don't have the image processor class, check if we're loading from a previous feature extractor config
502
+ # and if so, infer the image processor class from there.
503
+ if image_processor_type is None and image_processor_auto_map is None:
504
+ feature_extractor_class = config_dict.pop("feature_extractor_type", None)
505
+ if feature_extractor_class is not None:
506
+ image_processor_type = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
507
+ if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
508
+ feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
509
+ image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor")
510
+
511
+ # If we don't find the image processor class in the image processor config, let's try the model config.
512
+ if image_processor_type is None and image_processor_auto_map is None:
513
+ if not isinstance(config, PretrainedConfig):
514
+ config = AutoConfig.from_pretrained(
515
+ pretrained_model_name_or_path,
516
+ trust_remote_code=trust_remote_code,
517
+ **kwargs,
518
+ )
519
+ # It could be in `config.image_processor_type``
520
+ image_processor_type = getattr(config, "image_processor_type", None)
521
+ if hasattr(config, "auto_map") and "AutoImageProcessor" in config.auto_map:
522
+ image_processor_auto_map = config.auto_map["AutoImageProcessor"]
523
+
524
+ image_processor_class = None
525
+ # TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
526
+ if image_processor_type is not None:
527
+ # if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor.
528
+ if use_fast is None:
529
+ use_fast = image_processor_type.endswith("Fast")
530
+ if not use_fast and image_processor_type in FORCE_FAST_IMAGE_PROCESSOR and is_torchvision_available():
531
+ use_fast = True
532
+ logger.warning_once(
533
+ f"The image processor of type `{image_processor_type}` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. "
534
+ "This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. "
535
+ "Note that this behavior will be extended to all models in a future release."
536
+ )
537
+ if not use_fast:
538
+ logger.warning_once(
539
+ "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
540
+ "`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
541
+ "This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
542
+ )
543
+ if use_fast and not image_processor_type.endswith("Fast"):
544
+ image_processor_type += "Fast"
545
+ if use_fast and not is_torchvision_available():
546
+ # check if there is a slow image processor class to fallback to
547
+ image_processor_class = get_image_processor_class_from_name(image_processor_type[:-4])
548
+ if image_processor_class is None:
549
+ raise ValueError(
550
+ f"`{image_processor_type}` requires `torchvision` to be installed. Please install `torchvision` and try again."
551
+ )
552
+ logger.warning_once(
553
+ "Using `use_fast=True` but `torchvision` is not available. Falling back to the slow image processor."
554
+ )
555
+ use_fast = False
556
+ if use_fast:
557
+ for image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.values():
558
+ if image_processor_type in image_processors:
559
+ break
560
+ else:
561
+ image_processor_type = image_processor_type[:-4]
562
+ use_fast = False
563
+ logger.warning_once(
564
+ "`use_fast` is set to `True` but the image processor class does not have a fast version. "
565
+ " Falling back to the slow version."
566
+ )
567
+ image_processor_class = get_image_processor_class_from_name(image_processor_type)
568
+ else:
569
+ image_processor_type_slow = image_processor_type.removesuffix("Fast")
570
+ image_processor_class = get_image_processor_class_from_name(image_processor_type_slow)
571
+ if image_processor_class is None and image_processor_type.endswith("Fast"):
572
+ raise ValueError(
573
+ f"`{image_processor_type}` does not have a slow version. Please set `use_fast=True` when instantiating the processor."
574
+ )
575
+
576
+ has_remote_code = image_processor_auto_map is not None
577
+ has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
578
+ if has_remote_code:
579
+ if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple):
580
+ # In some configs, only the slow image processor class is stored
581
+ image_processor_auto_map = (image_processor_auto_map, None)
582
+ if use_fast and image_processor_auto_map[1] is not None:
583
+ class_ref = image_processor_auto_map[1]
584
+ else:
585
+ class_ref = image_processor_auto_map[0]
586
+ if "--" in class_ref:
587
+ upstream_repo = class_ref.split("--")[0]
588
+ else:
589
+ upstream_repo = None
590
+ trust_remote_code = resolve_trust_remote_code(
591
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
592
+ )
593
+
594
+ if has_remote_code and trust_remote_code:
595
+ if not use_fast and image_processor_auto_map[1] is not None:
596
+ _warning_fast_image_processor_available(image_processor_auto_map[1])
597
+
598
+ image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
599
+ _ = kwargs.pop("code_revision", None)
600
+ image_processor_class.register_for_auto_class()
601
+ return image_processor_class.from_dict(config_dict, **kwargs)
602
+ elif image_processor_class is not None:
603
+ return image_processor_class.from_dict(config_dict, **kwargs)
604
+ # Last try: we use the IMAGE_PROCESSOR_MAPPING.
605
+ elif type(config) in IMAGE_PROCESSOR_MAPPING:
606
+ image_processor_tuple = IMAGE_PROCESSOR_MAPPING[type(config)]
607
+
608
+ image_processor_class_py, image_processor_class_fast = image_processor_tuple
609
+
610
+ if not use_fast and image_processor_class_fast is not None:
611
+ _warning_fast_image_processor_available(image_processor_class_fast)
612
+
613
+ if image_processor_class_fast and (use_fast or image_processor_class_py is None):
614
+ return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
615
+ else:
616
+ if image_processor_class_py is not None:
617
+ return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
618
+ else:
619
+ raise ValueError(
620
+ "This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
621
+ )
622
+ raise ValueError(
623
+ f"Unrecognized image processor in {pretrained_model_name_or_path}. Should have a "
624
+ f"`image_processor_type` key in its {IMAGE_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
625
+ f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in IMAGE_PROCESSOR_MAPPING_NAMES)}"
626
+ )
627
+
628
+ @staticmethod
629
+ def register(
630
+ config_class,
631
+ image_processor_class=None,
632
+ slow_image_processor_class=None,
633
+ fast_image_processor_class=None,
634
+ exist_ok=False,
635
+ ):
636
+ """
637
+ Register a new image processor for this class.
638
+
639
+ Args:
640
+ config_class ([`PretrainedConfig`]):
641
+ The configuration corresponding to the model to register.
642
+ image_processor_class ([`ImageProcessingMixin`]): The image processor to register.
643
+ """
644
+ if image_processor_class is not None:
645
+ if slow_image_processor_class is not None:
646
+ raise ValueError("Cannot specify both image_processor_class and slow_image_processor_class")
647
+ warnings.warn(
648
+ "The image_processor_class argument is deprecated and will be removed in v4.42. Please use `slow_image_processor_class`, or `fast_image_processor_class` instead",
649
+ FutureWarning,
650
+ )
651
+ slow_image_processor_class = image_processor_class
652
+
653
+ if slow_image_processor_class is None and fast_image_processor_class is None:
654
+ raise ValueError("You need to specify either slow_image_processor_class or fast_image_processor_class")
655
+ if slow_image_processor_class is not None and issubclass(slow_image_processor_class, BaseImageProcessorFast):
656
+ raise ValueError("You passed a fast image processor in as the `slow_image_processor_class`.")
657
+ if fast_image_processor_class is not None and not issubclass(
658
+ fast_image_processor_class, BaseImageProcessorFast
659
+ ):
660
+ raise ValueError("The `fast_image_processor_class` should inherit from `BaseImageProcessorFast`.")
661
+
662
+ if (
663
+ slow_image_processor_class is not None
664
+ and fast_image_processor_class is not None
665
+ and issubclass(fast_image_processor_class, BaseImageProcessorFast)
666
+ and fast_image_processor_class.slow_image_processor_class != slow_image_processor_class
667
+ ):
668
+ raise ValueError(
669
+ "The fast processor class you are passing has a `slow_image_processor_class` attribute that is not "
670
+ "consistent with the slow processor class you passed (fast tokenizer has "
671
+ f"{fast_image_processor_class.slow_image_processor_class} and you passed {slow_image_processor_class}. Fix one of those "
672
+ "so they match!"
673
+ )
674
+
675
+ # Avoid resetting a set slow/fast image processor if we are passing just the other ones.
676
+ if config_class in IMAGE_PROCESSOR_MAPPING._extra_content:
677
+ existing_slow, existing_fast = IMAGE_PROCESSOR_MAPPING[config_class]
678
+ if slow_image_processor_class is None:
679
+ slow_image_processor_class = existing_slow
680
+ if fast_image_processor_class is None:
681
+ fast_image_processor_class = existing_fast
682
+
683
+ IMAGE_PROCESSOR_MAPPING.register(
684
+ config_class, (slow_image_processor_class, fast_image_processor_class), exist_ok=exist_ok
685
+ )
686
+
687
+
688
+ __all__ = ["IMAGE_PROCESSOR_MAPPING", "AutoImageProcessor"]
venv/lib/python3.13/site-packages/transformers/models/auto/modeling_auto.py ADDED
The diff for this file is too large to render. See raw diff
 
venv/lib/python3.13/site-packages/transformers/models/auto/modeling_flax_auto.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Auto Model class."""
16
+
17
+ from collections import OrderedDict
18
+
19
+ from ...utils import logging
20
+ from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
21
+ from .configuration_auto import CONFIG_MAPPING_NAMES
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
28
+ [
29
+ # Base model mapping
30
+ ("albert", "FlaxAlbertModel"),
31
+ ("bart", "FlaxBartModel"),
32
+ ("beit", "FlaxBeitModel"),
33
+ ("bert", "FlaxBertModel"),
34
+ ("big_bird", "FlaxBigBirdModel"),
35
+ ("blenderbot", "FlaxBlenderbotModel"),
36
+ ("blenderbot-small", "FlaxBlenderbotSmallModel"),
37
+ ("bloom", "FlaxBloomModel"),
38
+ ("clip", "FlaxCLIPModel"),
39
+ ("dinov2", "FlaxDinov2Model"),
40
+ ("distilbert", "FlaxDistilBertModel"),
41
+ ("electra", "FlaxElectraModel"),
42
+ ("gemma", "FlaxGemmaModel"),
43
+ ("gpt-sw3", "FlaxGPT2Model"),
44
+ ("gpt2", "FlaxGPT2Model"),
45
+ ("gpt_neo", "FlaxGPTNeoModel"),
46
+ ("gptj", "FlaxGPTJModel"),
47
+ ("llama", "FlaxLlamaModel"),
48
+ ("longt5", "FlaxLongT5Model"),
49
+ ("marian", "FlaxMarianModel"),
50
+ ("mbart", "FlaxMBartModel"),
51
+ ("mistral", "FlaxMistralModel"),
52
+ ("mt5", "FlaxMT5Model"),
53
+ ("opt", "FlaxOPTModel"),
54
+ ("pegasus", "FlaxPegasusModel"),
55
+ ("regnet", "FlaxRegNetModel"),
56
+ ("resnet", "FlaxResNetModel"),
57
+ ("roberta", "FlaxRobertaModel"),
58
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"),
59
+ ("roformer", "FlaxRoFormerModel"),
60
+ ("t5", "FlaxT5Model"),
61
+ ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
62
+ ("vit", "FlaxViTModel"),
63
+ ("wav2vec2", "FlaxWav2Vec2Model"),
64
+ ("whisper", "FlaxWhisperModel"),
65
+ ("xglm", "FlaxXGLMModel"),
66
+ ("xlm-roberta", "FlaxXLMRobertaModel"),
67
+ ]
68
+ )
69
+
70
+ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
71
+ [
72
+ # Model for pre-training mapping
73
+ ("albert", "FlaxAlbertForPreTraining"),
74
+ ("bart", "FlaxBartForConditionalGeneration"),
75
+ ("bert", "FlaxBertForPreTraining"),
76
+ ("big_bird", "FlaxBigBirdForPreTraining"),
77
+ ("electra", "FlaxElectraForPreTraining"),
78
+ ("longt5", "FlaxLongT5ForConditionalGeneration"),
79
+ ("mbart", "FlaxMBartForConditionalGeneration"),
80
+ ("mt5", "FlaxMT5ForConditionalGeneration"),
81
+ ("roberta", "FlaxRobertaForMaskedLM"),
82
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
83
+ ("roformer", "FlaxRoFormerForMaskedLM"),
84
+ ("t5", "FlaxT5ForConditionalGeneration"),
85
+ ("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
86
+ ("whisper", "FlaxWhisperForConditionalGeneration"),
87
+ ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
88
+ ]
89
+ )
90
+
91
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
92
+ [
93
+ # Model for Masked LM mapping
94
+ ("albert", "FlaxAlbertForMaskedLM"),
95
+ ("bart", "FlaxBartForConditionalGeneration"),
96
+ ("bert", "FlaxBertForMaskedLM"),
97
+ ("big_bird", "FlaxBigBirdForMaskedLM"),
98
+ ("distilbert", "FlaxDistilBertForMaskedLM"),
99
+ ("electra", "FlaxElectraForMaskedLM"),
100
+ ("mbart", "FlaxMBartForConditionalGeneration"),
101
+ ("roberta", "FlaxRobertaForMaskedLM"),
102
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
103
+ ("roformer", "FlaxRoFormerForMaskedLM"),
104
+ ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
105
+ ]
106
+ )
107
+
108
+ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
109
+ [
110
+ # Model for Seq2Seq Causal LM mapping
111
+ ("bart", "FlaxBartForConditionalGeneration"),
112
+ ("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
113
+ ("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
114
+ ("encoder-decoder", "FlaxEncoderDecoderModel"),
115
+ ("longt5", "FlaxLongT5ForConditionalGeneration"),
116
+ ("marian", "FlaxMarianMTModel"),
117
+ ("mbart", "FlaxMBartForConditionalGeneration"),
118
+ ("mt5", "FlaxMT5ForConditionalGeneration"),
119
+ ("pegasus", "FlaxPegasusForConditionalGeneration"),
120
+ ("t5", "FlaxT5ForConditionalGeneration"),
121
+ ]
122
+ )
123
+
124
+ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
125
+ [
126
+ # Model for Image-classification
127
+ ("beit", "FlaxBeitForImageClassification"),
128
+ ("dinov2", "FlaxDinov2ForImageClassification"),
129
+ ("regnet", "FlaxRegNetForImageClassification"),
130
+ ("resnet", "FlaxResNetForImageClassification"),
131
+ ("vit", "FlaxViTForImageClassification"),
132
+ ]
133
+ )
134
+
135
+ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
136
+ [
137
+ ("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"),
138
+ ]
139
+ )
140
+
141
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
142
+ [
143
+ # Model for Causal LM mapping
144
+ ("bart", "FlaxBartForCausalLM"),
145
+ ("bert", "FlaxBertForCausalLM"),
146
+ ("big_bird", "FlaxBigBirdForCausalLM"),
147
+ ("bloom", "FlaxBloomForCausalLM"),
148
+ ("electra", "FlaxElectraForCausalLM"),
149
+ ("gemma", "FlaxGemmaForCausalLM"),
150
+ ("gpt-sw3", "FlaxGPT2LMHeadModel"),
151
+ ("gpt2", "FlaxGPT2LMHeadModel"),
152
+ ("gpt_neo", "FlaxGPTNeoForCausalLM"),
153
+ ("gptj", "FlaxGPTJForCausalLM"),
154
+ ("llama", "FlaxLlamaForCausalLM"),
155
+ ("mistral", "FlaxMistralForCausalLM"),
156
+ ("opt", "FlaxOPTForCausalLM"),
157
+ ("roberta", "FlaxRobertaForCausalLM"),
158
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
159
+ ("xglm", "FlaxXGLMForCausalLM"),
160
+ ("xlm-roberta", "FlaxXLMRobertaForCausalLM"),
161
+ ]
162
+ )
163
+
164
+ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
165
+ [
166
+ # Model for Sequence Classification mapping
167
+ ("albert", "FlaxAlbertForSequenceClassification"),
168
+ ("bart", "FlaxBartForSequenceClassification"),
169
+ ("bert", "FlaxBertForSequenceClassification"),
170
+ ("big_bird", "FlaxBigBirdForSequenceClassification"),
171
+ ("distilbert", "FlaxDistilBertForSequenceClassification"),
172
+ ("electra", "FlaxElectraForSequenceClassification"),
173
+ ("mbart", "FlaxMBartForSequenceClassification"),
174
+ ("roberta", "FlaxRobertaForSequenceClassification"),
175
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForSequenceClassification"),
176
+ ("roformer", "FlaxRoFormerForSequenceClassification"),
177
+ ("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
178
+ ]
179
+ )
180
+
181
+ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
182
+ [
183
+ # Model for Question Answering mapping
184
+ ("albert", "FlaxAlbertForQuestionAnswering"),
185
+ ("bart", "FlaxBartForQuestionAnswering"),
186
+ ("bert", "FlaxBertForQuestionAnswering"),
187
+ ("big_bird", "FlaxBigBirdForQuestionAnswering"),
188
+ ("distilbert", "FlaxDistilBertForQuestionAnswering"),
189
+ ("electra", "FlaxElectraForQuestionAnswering"),
190
+ ("mbart", "FlaxMBartForQuestionAnswering"),
191
+ ("roberta", "FlaxRobertaForQuestionAnswering"),
192
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForQuestionAnswering"),
193
+ ("roformer", "FlaxRoFormerForQuestionAnswering"),
194
+ ("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
195
+ ]
196
+ )
197
+
198
+ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
199
+ [
200
+ # Model for Token Classification mapping
201
+ ("albert", "FlaxAlbertForTokenClassification"),
202
+ ("bert", "FlaxBertForTokenClassification"),
203
+ ("big_bird", "FlaxBigBirdForTokenClassification"),
204
+ ("distilbert", "FlaxDistilBertForTokenClassification"),
205
+ ("electra", "FlaxElectraForTokenClassification"),
206
+ ("roberta", "FlaxRobertaForTokenClassification"),
207
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForTokenClassification"),
208
+ ("roformer", "FlaxRoFormerForTokenClassification"),
209
+ ("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
210
+ ]
211
+ )
212
+
213
+ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
214
+ [
215
+ # Model for Multiple Choice mapping
216
+ ("albert", "FlaxAlbertForMultipleChoice"),
217
+ ("bert", "FlaxBertForMultipleChoice"),
218
+ ("big_bird", "FlaxBigBirdForMultipleChoice"),
219
+ ("distilbert", "FlaxDistilBertForMultipleChoice"),
220
+ ("electra", "FlaxElectraForMultipleChoice"),
221
+ ("roberta", "FlaxRobertaForMultipleChoice"),
222
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMultipleChoice"),
223
+ ("roformer", "FlaxRoFormerForMultipleChoice"),
224
+ ("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
225
+ ]
226
+ )
227
+
228
+ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
229
+ [
230
+ ("bert", "FlaxBertForNextSentencePrediction"),
231
+ ]
232
+ )
233
+
234
+ FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
235
+ [
236
+ ("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
237
+ ("whisper", "FlaxWhisperForConditionalGeneration"),
238
+ ]
239
+ )
240
+
241
+ FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
242
+ [
243
+ ("whisper", "FlaxWhisperForAudioClassification"),
244
+ ]
245
+ )
246
+
247
+ FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
248
+ FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
249
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
250
+ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
251
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
252
+ )
253
+ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
254
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
255
+ )
256
+ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
257
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
258
+ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
259
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
260
+ )
261
+ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
262
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
263
+ )
264
+ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
265
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
266
+ )
267
+ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
268
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
269
+ )
270
+ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
271
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
272
+ )
273
+ FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
274
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
275
+ )
276
+ FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
277
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
278
+ )
279
+
280
+
281
+ class FlaxAutoModel(_BaseAutoModelClass):
282
+ _model_mapping = FLAX_MODEL_MAPPING
283
+
284
+
285
+ FlaxAutoModel = auto_class_update(FlaxAutoModel)
286
+
287
+
288
+ class FlaxAutoModelForPreTraining(_BaseAutoModelClass):
289
+ _model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING
290
+
291
+
292
+ FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining")
293
+
294
+
295
+ class FlaxAutoModelForCausalLM(_BaseAutoModelClass):
296
+ _model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
297
+
298
+
299
+ FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling")
300
+
301
+
302
+ class FlaxAutoModelForMaskedLM(_BaseAutoModelClass):
303
+ _model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING
304
+
305
+
306
+ FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling")
307
+
308
+
309
+ class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass):
310
+ _model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
311
+
312
+
313
+ FlaxAutoModelForSeq2SeqLM = auto_class_update(
314
+ FlaxAutoModelForSeq2SeqLM,
315
+ head_doc="sequence-to-sequence language modeling",
316
+ checkpoint_for_example="google-t5/t5-base",
317
+ )
318
+
319
+
320
+ class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass):
321
+ _model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
322
+
323
+
324
+ FlaxAutoModelForSequenceClassification = auto_class_update(
325
+ FlaxAutoModelForSequenceClassification, head_doc="sequence classification"
326
+ )
327
+
328
+
329
+ class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass):
330
+ _model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
331
+
332
+
333
+ FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering")
334
+
335
+
336
+ class FlaxAutoModelForTokenClassification(_BaseAutoModelClass):
337
+ _model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
338
+
339
+
340
+ FlaxAutoModelForTokenClassification = auto_class_update(
341
+ FlaxAutoModelForTokenClassification, head_doc="token classification"
342
+ )
343
+
344
+
345
+ class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass):
346
+ _model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
347
+
348
+
349
+ FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice")
350
+
351
+
352
+ class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass):
353
+ _model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
354
+
355
+
356
+ FlaxAutoModelForNextSentencePrediction = auto_class_update(
357
+ FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
358
+ )
359
+
360
+
361
+ class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
362
+ _model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
363
+
364
+
365
+ FlaxAutoModelForImageClassification = auto_class_update(
366
+ FlaxAutoModelForImageClassification, head_doc="image classification"
367
+ )
368
+
369
+
370
+ class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
371
+ _model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
372
+
373
+
374
+ FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
375
+
376
+
377
+ class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
378
+ _model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
379
+
380
+
381
+ FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
382
+ FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
383
+ )
384
+
385
+ __all__ = [
386
+ "FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
387
+ "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
388
+ "FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
389
+ "FLAX_MODEL_FOR_MASKED_LM_MAPPING",
390
+ "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
391
+ "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
392
+ "FLAX_MODEL_FOR_PRETRAINING_MAPPING",
393
+ "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
394
+ "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
395
+ "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
396
+ "FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
397
+ "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
398
+ "FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
399
+ "FLAX_MODEL_MAPPING",
400
+ "FlaxAutoModel",
401
+ "FlaxAutoModelForCausalLM",
402
+ "FlaxAutoModelForImageClassification",
403
+ "FlaxAutoModelForMaskedLM",
404
+ "FlaxAutoModelForMultipleChoice",
405
+ "FlaxAutoModelForNextSentencePrediction",
406
+ "FlaxAutoModelForPreTraining",
407
+ "FlaxAutoModelForQuestionAnswering",
408
+ "FlaxAutoModelForSeq2SeqLM",
409
+ "FlaxAutoModelForSequenceClassification",
410
+ "FlaxAutoModelForSpeechSeq2Seq",
411
+ "FlaxAutoModelForTokenClassification",
412
+ "FlaxAutoModelForVision2Seq",
413
+ ]
venv/lib/python3.13/site-packages/transformers/models/auto/modeling_tf_auto.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Auto Model class."""
16
+
17
+ import warnings
18
+ from collections import OrderedDict
19
+
20
+ from ...utils import logging
21
+ from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
22
+ from .configuration_auto import CONFIG_MAPPING_NAMES
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ TF_MODEL_MAPPING_NAMES = OrderedDict(
29
+ [
30
+ # Base model mapping
31
+ ("albert", "TFAlbertModel"),
32
+ ("bart", "TFBartModel"),
33
+ ("bert", "TFBertModel"),
34
+ ("blenderbot", "TFBlenderbotModel"),
35
+ ("blenderbot-small", "TFBlenderbotSmallModel"),
36
+ ("blip", "TFBlipModel"),
37
+ ("camembert", "TFCamembertModel"),
38
+ ("clip", "TFCLIPModel"),
39
+ ("convbert", "TFConvBertModel"),
40
+ ("convnext", "TFConvNextModel"),
41
+ ("convnextv2", "TFConvNextV2Model"),
42
+ ("ctrl", "TFCTRLModel"),
43
+ ("cvt", "TFCvtModel"),
44
+ ("data2vec-vision", "TFData2VecVisionModel"),
45
+ ("deberta", "TFDebertaModel"),
46
+ ("deberta-v2", "TFDebertaV2Model"),
47
+ ("deit", "TFDeiTModel"),
48
+ ("distilbert", "TFDistilBertModel"),
49
+ ("dpr", "TFDPRQuestionEncoder"),
50
+ ("efficientformer", "TFEfficientFormerModel"),
51
+ ("electra", "TFElectraModel"),
52
+ ("esm", "TFEsmModel"),
53
+ ("flaubert", "TFFlaubertModel"),
54
+ ("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
55
+ ("gpt-sw3", "TFGPT2Model"),
56
+ ("gpt2", "TFGPT2Model"),
57
+ ("gptj", "TFGPTJModel"),
58
+ ("groupvit", "TFGroupViTModel"),
59
+ ("hubert", "TFHubertModel"),
60
+ ("idefics", "TFIdeficsModel"),
61
+ ("layoutlm", "TFLayoutLMModel"),
62
+ ("layoutlmv3", "TFLayoutLMv3Model"),
63
+ ("led", "TFLEDModel"),
64
+ ("longformer", "TFLongformerModel"),
65
+ ("lxmert", "TFLxmertModel"),
66
+ ("marian", "TFMarianModel"),
67
+ ("mbart", "TFMBartModel"),
68
+ ("mistral", "TFMistralModel"),
69
+ ("mobilebert", "TFMobileBertModel"),
70
+ ("mobilevit", "TFMobileViTModel"),
71
+ ("mpnet", "TFMPNetModel"),
72
+ ("mt5", "TFMT5Model"),
73
+ ("openai-gpt", "TFOpenAIGPTModel"),
74
+ ("opt", "TFOPTModel"),
75
+ ("pegasus", "TFPegasusModel"),
76
+ ("regnet", "TFRegNetModel"),
77
+ ("rembert", "TFRemBertModel"),
78
+ ("resnet", "TFResNetModel"),
79
+ ("roberta", "TFRobertaModel"),
80
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
81
+ ("roformer", "TFRoFormerModel"),
82
+ ("sam", "TFSamModel"),
83
+ ("sam_vision_model", "TFSamVisionModel"),
84
+ ("segformer", "TFSegformerModel"),
85
+ ("speech_to_text", "TFSpeech2TextModel"),
86
+ ("swiftformer", "TFSwiftFormerModel"),
87
+ ("swin", "TFSwinModel"),
88
+ ("t5", "TFT5Model"),
89
+ ("tapas", "TFTapasModel"),
90
+ ("transfo-xl", "TFTransfoXLModel"),
91
+ ("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"),
92
+ ("vit", "TFViTModel"),
93
+ ("vit_mae", "TFViTMAEModel"),
94
+ ("wav2vec2", "TFWav2Vec2Model"),
95
+ ("whisper", "TFWhisperModel"),
96
+ ("xglm", "TFXGLMModel"),
97
+ ("xlm", "TFXLMModel"),
98
+ ("xlm-roberta", "TFXLMRobertaModel"),
99
+ ("xlnet", "TFXLNetModel"),
100
+ ]
101
+ )
102
+
103
+ TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
104
+ [
105
+ # Model for pre-training mapping
106
+ ("albert", "TFAlbertForPreTraining"),
107
+ ("bart", "TFBartForConditionalGeneration"),
108
+ ("bert", "TFBertForPreTraining"),
109
+ ("camembert", "TFCamembertForMaskedLM"),
110
+ ("ctrl", "TFCTRLLMHeadModel"),
111
+ ("distilbert", "TFDistilBertForMaskedLM"),
112
+ ("electra", "TFElectraForPreTraining"),
113
+ ("flaubert", "TFFlaubertWithLMHeadModel"),
114
+ ("funnel", "TFFunnelForPreTraining"),
115
+ ("gpt-sw3", "TFGPT2LMHeadModel"),
116
+ ("gpt2", "TFGPT2LMHeadModel"),
117
+ ("idefics", "TFIdeficsForVisionText2Text"),
118
+ ("layoutlm", "TFLayoutLMForMaskedLM"),
119
+ ("lxmert", "TFLxmertForPreTraining"),
120
+ ("mobilebert", "TFMobileBertForPreTraining"),
121
+ ("mpnet", "TFMPNetForMaskedLM"),
122
+ ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
123
+ ("roberta", "TFRobertaForMaskedLM"),
124
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
125
+ ("t5", "TFT5ForConditionalGeneration"),
126
+ ("tapas", "TFTapasForMaskedLM"),
127
+ ("transfo-xl", "TFTransfoXLLMHeadModel"),
128
+ ("vit_mae", "TFViTMAEForPreTraining"),
129
+ ("xlm", "TFXLMWithLMHeadModel"),
130
+ ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
131
+ ("xlnet", "TFXLNetLMHeadModel"),
132
+ ]
133
+ )
134
+
135
+ TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
136
+ [
137
+ # Model with LM heads mapping
138
+ ("albert", "TFAlbertForMaskedLM"),
139
+ ("bart", "TFBartForConditionalGeneration"),
140
+ ("bert", "TFBertForMaskedLM"),
141
+ ("camembert", "TFCamembertForMaskedLM"),
142
+ ("convbert", "TFConvBertForMaskedLM"),
143
+ ("ctrl", "TFCTRLLMHeadModel"),
144
+ ("distilbert", "TFDistilBertForMaskedLM"),
145
+ ("electra", "TFElectraForMaskedLM"),
146
+ ("esm", "TFEsmForMaskedLM"),
147
+ ("flaubert", "TFFlaubertWithLMHeadModel"),
148
+ ("funnel", "TFFunnelForMaskedLM"),
149
+ ("gpt-sw3", "TFGPT2LMHeadModel"),
150
+ ("gpt2", "TFGPT2LMHeadModel"),
151
+ ("gptj", "TFGPTJForCausalLM"),
152
+ ("layoutlm", "TFLayoutLMForMaskedLM"),
153
+ ("led", "TFLEDForConditionalGeneration"),
154
+ ("longformer", "TFLongformerForMaskedLM"),
155
+ ("marian", "TFMarianMTModel"),
156
+ ("mobilebert", "TFMobileBertForMaskedLM"),
157
+ ("mpnet", "TFMPNetForMaskedLM"),
158
+ ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
159
+ ("rembert", "TFRemBertForMaskedLM"),
160
+ ("roberta", "TFRobertaForMaskedLM"),
161
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
162
+ ("roformer", "TFRoFormerForMaskedLM"),
163
+ ("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
164
+ ("t5", "TFT5ForConditionalGeneration"),
165
+ ("tapas", "TFTapasForMaskedLM"),
166
+ ("transfo-xl", "TFTransfoXLLMHeadModel"),
167
+ ("whisper", "TFWhisperForConditionalGeneration"),
168
+ ("xlm", "TFXLMWithLMHeadModel"),
169
+ ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
170
+ ("xlnet", "TFXLNetLMHeadModel"),
171
+ ]
172
+ )
173
+
174
+ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
175
+ [
176
+ # Model for Causal LM mapping
177
+ ("bert", "TFBertLMHeadModel"),
178
+ ("camembert", "TFCamembertForCausalLM"),
179
+ ("ctrl", "TFCTRLLMHeadModel"),
180
+ ("gpt-sw3", "TFGPT2LMHeadModel"),
181
+ ("gpt2", "TFGPT2LMHeadModel"),
182
+ ("gptj", "TFGPTJForCausalLM"),
183
+ ("mistral", "TFMistralForCausalLM"),
184
+ ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
185
+ ("opt", "TFOPTForCausalLM"),
186
+ ("rembert", "TFRemBertForCausalLM"),
187
+ ("roberta", "TFRobertaForCausalLM"),
188
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"),
189
+ ("roformer", "TFRoFormerForCausalLM"),
190
+ ("transfo-xl", "TFTransfoXLLMHeadModel"),
191
+ ("xglm", "TFXGLMForCausalLM"),
192
+ ("xlm", "TFXLMWithLMHeadModel"),
193
+ ("xlm-roberta", "TFXLMRobertaForCausalLM"),
194
+ ("xlnet", "TFXLNetLMHeadModel"),
195
+ ]
196
+ )
197
+
198
+ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
199
+ [
200
+ ("deit", "TFDeiTForMaskedImageModeling"),
201
+ ("swin", "TFSwinForMaskedImageModeling"),
202
+ ]
203
+ )
204
+
205
+ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
206
+ [
207
+ # Model for Image-classsification
208
+ ("convnext", "TFConvNextForImageClassification"),
209
+ ("convnextv2", "TFConvNextV2ForImageClassification"),
210
+ ("cvt", "TFCvtForImageClassification"),
211
+ ("data2vec-vision", "TFData2VecVisionForImageClassification"),
212
+ ("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
213
+ (
214
+ "efficientformer",
215
+ ("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"),
216
+ ),
217
+ ("mobilevit", "TFMobileViTForImageClassification"),
218
+ ("regnet", "TFRegNetForImageClassification"),
219
+ ("resnet", "TFResNetForImageClassification"),
220
+ ("segformer", "TFSegformerForImageClassification"),
221
+ ("swiftformer", "TFSwiftFormerForImageClassification"),
222
+ ("swin", "TFSwinForImageClassification"),
223
+ ("vit", "TFViTForImageClassification"),
224
+ ]
225
+ )
226
+
227
+
228
+ TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
229
+ [
230
+ # Model for Zero Shot Image Classification mapping
231
+ ("blip", "TFBlipModel"),
232
+ ("clip", "TFCLIPModel"),
233
+ ]
234
+ )
235
+
236
+
237
+ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
238
+ [
239
+ # Model for Semantic Segmentation mapping
240
+ ("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
241
+ ("mobilevit", "TFMobileViTForSemanticSegmentation"),
242
+ ("segformer", "TFSegformerForSemanticSegmentation"),
243
+ ]
244
+ )
245
+
246
+ TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
247
+ [
248
+ ("blip", "TFBlipForConditionalGeneration"),
249
+ ("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
250
+ ]
251
+ )
252
+
253
+ TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
254
+ [
255
+ # Model for Masked LM mapping
256
+ ("albert", "TFAlbertForMaskedLM"),
257
+ ("bert", "TFBertForMaskedLM"),
258
+ ("camembert", "TFCamembertForMaskedLM"),
259
+ ("convbert", "TFConvBertForMaskedLM"),
260
+ ("deberta", "TFDebertaForMaskedLM"),
261
+ ("deberta-v2", "TFDebertaV2ForMaskedLM"),
262
+ ("distilbert", "TFDistilBertForMaskedLM"),
263
+ ("electra", "TFElectraForMaskedLM"),
264
+ ("esm", "TFEsmForMaskedLM"),
265
+ ("flaubert", "TFFlaubertWithLMHeadModel"),
266
+ ("funnel", "TFFunnelForMaskedLM"),
267
+ ("layoutlm", "TFLayoutLMForMaskedLM"),
268
+ ("longformer", "TFLongformerForMaskedLM"),
269
+ ("mobilebert", "TFMobileBertForMaskedLM"),
270
+ ("mpnet", "TFMPNetForMaskedLM"),
271
+ ("rembert", "TFRemBertForMaskedLM"),
272
+ ("roberta", "TFRobertaForMaskedLM"),
273
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
274
+ ("roformer", "TFRoFormerForMaskedLM"),
275
+ ("tapas", "TFTapasForMaskedLM"),
276
+ ("xlm", "TFXLMWithLMHeadModel"),
277
+ ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
278
+ ]
279
+ )
280
+
281
+ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
282
+ [
283
+ # Model for Seq2Seq Causal LM mapping
284
+ ("bart", "TFBartForConditionalGeneration"),
285
+ ("blenderbot", "TFBlenderbotForConditionalGeneration"),
286
+ ("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
287
+ ("encoder-decoder", "TFEncoderDecoderModel"),
288
+ ("led", "TFLEDForConditionalGeneration"),
289
+ ("marian", "TFMarianMTModel"),
290
+ ("mbart", "TFMBartForConditionalGeneration"),
291
+ ("mt5", "TFMT5ForConditionalGeneration"),
292
+ ("pegasus", "TFPegasusForConditionalGeneration"),
293
+ ("t5", "TFT5ForConditionalGeneration"),
294
+ ]
295
+ )
296
+
297
+ TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
298
+ [
299
+ ("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
300
+ ("whisper", "TFWhisperForConditionalGeneration"),
301
+ ]
302
+ )
303
+
304
+ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
305
+ [
306
+ # Model for Sequence Classification mapping
307
+ ("albert", "TFAlbertForSequenceClassification"),
308
+ ("bart", "TFBartForSequenceClassification"),
309
+ ("bert", "TFBertForSequenceClassification"),
310
+ ("camembert", "TFCamembertForSequenceClassification"),
311
+ ("convbert", "TFConvBertForSequenceClassification"),
312
+ ("ctrl", "TFCTRLForSequenceClassification"),
313
+ ("deberta", "TFDebertaForSequenceClassification"),
314
+ ("deberta-v2", "TFDebertaV2ForSequenceClassification"),
315
+ ("distilbert", "TFDistilBertForSequenceClassification"),
316
+ ("electra", "TFElectraForSequenceClassification"),
317
+ ("esm", "TFEsmForSequenceClassification"),
318
+ ("flaubert", "TFFlaubertForSequenceClassification"),
319
+ ("funnel", "TFFunnelForSequenceClassification"),
320
+ ("gpt-sw3", "TFGPT2ForSequenceClassification"),
321
+ ("gpt2", "TFGPT2ForSequenceClassification"),
322
+ ("gptj", "TFGPTJForSequenceClassification"),
323
+ ("layoutlm", "TFLayoutLMForSequenceClassification"),
324
+ ("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"),
325
+ ("longformer", "TFLongformerForSequenceClassification"),
326
+ ("mistral", "TFMistralForSequenceClassification"),
327
+ ("mobilebert", "TFMobileBertForSequenceClassification"),
328
+ ("mpnet", "TFMPNetForSequenceClassification"),
329
+ ("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
330
+ ("rembert", "TFRemBertForSequenceClassification"),
331
+ ("roberta", "TFRobertaForSequenceClassification"),
332
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"),
333
+ ("roformer", "TFRoFormerForSequenceClassification"),
334
+ ("tapas", "TFTapasForSequenceClassification"),
335
+ ("transfo-xl", "TFTransfoXLForSequenceClassification"),
336
+ ("xlm", "TFXLMForSequenceClassification"),
337
+ ("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
338
+ ("xlnet", "TFXLNetForSequenceClassification"),
339
+ ]
340
+ )
341
+
342
+ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
343
+ [
344
+ # Model for Question Answering mapping
345
+ ("albert", "TFAlbertForQuestionAnswering"),
346
+ ("bert", "TFBertForQuestionAnswering"),
347
+ ("camembert", "TFCamembertForQuestionAnswering"),
348
+ ("convbert", "TFConvBertForQuestionAnswering"),
349
+ ("deberta", "TFDebertaForQuestionAnswering"),
350
+ ("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
351
+ ("distilbert", "TFDistilBertForQuestionAnswering"),
352
+ ("electra", "TFElectraForQuestionAnswering"),
353
+ ("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
354
+ ("funnel", "TFFunnelForQuestionAnswering"),
355
+ ("gptj", "TFGPTJForQuestionAnswering"),
356
+ ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
357
+ ("longformer", "TFLongformerForQuestionAnswering"),
358
+ ("mobilebert", "TFMobileBertForQuestionAnswering"),
359
+ ("mpnet", "TFMPNetForQuestionAnswering"),
360
+ ("rembert", "TFRemBertForQuestionAnswering"),
361
+ ("roberta", "TFRobertaForQuestionAnswering"),
362
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"),
363
+ ("roformer", "TFRoFormerForQuestionAnswering"),
364
+ ("xlm", "TFXLMForQuestionAnsweringSimple"),
365
+ ("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
366
+ ("xlnet", "TFXLNetForQuestionAnsweringSimple"),
367
+ ]
368
+ )
369
+ TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])
370
+
371
+ TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
372
+ [
373
+ ("layoutlm", "TFLayoutLMForQuestionAnswering"),
374
+ ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
375
+ ]
376
+ )
377
+
378
+
379
+ TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
380
+ [
381
+ # Model for Table Question Answering mapping
382
+ ("tapas", "TFTapasForQuestionAnswering"),
383
+ ]
384
+ )
385
+
386
+ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
387
+ [
388
+ # Model for Token Classification mapping
389
+ ("albert", "TFAlbertForTokenClassification"),
390
+ ("bert", "TFBertForTokenClassification"),
391
+ ("camembert", "TFCamembertForTokenClassification"),
392
+ ("convbert", "TFConvBertForTokenClassification"),
393
+ ("deberta", "TFDebertaForTokenClassification"),
394
+ ("deberta-v2", "TFDebertaV2ForTokenClassification"),
395
+ ("distilbert", "TFDistilBertForTokenClassification"),
396
+ ("electra", "TFElectraForTokenClassification"),
397
+ ("esm", "TFEsmForTokenClassification"),
398
+ ("flaubert", "TFFlaubertForTokenClassification"),
399
+ ("funnel", "TFFunnelForTokenClassification"),
400
+ ("layoutlm", "TFLayoutLMForTokenClassification"),
401
+ ("layoutlmv3", "TFLayoutLMv3ForTokenClassification"),
402
+ ("longformer", "TFLongformerForTokenClassification"),
403
+ ("mobilebert", "TFMobileBertForTokenClassification"),
404
+ ("mpnet", "TFMPNetForTokenClassification"),
405
+ ("rembert", "TFRemBertForTokenClassification"),
406
+ ("roberta", "TFRobertaForTokenClassification"),
407
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"),
408
+ ("roformer", "TFRoFormerForTokenClassification"),
409
+ ("xlm", "TFXLMForTokenClassification"),
410
+ ("xlm-roberta", "TFXLMRobertaForTokenClassification"),
411
+ ("xlnet", "TFXLNetForTokenClassification"),
412
+ ]
413
+ )
414
+
415
+ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
416
+ [
417
+ # Model for Multiple Choice mapping
418
+ ("albert", "TFAlbertForMultipleChoice"),
419
+ ("bert", "TFBertForMultipleChoice"),
420
+ ("camembert", "TFCamembertForMultipleChoice"),
421
+ ("convbert", "TFConvBertForMultipleChoice"),
422
+ ("deberta-v2", "TFDebertaV2ForMultipleChoice"),
423
+ ("distilbert", "TFDistilBertForMultipleChoice"),
424
+ ("electra", "TFElectraForMultipleChoice"),
425
+ ("flaubert", "TFFlaubertForMultipleChoice"),
426
+ ("funnel", "TFFunnelForMultipleChoice"),
427
+ ("longformer", "TFLongformerForMultipleChoice"),
428
+ ("mobilebert", "TFMobileBertForMultipleChoice"),
429
+ ("mpnet", "TFMPNetForMultipleChoice"),
430
+ ("rembert", "TFRemBertForMultipleChoice"),
431
+ ("roberta", "TFRobertaForMultipleChoice"),
432
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"),
433
+ ("roformer", "TFRoFormerForMultipleChoice"),
434
+ ("xlm", "TFXLMForMultipleChoice"),
435
+ ("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
436
+ ("xlnet", "TFXLNetForMultipleChoice"),
437
+ ]
438
+ )
439
+
440
+ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
441
+ [
442
+ ("bert", "TFBertForNextSentencePrediction"),
443
+ ("mobilebert", "TFMobileBertForNextSentencePrediction"),
444
+ ]
445
+ )
446
+ TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
447
+ [
448
+ ("sam", "TFSamModel"),
449
+ ]
450
+ )
451
+ TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
452
+ [
453
+ ("albert", "TFAlbertModel"),
454
+ ("bert", "TFBertModel"),
455
+ ("convbert", "TFConvBertModel"),
456
+ ("deberta", "TFDebertaModel"),
457
+ ("deberta-v2", "TFDebertaV2Model"),
458
+ ("distilbert", "TFDistilBertModel"),
459
+ ("electra", "TFElectraModel"),
460
+ ("flaubert", "TFFlaubertModel"),
461
+ ("longformer", "TFLongformerModel"),
462
+ ("mobilebert", "TFMobileBertModel"),
463
+ ("mt5", "TFMT5EncoderModel"),
464
+ ("rembert", "TFRemBertModel"),
465
+ ("roberta", "TFRobertaModel"),
466
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
467
+ ("roformer", "TFRoFormerModel"),
468
+ ("t5", "TFT5EncoderModel"),
469
+ ("xlm", "TFXLMModel"),
470
+ ("xlm-roberta", "TFXLMRobertaModel"),
471
+ ]
472
+ )
473
+
474
+ TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
475
+ TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
476
+ TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
477
+ TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
478
+ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
479
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
480
+ )
481
+ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
482
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
483
+ )
484
+ TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
485
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
486
+ )
487
+ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
488
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
489
+ )
490
+ TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
491
+ TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
492
+ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
493
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
494
+ )
495
+ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
496
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
497
+ )
498
+ TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
499
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
500
+ )
501
+ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
502
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
503
+ )
504
+ TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
505
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
506
+ )
507
+ TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
508
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
509
+ )
510
+ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
511
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
512
+ )
513
+ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
514
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
515
+ )
516
+ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
517
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
518
+ )
519
+ TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
520
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
521
+ )
522
+
523
+ TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
524
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
525
+ )
526
+
527
+ TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
528
+
529
+
530
+ class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
531
+ _model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
532
+
533
+
534
+ class TFAutoModelForTextEncoding(_BaseAutoModelClass):
535
+ _model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING
536
+
537
+
538
+ class TFAutoModel(_BaseAutoModelClass):
539
+ _model_mapping = TF_MODEL_MAPPING
540
+
541
+
542
+ TFAutoModel = auto_class_update(TFAutoModel)
543
+
544
+
545
+ class TFAutoModelForAudioClassification(_BaseAutoModelClass):
546
+ _model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
547
+
548
+
549
+ TFAutoModelForAudioClassification = auto_class_update(
550
+ TFAutoModelForAudioClassification, head_doc="audio classification"
551
+ )
552
+
553
+
554
+ class TFAutoModelForPreTraining(_BaseAutoModelClass):
555
+ _model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
556
+
557
+
558
+ TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
559
+
560
+
561
+ # Private on purpose, the public class will add the deprecation warnings.
562
+ class _TFAutoModelWithLMHead(_BaseAutoModelClass):
563
+ _model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
564
+
565
+
566
+ _TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
567
+
568
+
569
+ class TFAutoModelForCausalLM(_BaseAutoModelClass):
570
+ _model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
571
+
572
+
573
+ TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
574
+
575
+
576
+ class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass):
577
+ _model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
578
+
579
+
580
+ TFAutoModelForMaskedImageModeling = auto_class_update(
581
+ TFAutoModelForMaskedImageModeling, head_doc="masked image modeling"
582
+ )
583
+
584
+
585
+ class TFAutoModelForImageClassification(_BaseAutoModelClass):
586
+ _model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
587
+
588
+
589
+ TFAutoModelForImageClassification = auto_class_update(
590
+ TFAutoModelForImageClassification, head_doc="image classification"
591
+ )
592
+
593
+
594
+ class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
595
+ _model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
596
+
597
+
598
+ TFAutoModelForZeroShotImageClassification = auto_class_update(
599
+ TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
600
+ )
601
+
602
+
603
+ class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
604
+ _model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
605
+
606
+
607
+ TFAutoModelForSemanticSegmentation = auto_class_update(
608
+ TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
609
+ )
610
+
611
+
612
+ class TFAutoModelForVision2Seq(_BaseAutoModelClass):
613
+ _model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
614
+
615
+
616
+ TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")
617
+
618
+
619
+ class TFAutoModelForMaskedLM(_BaseAutoModelClass):
620
+ _model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
621
+
622
+
623
+ TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")
624
+
625
+
626
+ class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
627
+ _model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
628
+
629
+
630
+ TFAutoModelForSeq2SeqLM = auto_class_update(
631
+ TFAutoModelForSeq2SeqLM,
632
+ head_doc="sequence-to-sequence language modeling",
633
+ checkpoint_for_example="google-t5/t5-base",
634
+ )
635
+
636
+
637
+ class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
638
+ _model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
639
+
640
+
641
+ TFAutoModelForSequenceClassification = auto_class_update(
642
+ TFAutoModelForSequenceClassification, head_doc="sequence classification"
643
+ )
644
+
645
+
646
+ class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
647
+ _model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
648
+
649
+
650
+ TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")
651
+
652
+
653
+ class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
654
+ _model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
655
+
656
+
657
+ TFAutoModelForDocumentQuestionAnswering = auto_class_update(
658
+ TFAutoModelForDocumentQuestionAnswering,
659
+ head_doc="document question answering",
660
+ checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
661
+ )
662
+
663
+
664
+ class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass):
665
+ _model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
666
+
667
+
668
+ TFAutoModelForTableQuestionAnswering = auto_class_update(
669
+ TFAutoModelForTableQuestionAnswering,
670
+ head_doc="table question answering",
671
+ checkpoint_for_example="google/tapas-base-finetuned-wtq",
672
+ )
673
+
674
+
675
+ class TFAutoModelForTokenClassification(_BaseAutoModelClass):
676
+ _model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
677
+
678
+
679
+ TFAutoModelForTokenClassification = auto_class_update(
680
+ TFAutoModelForTokenClassification, head_doc="token classification"
681
+ )
682
+
683
+
684
+ class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
685
+ _model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
686
+
687
+
688
+ TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice")
689
+
690
+
691
+ class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
692
+ _model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
693
+
694
+
695
+ TFAutoModelForNextSentencePrediction = auto_class_update(
696
+ TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
697
+ )
698
+
699
+
700
+ class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
701
+ _model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
702
+
703
+
704
+ TFAutoModelForSpeechSeq2Seq = auto_class_update(
705
+ TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
706
+ )
707
+
708
+
709
+ class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
710
+ @classmethod
711
+ def from_config(cls, config):
712
+ warnings.warn(
713
+ "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
714
+ " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
715
+ " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
716
+ FutureWarning,
717
+ )
718
+ return super().from_config(config)
719
+
720
+ @classmethod
721
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
722
+ warnings.warn(
723
+ "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
724
+ " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
725
+ " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
726
+ FutureWarning,
727
+ )
728
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
729
+
730
+
731
+ __all__ = [
732
+ "TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
733
+ "TF_MODEL_FOR_CAUSAL_LM_MAPPING",
734
+ "TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
735
+ "TF_MODEL_FOR_MASK_GENERATION_MAPPING",
736
+ "TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
737
+ "TF_MODEL_FOR_MASKED_LM_MAPPING",
738
+ "TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
739
+ "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
740
+ "TF_MODEL_FOR_PRETRAINING_MAPPING",
741
+ "TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
742
+ "TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
743
+ "TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
744
+ "TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
745
+ "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
746
+ "TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
747
+ "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
748
+ "TF_MODEL_FOR_TEXT_ENCODING_MAPPING",
749
+ "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
750
+ "TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
751
+ "TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
752
+ "TF_MODEL_MAPPING",
753
+ "TF_MODEL_WITH_LM_HEAD_MAPPING",
754
+ "TFAutoModel",
755
+ "TFAutoModelForAudioClassification",
756
+ "TFAutoModelForCausalLM",
757
+ "TFAutoModelForImageClassification",
758
+ "TFAutoModelForMaskedImageModeling",
759
+ "TFAutoModelForMaskedLM",
760
+ "TFAutoModelForMaskGeneration",
761
+ "TFAutoModelForMultipleChoice",
762
+ "TFAutoModelForNextSentencePrediction",
763
+ "TFAutoModelForPreTraining",
764
+ "TFAutoModelForDocumentQuestionAnswering",
765
+ "TFAutoModelForQuestionAnswering",
766
+ "TFAutoModelForSemanticSegmentation",
767
+ "TFAutoModelForSeq2SeqLM",
768
+ "TFAutoModelForSequenceClassification",
769
+ "TFAutoModelForSpeechSeq2Seq",
770
+ "TFAutoModelForTableQuestionAnswering",
771
+ "TFAutoModelForTextEncoding",
772
+ "TFAutoModelForTokenClassification",
773
+ "TFAutoModelForVision2Seq",
774
+ "TFAutoModelForZeroShotImageClassification",
775
+ "TFAutoModelWithLMHead",
776
+ ]
venv/lib/python3.13/site-packages/transformers/models/auto/processing_auto.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """AutoProcessor class."""
16
+
17
+ import importlib
18
+ import inspect
19
+ import json
20
+ import warnings
21
+ from collections import OrderedDict
22
+
23
+ # Build the list of all feature extractors
24
+ from ...configuration_utils import PretrainedConfig
25
+ from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
26
+ from ...feature_extraction_utils import FeatureExtractionMixin
27
+ from ...image_processing_utils import ImageProcessingMixin
28
+ from ...processing_utils import ProcessorMixin
29
+ from ...tokenization_utils import TOKENIZER_CONFIG_FILE
30
+ from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, VIDEO_PROCESSOR_NAME, cached_file, logging
31
+ from ...video_processing_utils import BaseVideoProcessor
32
+ from .auto_factory import _LazyAutoMapping
33
+ from .configuration_auto import (
34
+ CONFIG_MAPPING_NAMES,
35
+ AutoConfig,
36
+ model_type_to_module_name,
37
+ replace_list_option_in_docstrings,
38
+ )
39
+ from .feature_extraction_auto import AutoFeatureExtractor
40
+ from .image_processing_auto import AutoImageProcessor
41
+ from .tokenization_auto import AutoTokenizer
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ PROCESSOR_MAPPING_NAMES = OrderedDict(
47
+ [
48
+ ("aimv2", "CLIPProcessor"),
49
+ ("align", "AlignProcessor"),
50
+ ("altclip", "AltCLIPProcessor"),
51
+ ("aria", "AriaProcessor"),
52
+ ("aya_vision", "AyaVisionProcessor"),
53
+ ("bark", "BarkProcessor"),
54
+ ("blip", "BlipProcessor"),
55
+ ("blip-2", "Blip2Processor"),
56
+ ("bridgetower", "BridgeTowerProcessor"),
57
+ ("chameleon", "ChameleonProcessor"),
58
+ ("chinese_clip", "ChineseCLIPProcessor"),
59
+ ("clap", "ClapProcessor"),
60
+ ("clip", "CLIPProcessor"),
61
+ ("clipseg", "CLIPSegProcessor"),
62
+ ("clvp", "ClvpProcessor"),
63
+ ("cohere2_vision", "Cohere2VisionProcessor"),
64
+ ("colpali", "ColPaliProcessor"),
65
+ ("colqwen2", "ColQwen2Processor"),
66
+ ("deepseek_vl", "DeepseekVLProcessor"),
67
+ ("deepseek_vl_hybrid", "DeepseekVLHybridProcessor"),
68
+ ("dia", "DiaProcessor"),
69
+ ("edgetam", "Sam2Processor"),
70
+ ("emu3", "Emu3Processor"),
71
+ ("evolla", "EvollaProcessor"),
72
+ ("flava", "FlavaProcessor"),
73
+ ("florence2", "Florence2Processor"),
74
+ ("fuyu", "FuyuProcessor"),
75
+ ("gemma3", "Gemma3Processor"),
76
+ ("gemma3n", "Gemma3nProcessor"),
77
+ ("git", "GitProcessor"),
78
+ ("glm4v", "Glm4vProcessor"),
79
+ ("glm4v_moe", "Glm4vProcessor"),
80
+ ("got_ocr2", "GotOcr2Processor"),
81
+ ("granite_speech", "GraniteSpeechProcessor"),
82
+ ("grounding-dino", "GroundingDinoProcessor"),
83
+ ("groupvit", "CLIPProcessor"),
84
+ ("hubert", "Wav2Vec2Processor"),
85
+ ("idefics", "IdeficsProcessor"),
86
+ ("idefics2", "Idefics2Processor"),
87
+ ("idefics3", "Idefics3Processor"),
88
+ ("instructblip", "InstructBlipProcessor"),
89
+ ("instructblipvideo", "InstructBlipVideoProcessor"),
90
+ ("internvl", "InternVLProcessor"),
91
+ ("janus", "JanusProcessor"),
92
+ ("kosmos-2", "Kosmos2Processor"),
93
+ ("kosmos-2.5", "Kosmos2_5Processor"),
94
+ ("kyutai_speech_to_text", "KyutaiSpeechToTextProcessor"),
95
+ ("layoutlmv2", "LayoutLMv2Processor"),
96
+ ("layoutlmv3", "LayoutLMv3Processor"),
97
+ ("lfm2_vl", "Lfm2VlProcessor"),
98
+ ("llama4", "Llama4Processor"),
99
+ ("llava", "LlavaProcessor"),
100
+ ("llava_next", "LlavaNextProcessor"),
101
+ ("llava_next_video", "LlavaNextVideoProcessor"),
102
+ ("llava_onevision", "LlavaOnevisionProcessor"),
103
+ ("markuplm", "MarkupLMProcessor"),
104
+ ("mctct", "MCTCTProcessor"),
105
+ ("metaclip_2", "CLIPProcessor"),
106
+ ("mgp-str", "MgpstrProcessor"),
107
+ ("mistral3", "PixtralProcessor"),
108
+ ("mllama", "MllamaProcessor"),
109
+ ("mm-grounding-dino", "GroundingDinoProcessor"),
110
+ ("moonshine", "Wav2Vec2Processor"),
111
+ ("oneformer", "OneFormerProcessor"),
112
+ ("ovis2", "Ovis2Processor"),
113
+ ("owlv2", "Owlv2Processor"),
114
+ ("owlvit", "OwlViTProcessor"),
115
+ ("paligemma", "PaliGemmaProcessor"),
116
+ ("perception_lm", "PerceptionLMProcessor"),
117
+ ("phi4_multimodal", "Phi4MultimodalProcessor"),
118
+ ("pix2struct", "Pix2StructProcessor"),
119
+ ("pixtral", "PixtralProcessor"),
120
+ ("pop2piano", "Pop2PianoProcessor"),
121
+ ("qwen2_5_omni", "Qwen2_5OmniProcessor"),
122
+ ("qwen2_5_vl", "Qwen2_5_VLProcessor"),
123
+ ("qwen2_audio", "Qwen2AudioProcessor"),
124
+ ("qwen2_vl", "Qwen2VLProcessor"),
125
+ ("qwen3_omni_moe", "Qwen3OmniMoeProcessor"),
126
+ ("qwen3_vl", "Qwen3VLProcessor"),
127
+ ("qwen3_vl_moe", "Qwen3VLProcessor"),
128
+ ("sam", "SamProcessor"),
129
+ ("sam2", "Sam2Processor"),
130
+ ("sam_hq", "SamHQProcessor"),
131
+ ("seamless_m4t", "SeamlessM4TProcessor"),
132
+ ("sew", "Wav2Vec2Processor"),
133
+ ("sew-d", "Wav2Vec2Processor"),
134
+ ("shieldgemma2", "ShieldGemma2Processor"),
135
+ ("siglip", "SiglipProcessor"),
136
+ ("siglip2", "Siglip2Processor"),
137
+ ("smolvlm", "SmolVLMProcessor"),
138
+ ("speech_to_text", "Speech2TextProcessor"),
139
+ ("speech_to_text_2", "Speech2Text2Processor"),
140
+ ("speecht5", "SpeechT5Processor"),
141
+ ("trocr", "TrOCRProcessor"),
142
+ ("tvlt", "TvltProcessor"),
143
+ ("tvp", "TvpProcessor"),
144
+ ("udop", "UdopProcessor"),
145
+ ("unispeech", "Wav2Vec2Processor"),
146
+ ("unispeech-sat", "Wav2Vec2Processor"),
147
+ ("video_llava", "VideoLlavaProcessor"),
148
+ ("vilt", "ViltProcessor"),
149
+ ("vipllava", "LlavaProcessor"),
150
+ ("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
151
+ ("voxtral", "VoxtralProcessor"),
152
+ ("wav2vec2", "Wav2Vec2Processor"),
153
+ ("wav2vec2-bert", "Wav2Vec2Processor"),
154
+ ("wav2vec2-conformer", "Wav2Vec2Processor"),
155
+ ("wavlm", "Wav2Vec2Processor"),
156
+ ("whisper", "WhisperProcessor"),
157
+ ("xclip", "XCLIPProcessor"),
158
+ ]
159
+ )
160
+
161
+ PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, PROCESSOR_MAPPING_NAMES)
162
+
163
+
164
+ def processor_class_from_name(class_name: str):
165
+ for module_name, processors in PROCESSOR_MAPPING_NAMES.items():
166
+ if class_name in processors:
167
+ module_name = model_type_to_module_name(module_name)
168
+
169
+ module = importlib.import_module(f".{module_name}", "transformers.models")
170
+ try:
171
+ return getattr(module, class_name)
172
+ except AttributeError:
173
+ continue
174
+
175
+ for processor in PROCESSOR_MAPPING._extra_content.values():
176
+ if getattr(processor, "__name__", None) == class_name:
177
+ return processor
178
+
179
+ # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
180
+ # init and we return the proper dummy to get an appropriate error message.
181
+ main_module = importlib.import_module("transformers")
182
+ if hasattr(main_module, class_name):
183
+ return getattr(main_module, class_name)
184
+
185
+ return None
186
+
187
+
188
+ class AutoProcessor:
189
+ r"""
190
+ This is a generic processor class that will be instantiated as one of the processor classes of the library when
191
+ created with the [`AutoProcessor.from_pretrained`] class method.
192
+
193
+ This class cannot be instantiated directly using `__init__()` (throws an error).
194
+ """
195
+
196
+ def __init__(self):
197
+ raise OSError(
198
+ "AutoProcessor is designed to be instantiated "
199
+ "using the `AutoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
200
+ )
201
+
202
+ @classmethod
203
+ @replace_list_option_in_docstrings(PROCESSOR_MAPPING_NAMES)
204
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
205
+ r"""
206
+ Instantiate one of the processor classes of the library from a pretrained model vocabulary.
207
+
208
+ The processor class to instantiate is selected based on the `model_type` property of the config object (either
209
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible):
210
+
211
+ List options
212
+
213
+ Params:
214
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
215
+ This can be either:
216
+
217
+ - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
218
+ huggingface.co.
219
+ - a path to a *directory* containing a processor files saved using the `save_pretrained()` method,
220
+ e.g., `./my_model_directory/`.
221
+ cache_dir (`str` or `os.PathLike`, *optional*):
222
+ Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
223
+ standard cache should not be used.
224
+ force_download (`bool`, *optional*, defaults to `False`):
225
+ Whether or not to force to (re-)download the feature extractor files and override the cached versions
226
+ if they exist.
227
+ resume_download:
228
+ Deprecated and ignored. All downloads are now resumed by default when possible.
229
+ Will be removed in v5 of Transformers.
230
+ proxies (`dict[str, str]`, *optional*):
231
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
232
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
233
+ token (`str` or *bool*, *optional*):
234
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
235
+ when running `hf auth login` (stored in `~/.huggingface`).
236
+ revision (`str`, *optional*, defaults to `"main"`):
237
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
238
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
239
+ identifier allowed by git.
240
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
241
+ If `False`, then this function returns just the final feature extractor object. If `True`, then this
242
+ functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
243
+ consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
244
+ `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
245
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
246
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
247
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
248
+ execute code present on the Hub on your local machine.
249
+ kwargs (`dict[str, Any]`, *optional*):
250
+ The values in kwargs of any keys which are feature extractor attributes will be used to override the
251
+ loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
252
+ controlled by the `return_unused_kwargs` keyword parameter.
253
+
254
+ <Tip>
255
+
256
+ Passing `token=True` is required when you want to use a private model.
257
+
258
+ </Tip>
259
+
260
+ Examples:
261
+
262
+ ```python
263
+ >>> from transformers import AutoProcessor
264
+
265
+ >>> # Download processor from huggingface.co and cache.
266
+ >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
267
+
268
+ >>> # If processor files are in a directory (e.g. processor was saved using *save_pretrained('./test/saved_model/')*)
269
+ >>> # processor = AutoProcessor.from_pretrained("./test/saved_model/")
270
+ ```"""
271
+ use_auth_token = kwargs.pop("use_auth_token", None)
272
+ if use_auth_token is not None:
273
+ warnings.warn(
274
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
275
+ FutureWarning,
276
+ )
277
+ if kwargs.get("token") is not None:
278
+ raise ValueError(
279
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
280
+ )
281
+ kwargs["token"] = use_auth_token
282
+
283
+ config = kwargs.pop("config", None)
284
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
285
+ kwargs["_from_auto"] = True
286
+
287
+ processor_class = None
288
+ processor_auto_map = None
289
+
290
+ # First, let's see if we have a processor or preprocessor config.
291
+ # Filter the kwargs for `cached_file`.
292
+ cached_file_kwargs = {key: kwargs[key] for key in inspect.signature(cached_file).parameters if key in kwargs}
293
+ # We don't want to raise
294
+ cached_file_kwargs.update(
295
+ {
296
+ "_raise_exceptions_for_gated_repo": False,
297
+ "_raise_exceptions_for_missing_entries": False,
298
+ "_raise_exceptions_for_connection_errors": False,
299
+ }
300
+ )
301
+
302
+ # Let's start by checking whether the processor class is saved in a processor config
303
+ processor_config_file = cached_file(pretrained_model_name_or_path, PROCESSOR_NAME, **cached_file_kwargs)
304
+ if processor_config_file is not None:
305
+ config_dict, _ = ProcessorMixin.get_processor_dict(pretrained_model_name_or_path, **kwargs)
306
+ processor_class = config_dict.get("processor_class", None)
307
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
308
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
309
+
310
+ if processor_class is None:
311
+ # If not found, let's check whether the processor class is saved in an image processor config
312
+ preprocessor_config_file = cached_file(
313
+ pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
314
+ )
315
+ if preprocessor_config_file is not None:
316
+ config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
317
+ processor_class = config_dict.get("processor_class", None)
318
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
319
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
320
+
321
+ # Saved as video processor
322
+ if preprocessor_config_file is None:
323
+ preprocessor_config_file = cached_file(
324
+ pretrained_model_name_or_path, VIDEO_PROCESSOR_NAME, **cached_file_kwargs
325
+ )
326
+ if preprocessor_config_file is not None:
327
+ config_dict, _ = BaseVideoProcessor.get_video_processor_dict(
328
+ pretrained_model_name_or_path, **kwargs
329
+ )
330
+ processor_class = config_dict.get("processor_class", None)
331
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
332
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
333
+
334
+ # Saved as feature extractor
335
+ if preprocessor_config_file is None:
336
+ preprocessor_config_file = cached_file(
337
+ pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
338
+ )
339
+ if preprocessor_config_file is not None and processor_class is None:
340
+ config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(
341
+ pretrained_model_name_or_path, **kwargs
342
+ )
343
+ processor_class = config_dict.get("processor_class", None)
344
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
345
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
346
+
347
+ if processor_class is None:
348
+ # Next, let's check whether the processor class is saved in a tokenizer
349
+ tokenizer_config_file = cached_file(
350
+ pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **cached_file_kwargs
351
+ )
352
+ if tokenizer_config_file is not None:
353
+ with open(tokenizer_config_file, encoding="utf-8") as reader:
354
+ config_dict = json.load(reader)
355
+
356
+ processor_class = config_dict.get("processor_class", None)
357
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
358
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
359
+
360
+ if processor_class is None:
361
+ # Otherwise, load config, if it can be loaded.
362
+ if not isinstance(config, PretrainedConfig):
363
+ config = AutoConfig.from_pretrained(
364
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
365
+ )
366
+
367
+ # And check if the config contains the processor class.
368
+ processor_class = getattr(config, "processor_class", None)
369
+ if hasattr(config, "auto_map") and "AutoProcessor" in config.auto_map:
370
+ processor_auto_map = config.auto_map["AutoProcessor"]
371
+
372
+ if processor_class is not None:
373
+ processor_class = processor_class_from_name(processor_class)
374
+
375
+ has_remote_code = processor_auto_map is not None
376
+ has_local_code = processor_class is not None or type(config) in PROCESSOR_MAPPING
377
+ if has_remote_code:
378
+ if "--" in processor_auto_map:
379
+ upstream_repo = processor_auto_map.split("--")[0]
380
+ else:
381
+ upstream_repo = None
382
+ trust_remote_code = resolve_trust_remote_code(
383
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
384
+ )
385
+
386
+ if has_remote_code and trust_remote_code:
387
+ processor_class = get_class_from_dynamic_module(
388
+ processor_auto_map, pretrained_model_name_or_path, **kwargs
389
+ )
390
+ _ = kwargs.pop("code_revision", None)
391
+ processor_class.register_for_auto_class()
392
+ return processor_class.from_pretrained(
393
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
394
+ )
395
+ elif processor_class is not None:
396
+ return processor_class.from_pretrained(
397
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
398
+ )
399
+ # Last try: we use the PROCESSOR_MAPPING.
400
+ elif type(config) in PROCESSOR_MAPPING:
401
+ return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs)
402
+
403
+ # At this stage, there doesn't seem to be a `Processor` class available for this model, so let's try a
404
+ # tokenizer.
405
+ try:
406
+ return AutoTokenizer.from_pretrained(
407
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
408
+ )
409
+ except Exception:
410
+ try:
411
+ return AutoImageProcessor.from_pretrained(
412
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
413
+ )
414
+ except Exception:
415
+ pass
416
+
417
+ try:
418
+ return AutoFeatureExtractor.from_pretrained(
419
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
420
+ )
421
+ except Exception:
422
+ pass
423
+
424
+ raise ValueError(
425
+ f"Unrecognized processing class in {pretrained_model_name_or_path}. Can't instantiate a processor, a "
426
+ "tokenizer, an image processor or a feature extractor for this model. Make sure the repository contains "
427
+ "the files of at least one of those processing classes."
428
+ )
429
+
430
+ @staticmethod
431
+ def register(config_class, processor_class, exist_ok=False):
432
+ """
433
+ Register a new processor for this class.
434
+
435
+ Args:
436
+ config_class ([`PretrainedConfig`]):
437
+ The configuration corresponding to the model to register.
438
+ processor_class ([`ProcessorMixin`]): The processor to register.
439
+ """
440
+ PROCESSOR_MAPPING.register(config_class, processor_class, exist_ok=exist_ok)
441
+
442
+
443
+ __all__ = ["PROCESSOR_MAPPING", "AutoProcessor"]
venv/lib/python3.13/site-packages/transformers/models/auto/tokenization_auto.py ADDED
@@ -0,0 +1,1235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Auto Tokenizer class."""
16
+
17
+ import importlib
18
+ import json
19
+ import os
20
+ import warnings
21
+ from collections import OrderedDict
22
+ from typing import Any, Optional, Union
23
+
24
+ from transformers.utils.import_utils import is_mistral_common_available
25
+
26
+ from ...configuration_utils import PretrainedConfig
27
+ from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
28
+ from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint
29
+ from ...tokenization_utils import PreTrainedTokenizer
30
+ from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
31
+ from ...utils import (
32
+ cached_file,
33
+ extract_commit_hash,
34
+ is_g2p_en_available,
35
+ is_sentencepiece_available,
36
+ is_tokenizers_available,
37
+ logging,
38
+ )
39
+ from ..encoder_decoder import EncoderDecoderConfig
40
+ from .auto_factory import _LazyAutoMapping
41
+ from .configuration_auto import (
42
+ CONFIG_MAPPING_NAMES,
43
+ AutoConfig,
44
+ config_class_to_model_type,
45
+ model_type_to_module_name,
46
+ replace_list_option_in_docstrings,
47
+ )
48
+
49
+
50
+ if is_tokenizers_available():
51
+ from ...tokenization_utils_fast import PreTrainedTokenizerFast
52
+ else:
53
+ PreTrainedTokenizerFast = None
54
+
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ # Explicit rather than inferred generics to significantly improves completion suggestion performance for language servers.
59
+ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
60
+ [
61
+ (
62
+ "aimv2",
63
+ (
64
+ "CLIPTokenizer",
65
+ "CLIPTokenizerFast" if is_tokenizers_available() else None,
66
+ ),
67
+ ),
68
+ (
69
+ "albert",
70
+ (
71
+ "AlbertTokenizer" if is_sentencepiece_available() else None,
72
+ "AlbertTokenizerFast" if is_tokenizers_available() else None,
73
+ ),
74
+ ),
75
+ ("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
76
+ ("arcee", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
77
+ ("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
78
+ ("aya_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
79
+ ("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
80
+ ("bart", ("BartTokenizer", "BartTokenizerFast")),
81
+ (
82
+ "barthez",
83
+ (
84
+ "BarthezTokenizer" if is_sentencepiece_available() else None,
85
+ "BarthezTokenizerFast" if is_tokenizers_available() else None,
86
+ ),
87
+ ),
88
+ ("bartpho", ("BartphoTokenizer", None)),
89
+ ("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
90
+ ("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
91
+ ("bert-japanese", ("BertJapaneseTokenizer", None)),
92
+ ("bertweet", ("BertweetTokenizer", None)),
93
+ (
94
+ "big_bird",
95
+ (
96
+ "BigBirdTokenizer" if is_sentencepiece_available() else None,
97
+ "BigBirdTokenizerFast" if is_tokenizers_available() else None,
98
+ ),
99
+ ),
100
+ ("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
101
+ ("biogpt", ("BioGptTokenizer", None)),
102
+ ("bitnet", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
103
+ ("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
104
+ ("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
105
+ ("blip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
106
+ ("blip-2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
107
+ ("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)),
108
+ ("blt", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
109
+ ("bridgetower", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
110
+ ("bros", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
111
+ ("byt5", ("ByT5Tokenizer", None)),
112
+ (
113
+ "camembert",
114
+ (
115
+ "CamembertTokenizer" if is_sentencepiece_available() else None,
116
+ "CamembertTokenizerFast" if is_tokenizers_available() else None,
117
+ ),
118
+ ),
119
+ ("canine", ("CanineTokenizer", None)),
120
+ (
121
+ "chameleon",
122
+ (
123
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
124
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
125
+ ),
126
+ ),
127
+ ("chinese_clip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
128
+ (
129
+ "clap",
130
+ (
131
+ "RobertaTokenizer",
132
+ "RobertaTokenizerFast" if is_tokenizers_available() else None,
133
+ ),
134
+ ),
135
+ (
136
+ "clip",
137
+ (
138
+ "CLIPTokenizer",
139
+ "CLIPTokenizerFast" if is_tokenizers_available() else None,
140
+ ),
141
+ ),
142
+ (
143
+ "clipseg",
144
+ (
145
+ "CLIPTokenizer",
146
+ "CLIPTokenizerFast" if is_tokenizers_available() else None,
147
+ ),
148
+ ),
149
+ ("clvp", ("ClvpTokenizer", None)),
150
+ (
151
+ "code_llama",
152
+ (
153
+ "CodeLlamaTokenizer" if is_sentencepiece_available() else None,
154
+ "CodeLlamaTokenizerFast" if is_tokenizers_available() else None,
155
+ ),
156
+ ),
157
+ ("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
158
+ ("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
159
+ ("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
160
+ ("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
161
+ ("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
162
+ ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
163
+ (
164
+ "cpm",
165
+ (
166
+ "CpmTokenizer" if is_sentencepiece_available() else None,
167
+ "CpmTokenizerFast" if is_tokenizers_available() else None,
168
+ ),
169
+ ),
170
+ ("cpmant", ("CpmAntTokenizer", None)),
171
+ ("csm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
172
+ ("ctrl", ("CTRLTokenizer", None)),
173
+ ("data2vec-audio", ("Wav2Vec2CTCTokenizer", None)),
174
+ ("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
175
+ ("dbrx", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
176
+ ("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
177
+ (
178
+ "deberta-v2",
179
+ (
180
+ "DebertaV2Tokenizer" if is_sentencepiece_available() else None,
181
+ "DebertaV2TokenizerFast" if is_tokenizers_available() else None,
182
+ ),
183
+ ),
184
+ (
185
+ "deepseek_v2",
186
+ (
187
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
188
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
189
+ ),
190
+ ),
191
+ (
192
+ "deepseek_v3",
193
+ (
194
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
195
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
196
+ ),
197
+ ),
198
+ (
199
+ "deepseek_vl",
200
+ (
201
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
202
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
203
+ ),
204
+ ),
205
+ (
206
+ "deepseek_vl_hybrid",
207
+ (
208
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
209
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
210
+ ),
211
+ ),
212
+ ("dia", ("DiaTokenizer", None)),
213
+ (
214
+ "diffllama",
215
+ (
216
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
217
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
218
+ ),
219
+ ),
220
+ ("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
221
+ (
222
+ "dpr",
223
+ (
224
+ "DPRQuestionEncoderTokenizer",
225
+ "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
226
+ ),
227
+ ),
228
+ ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
229
+ ("emu3", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
230
+ ("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
231
+ ("ernie4_5", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
232
+ ("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
233
+ ("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
234
+ ("esm", ("EsmTokenizer", None)),
235
+ (
236
+ "exaone4",
237
+ (
238
+ "GPT2Tokenizer" if is_tokenizers_available() else None,
239
+ "GPT2TokenizerFast" if is_tokenizers_available() else None,
240
+ ),
241
+ ),
242
+ ("falcon", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
243
+ ("falcon_mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
244
+ (
245
+ "fastspeech2_conformer",
246
+ ("FastSpeech2ConformerTokenizer" if is_g2p_en_available() else None, None),
247
+ ),
248
+ ("flaubert", ("FlaubertTokenizer", None)),
249
+ ("flex_olmo", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
250
+ ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
251
+ ("fsmt", ("FSMTTokenizer", None)),
252
+ ("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
253
+ (
254
+ "gemma",
255
+ (
256
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
257
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
258
+ ),
259
+ ),
260
+ (
261
+ "gemma2",
262
+ (
263
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
264
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
265
+ ),
266
+ ),
267
+ (
268
+ "gemma3",
269
+ (
270
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
271
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
272
+ ),
273
+ ),
274
+ (
275
+ "gemma3_text",
276
+ (
277
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
278
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
279
+ ),
280
+ ),
281
+ (
282
+ "gemma3n",
283
+ (
284
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
285
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
286
+ ),
287
+ ),
288
+ (
289
+ "gemma3n_text",
290
+ (
291
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
292
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
293
+ ),
294
+ ),
295
+ ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
296
+ ("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
297
+ ("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
298
+ ("glm4_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
299
+ ("glm4v", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
300
+ ("glm4v_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
301
+ ("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
302
+ ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
303
+ ("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
304
+ ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
305
+ ("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
306
+ ("gpt_neox_japanese", ("GPTNeoXJapaneseTokenizer", None)),
307
+ ("gpt_oss", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
308
+ ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
309
+ ("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
310
+ ("granite", ("GPT2Tokenizer", None)),
311
+ ("granitemoe", ("GPT2Tokenizer", None)),
312
+ ("granitemoehybrid", ("GPT2Tokenizer", None)),
313
+ ("granitemoeshared", ("GPT2Tokenizer", None)),
314
+ ("grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
315
+ ("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
316
+ ("helium", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
317
+ ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
318
+ ("hubert", ("Wav2Vec2CTCTokenizer", None)),
319
+ ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
320
+ ("idefics", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
321
+ ("idefics2", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
322
+ ("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
323
+ ("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
324
+ ("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
325
+ ("internvl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
326
+ (
327
+ "jamba",
328
+ (
329
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
330
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
331
+ ),
332
+ ),
333
+ ("janus", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
334
+ (
335
+ "jetmoe",
336
+ (
337
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
338
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
339
+ ),
340
+ ),
341
+ ("jukebox", ("JukeboxTokenizer", None)),
342
+ (
343
+ "kosmos-2",
344
+ (
345
+ "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
346
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
347
+ ),
348
+ ),
349
+ ("kosmos-2.5", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
350
+ ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
351
+ ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
352
+ ("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
353
+ ("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
354
+ ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
355
+ ("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
356
+ (
357
+ "llama",
358
+ (
359
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
360
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
361
+ ),
362
+ ),
363
+ (
364
+ "llama4",
365
+ (
366
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
367
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
368
+ ),
369
+ ),
370
+ (
371
+ "llama4_text",
372
+ (
373
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
374
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
375
+ ),
376
+ ),
377
+ ("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
378
+ ("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
379
+ ("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
380
+ ("llava_onevision", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
381
+ ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
382
+ (
383
+ "longt5",
384
+ (
385
+ "T5Tokenizer" if is_sentencepiece_available() else None,
386
+ "T5TokenizerFast" if is_tokenizers_available() else None,
387
+ ),
388
+ ),
389
+ ("luke", ("LukeTokenizer", None)),
390
+ ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
391
+ ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
392
+ ("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
393
+ ("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
394
+ ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
395
+ (
396
+ "mbart",
397
+ (
398
+ "MBartTokenizer" if is_sentencepiece_available() else None,
399
+ "MBartTokenizerFast" if is_tokenizers_available() else None,
400
+ ),
401
+ ),
402
+ (
403
+ "mbart50",
404
+ (
405
+ "MBart50Tokenizer" if is_sentencepiece_available() else None,
406
+ "MBart50TokenizerFast" if is_tokenizers_available() else None,
407
+ ),
408
+ ),
409
+ ("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
410
+ ("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
411
+ (
412
+ "metaclip_2",
413
+ (
414
+ "XLMRobertaTokenizer",
415
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
416
+ ),
417
+ ),
418
+ ("mgp-str", ("MgpstrTokenizer", None)),
419
+ (
420
+ "minimax",
421
+ (
422
+ "GPT2Tokenizer" if is_sentencepiece_available() else None,
423
+ "GPT2TokenizerFast" if is_tokenizers_available() else None,
424
+ ),
425
+ ),
426
+ (
427
+ "ministral",
428
+ (
429
+ "MistralCommonTokenizer"
430
+ if is_mistral_common_available()
431
+ else ("LlamaTokenizer" if is_sentencepiece_available() else None),
432
+ "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
433
+ ),
434
+ ),
435
+ (
436
+ "mistral",
437
+ (
438
+ "MistralCommonTokenizer"
439
+ if is_mistral_common_available()
440
+ else ("LlamaTokenizer" if is_sentencepiece_available() else None),
441
+ "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
442
+ ),
443
+ ),
444
+ (
445
+ "mistral3",
446
+ (
447
+ "MistralCommonTokenizer"
448
+ if is_mistral_common_available()
449
+ else ("LlamaTokenizer" if is_sentencepiece_available() else None),
450
+ "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
451
+ ),
452
+ ),
453
+ (
454
+ "mixtral",
455
+ (
456
+ "MistralCommonTokenizer"
457
+ if is_mistral_common_available()
458
+ else ("LlamaTokenizer" if is_sentencepiece_available() else None),
459
+ "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
460
+ ),
461
+ ),
462
+ ("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
463
+ ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
464
+ ("mm-grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
465
+ ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
466
+ ("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
467
+ ("moonshine", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
468
+ ("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
469
+ ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
470
+ ("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
471
+ ("mra", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
472
+ (
473
+ "mt5",
474
+ (
475
+ "MT5Tokenizer" if is_sentencepiece_available() else None,
476
+ "MT5TokenizerFast" if is_tokenizers_available() else None,
477
+ ),
478
+ ),
479
+ ("musicgen", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
480
+ ("musicgen_melody", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
481
+ ("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)),
482
+ ("myt5", ("MyT5Tokenizer", None)),
483
+ ("nemotron", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
484
+ ("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
485
+ (
486
+ "nllb",
487
+ (
488
+ "NllbTokenizer" if is_sentencepiece_available() else None,
489
+ "NllbTokenizerFast" if is_tokenizers_available() else None,
490
+ ),
491
+ ),
492
+ (
493
+ "nllb-moe",
494
+ (
495
+ "NllbTokenizer" if is_sentencepiece_available() else None,
496
+ "NllbTokenizerFast" if is_tokenizers_available() else None,
497
+ ),
498
+ ),
499
+ (
500
+ "nystromformer",
501
+ (
502
+ "AlbertTokenizer" if is_sentencepiece_available() else None,
503
+ "AlbertTokenizerFast" if is_tokenizers_available() else None,
504
+ ),
505
+ ),
506
+ ("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
507
+ ("olmo2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
508
+ ("olmo3", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
509
+ ("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
510
+ (
511
+ "omdet-turbo",
512
+ ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None),
513
+ ),
514
+ ("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
515
+ (
516
+ "openai-gpt",
517
+ ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None),
518
+ ),
519
+ ("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
520
+ ("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
521
+ ("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
522
+ ("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
523
+ ("parakeet", ("ParakeetCTCTokenizer", None)),
524
+ (
525
+ "pegasus",
526
+ (
527
+ "PegasusTokenizer" if is_sentencepiece_available() else None,
528
+ "PegasusTokenizerFast" if is_tokenizers_available() else None,
529
+ ),
530
+ ),
531
+ (
532
+ "pegasus_x",
533
+ (
534
+ "PegasusTokenizer" if is_sentencepiece_available() else None,
535
+ "PegasusTokenizerFast" if is_tokenizers_available() else None,
536
+ ),
537
+ ),
538
+ (
539
+ "perceiver",
540
+ (
541
+ "PerceiverTokenizer",
542
+ None,
543
+ ),
544
+ ),
545
+ (
546
+ "persimmon",
547
+ (
548
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
549
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
550
+ ),
551
+ ),
552
+ ("phi", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
553
+ ("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
554
+ ("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
555
+ ("phobert", ("PhobertTokenizer", None)),
556
+ ("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
557
+ (
558
+ "pixtral",
559
+ (
560
+ None,
561
+ "MistralCommonTokenizer"
562
+ if is_mistral_common_available()
563
+ else ("PreTrainedTokenizerFast" if is_tokenizers_available() else None),
564
+ ),
565
+ ),
566
+ ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
567
+ ("prophetnet", ("ProphetNetTokenizer", None)),
568
+ ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
569
+ (
570
+ "qwen2",
571
+ (
572
+ "Qwen2Tokenizer",
573
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
574
+ ),
575
+ ),
576
+ ("qwen2_5_omni", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
577
+ ("qwen2_5_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
578
+ ("qwen2_audio", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
579
+ (
580
+ "qwen2_moe",
581
+ (
582
+ "Qwen2Tokenizer",
583
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
584
+ ),
585
+ ),
586
+ ("qwen2_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
587
+ (
588
+ "qwen3",
589
+ (
590
+ "Qwen2Tokenizer",
591
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
592
+ ),
593
+ ),
594
+ (
595
+ "qwen3_moe",
596
+ (
597
+ "Qwen2Tokenizer",
598
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
599
+ ),
600
+ ),
601
+ (
602
+ "qwen3_next",
603
+ (
604
+ "Qwen2Tokenizer",
605
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
606
+ ),
607
+ ),
608
+ ("qwen3_omni_moe", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
609
+ ("qwen3_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
610
+ ("qwen3_vl_moe", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
611
+ ("rag", ("RagTokenizer", None)),
612
+ ("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
613
+ (
614
+ "recurrent_gemma",
615
+ (
616
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
617
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
618
+ ),
619
+ ),
620
+ (
621
+ "reformer",
622
+ (
623
+ "ReformerTokenizer" if is_sentencepiece_available() else None,
624
+ "ReformerTokenizerFast" if is_tokenizers_available() else None,
625
+ ),
626
+ ),
627
+ (
628
+ "rembert",
629
+ (
630
+ "RemBertTokenizer" if is_sentencepiece_available() else None,
631
+ "RemBertTokenizerFast" if is_tokenizers_available() else None,
632
+ ),
633
+ ),
634
+ ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
635
+ ("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
636
+ (
637
+ "roberta-prelayernorm",
638
+ ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None),
639
+ ),
640
+ ("roc_bert", ("RoCBertTokenizer", None)),
641
+ ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
642
+ ("rwkv", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
643
+ (
644
+ "seamless_m4t",
645
+ (
646
+ "SeamlessM4TTokenizer" if is_sentencepiece_available() else None,
647
+ "SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
648
+ ),
649
+ ),
650
+ (
651
+ "seamless_m4t_v2",
652
+ (
653
+ "SeamlessM4TTokenizer" if is_sentencepiece_available() else None,
654
+ "SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
655
+ ),
656
+ ),
657
+ (
658
+ "shieldgemma2",
659
+ (
660
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
661
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
662
+ ),
663
+ ),
664
+ ("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)),
665
+ (
666
+ "siglip2",
667
+ (
668
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
669
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
670
+ ),
671
+ ),
672
+ ("smollm3", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
673
+ ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
674
+ ("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
675
+ ("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
676
+ ("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
677
+ (
678
+ "squeezebert",
679
+ ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
680
+ ),
681
+ ("stablelm", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
682
+ ("starcoder2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
683
+ (
684
+ "switch_transformers",
685
+ (
686
+ "T5Tokenizer" if is_sentencepiece_available() else None,
687
+ "T5TokenizerFast" if is_tokenizers_available() else None,
688
+ ),
689
+ ),
690
+ (
691
+ "t5",
692
+ (
693
+ "T5Tokenizer" if is_sentencepiece_available() else None,
694
+ "T5TokenizerFast" if is_tokenizers_available() else None,
695
+ ),
696
+ ),
697
+ (
698
+ "t5gemma",
699
+ (
700
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
701
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
702
+ ),
703
+ ),
704
+ ("tapas", ("TapasTokenizer", None)),
705
+ ("tapex", ("TapexTokenizer", None)),
706
+ ("transfo-xl", ("TransfoXLTokenizer", None)),
707
+ ("tvp", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
708
+ (
709
+ "udop",
710
+ (
711
+ "UdopTokenizer" if is_sentencepiece_available() else None,
712
+ "UdopTokenizerFast" if is_tokenizers_available() else None,
713
+ ),
714
+ ),
715
+ (
716
+ "umt5",
717
+ (
718
+ "T5Tokenizer" if is_sentencepiece_available() else None,
719
+ "T5TokenizerFast" if is_tokenizers_available() else None,
720
+ ),
721
+ ),
722
+ ("video_llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
723
+ ("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
724
+ ("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
725
+ ("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
726
+ ("vits", ("VitsTokenizer", None)),
727
+ (
728
+ "voxtral",
729
+ (
730
+ "MistralCommonTokenizer" if is_mistral_common_available() else None,
731
+ "PreTrainedTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
732
+ ),
733
+ ),
734
+ ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
735
+ ("wav2vec2-bert", ("Wav2Vec2CTCTokenizer", None)),
736
+ ("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
737
+ ("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
738
+ ("whisper", ("WhisperTokenizer", "WhisperTokenizerFast" if is_tokenizers_available() else None)),
739
+ ("xclip", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
740
+ (
741
+ "xglm",
742
+ (
743
+ "XGLMTokenizer" if is_sentencepiece_available() else None,
744
+ "XGLMTokenizerFast" if is_tokenizers_available() else None,
745
+ ),
746
+ ),
747
+ ("xlm", ("XLMTokenizer", None)),
748
+ ("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
749
+ (
750
+ "xlm-roberta",
751
+ (
752
+ "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
753
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
754
+ ),
755
+ ),
756
+ (
757
+ "xlm-roberta-xl",
758
+ (
759
+ "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
760
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
761
+ ),
762
+ ),
763
+ (
764
+ "xlnet",
765
+ (
766
+ "XLNetTokenizer" if is_sentencepiece_available() else None,
767
+ "XLNetTokenizerFast" if is_tokenizers_available() else None,
768
+ ),
769
+ ),
770
+ ("xlstm", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
771
+ (
772
+ "xmod",
773
+ (
774
+ "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
775
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
776
+ ),
777
+ ),
778
+ (
779
+ "yoso",
780
+ (
781
+ "AlbertTokenizer" if is_sentencepiece_available() else None,
782
+ "AlbertTokenizerFast" if is_tokenizers_available() else None,
783
+ ),
784
+ ),
785
+ (
786
+ "zamba",
787
+ (
788
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
789
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
790
+ ),
791
+ ),
792
+ (
793
+ "zamba2",
794
+ (
795
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
796
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
797
+ ),
798
+ ),
799
+ ]
800
+ )
801
+
802
+ TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
803
+
804
+ CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
805
+
806
+
807
+ def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
808
+ if class_name == "PreTrainedTokenizerFast":
809
+ return PreTrainedTokenizerFast
810
+
811
+ for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
812
+ if class_name in tokenizers:
813
+ module_name = model_type_to_module_name(module_name)
814
+ if module_name in ["mistral", "mixtral", "ministral"] and class_name == "MistralCommonTokenizer":
815
+ module = importlib.import_module(".tokenization_mistral_common", "transformers")
816
+ else:
817
+ module = importlib.import_module(f".{module_name}", "transformers.models")
818
+ try:
819
+ return getattr(module, class_name)
820
+ except AttributeError:
821
+ continue
822
+
823
+ for tokenizers in TOKENIZER_MAPPING._extra_content.values():
824
+ for tokenizer in tokenizers:
825
+ if getattr(tokenizer, "__name__", None) == class_name:
826
+ return tokenizer
827
+
828
+ # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
829
+ # init and we return the proper dummy to get an appropriate error message.
830
+ main_module = importlib.import_module("transformers")
831
+ if hasattr(main_module, class_name):
832
+ return getattr(main_module, class_name)
833
+
834
+ return None
835
+
836
+
837
+ def get_tokenizer_config(
838
+ pretrained_model_name_or_path: Union[str, os.PathLike[str]],
839
+ cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
840
+ force_download: bool = False,
841
+ resume_download: Optional[bool] = None,
842
+ proxies: Optional[dict[str, str]] = None,
843
+ token: Optional[Union[bool, str]] = None,
844
+ revision: Optional[str] = None,
845
+ local_files_only: bool = False,
846
+ subfolder: str = "",
847
+ **kwargs,
848
+ ) -> dict[str, Any]:
849
+ """
850
+ Loads the tokenizer configuration from a pretrained model tokenizer configuration.
851
+
852
+ Args:
853
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
854
+ This can be either:
855
+
856
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
857
+ huggingface.co.
858
+ - a path to a *directory* containing a configuration file saved using the
859
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
860
+
861
+ cache_dir (`str` or `os.PathLike`, *optional*):
862
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
863
+ cache should not be used.
864
+ force_download (`bool`, *optional*, defaults to `False`):
865
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
866
+ exist.
867
+ resume_download:
868
+ Deprecated and ignored. All downloads are now resumed by default when possible.
869
+ Will be removed in v5 of Transformers.
870
+ proxies (`dict[str, str]`, *optional*):
871
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
872
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
873
+ token (`str` or *bool*, *optional*):
874
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
875
+ when running `hf auth login` (stored in `~/.huggingface`).
876
+ revision (`str`, *optional*, defaults to `"main"`):
877
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
878
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
879
+ identifier allowed by git.
880
+ local_files_only (`bool`, *optional*, defaults to `False`):
881
+ If `True`, will only try to load the tokenizer configuration from local files.
882
+ subfolder (`str`, *optional*, defaults to `""`):
883
+ In case the tokenizer config is located inside a subfolder of the model repo on huggingface.co, you can
884
+ specify the folder name here.
885
+
886
+ <Tip>
887
+
888
+ Passing `token=True` is required when you want to use a private model.
889
+
890
+ </Tip>
891
+
892
+ Returns:
893
+ `dict`: The configuration of the tokenizer.
894
+
895
+ Examples:
896
+
897
+ ```python
898
+ # Download configuration from huggingface.co and cache.
899
+ tokenizer_config = get_tokenizer_config("google-bert/bert-base-uncased")
900
+ # This model does not have a tokenizer config so the result will be an empty dict.
901
+ tokenizer_config = get_tokenizer_config("FacebookAI/xlm-roberta-base")
902
+
903
+ # Save a pretrained tokenizer locally and you can reload its config
904
+ from transformers import AutoTokenizer
905
+
906
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
907
+ tokenizer.save_pretrained("tokenizer-test")
908
+ tokenizer_config = get_tokenizer_config("tokenizer-test")
909
+ ```"""
910
+ use_auth_token = kwargs.pop("use_auth_token", None)
911
+ if use_auth_token is not None:
912
+ warnings.warn(
913
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
914
+ FutureWarning,
915
+ )
916
+ if token is not None:
917
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
918
+ token = use_auth_token
919
+
920
+ commit_hash = kwargs.get("_commit_hash")
921
+ resolved_config_file = cached_file(
922
+ pretrained_model_name_or_path,
923
+ TOKENIZER_CONFIG_FILE,
924
+ cache_dir=cache_dir,
925
+ force_download=force_download,
926
+ resume_download=resume_download,
927
+ proxies=proxies,
928
+ token=token,
929
+ revision=revision,
930
+ local_files_only=local_files_only,
931
+ subfolder=subfolder,
932
+ _raise_exceptions_for_gated_repo=False,
933
+ _raise_exceptions_for_missing_entries=False,
934
+ _raise_exceptions_for_connection_errors=False,
935
+ _commit_hash=commit_hash,
936
+ )
937
+ if resolved_config_file is None:
938
+ logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
939
+ return {}
940
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
941
+
942
+ with open(resolved_config_file, encoding="utf-8") as reader:
943
+ result = json.load(reader)
944
+ result["_commit_hash"] = commit_hash
945
+ return result
946
+
947
+
948
+ class AutoTokenizer:
949
+ r"""
950
+ This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
951
+ created with the [`AutoTokenizer.from_pretrained`] class method.
952
+
953
+ This class cannot be instantiated directly using `__init__()` (throws an error).
954
+ """
955
+
956
+ def __init__(self):
957
+ raise OSError(
958
+ "AutoTokenizer is designed to be instantiated "
959
+ "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
960
+ )
961
+
962
+ @classmethod
963
+ @replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
964
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
965
+ r"""
966
+ Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.
967
+
968
+ The tokenizer class to instantiate is selected based on the `model_type` property of the config object (either
969
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
970
+ falling back to using pattern matching on `pretrained_model_name_or_path`:
971
+
972
+ List options
973
+
974
+ Params:
975
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
976
+ Can be either:
977
+
978
+ - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
979
+ - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
980
+ using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
981
+ - A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
982
+ single vocabulary file (like Bert or XLNet), e.g.: `./my_model_directory/vocab.txt`. (Not
983
+ applicable to all derived classes)
984
+ inputs (additional positional arguments, *optional*):
985
+ Will be passed along to the Tokenizer `__init__()` method.
986
+ config ([`PretrainedConfig`], *optional*)
987
+ The configuration object used to determine the tokenizer class to instantiate.
988
+ cache_dir (`str` or `os.PathLike`, *optional*):
989
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
990
+ standard cache should not be used.
991
+ force_download (`bool`, *optional*, defaults to `False`):
992
+ Whether or not to force the (re-)download the model weights and configuration files and override the
993
+ cached versions if they exist.
994
+ resume_download:
995
+ Deprecated and ignored. All downloads are now resumed by default when possible.
996
+ Will be removed in v5 of Transformers.
997
+ proxies (`dict[str, str]`, *optional*):
998
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
999
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1000
+ revision (`str`, *optional*, defaults to `"main"`):
1001
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
1002
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
1003
+ identifier allowed by git.
1004
+ subfolder (`str`, *optional*):
1005
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
1006
+ facebook/rag-token-base), specify it here.
1007
+ use_fast (`bool`, *optional*, defaults to `True`):
1008
+ Use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported for
1009
+ a given model. If a fast tokenizer is not available for a given model, a normal Python-based tokenizer
1010
+ is returned instead.
1011
+ tokenizer_type (`str`, *optional*):
1012
+ Tokenizer type to be loaded.
1013
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
1014
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
1015
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
1016
+ execute code present on the Hub on your local machine.
1017
+ kwargs (additional keyword arguments, *optional*):
1018
+ Will be passed to the Tokenizer `__init__()` method. Can be used to set special tokens like
1019
+ `bos_token`, `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`,
1020
+ `additional_special_tokens`. See parameters in the `__init__()` for more details.
1021
+
1022
+ Examples:
1023
+
1024
+ ```python
1025
+ >>> from transformers import AutoTokenizer
1026
+
1027
+ >>> # Download vocabulary from huggingface.co and cache.
1028
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
1029
+
1030
+ >>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
1031
+ >>> tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
1032
+
1033
+ >>> # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
1034
+ >>> # tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")
1035
+
1036
+ >>> # Download vocabulary from huggingface.co and define model-specific arguments
1037
+ >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base", add_prefix_space=True)
1038
+ ```"""
1039
+ use_auth_token = kwargs.pop("use_auth_token", None)
1040
+ if use_auth_token is not None:
1041
+ warnings.warn(
1042
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
1043
+ FutureWarning,
1044
+ )
1045
+ if kwargs.get("token") is not None:
1046
+ raise ValueError(
1047
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
1048
+ )
1049
+ kwargs["token"] = use_auth_token
1050
+
1051
+ config = kwargs.pop("config", None)
1052
+ kwargs["_from_auto"] = True
1053
+
1054
+ use_fast = kwargs.pop("use_fast", True)
1055
+ tokenizer_type = kwargs.pop("tokenizer_type", None)
1056
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
1057
+ gguf_file = kwargs.get("gguf_file")
1058
+
1059
+ # First, let's see whether the tokenizer_type is passed so that we can leverage it
1060
+ if tokenizer_type is not None:
1061
+ tokenizer_class = None
1062
+ tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None)
1063
+
1064
+ if tokenizer_class_tuple is None:
1065
+ raise ValueError(
1066
+ f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of "
1067
+ f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES)}."
1068
+ )
1069
+
1070
+ tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple
1071
+
1072
+ if use_fast:
1073
+ if tokenizer_fast_class_name is not None:
1074
+ tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name)
1075
+ else:
1076
+ logger.warning(
1077
+ "`use_fast` is set to `True` but the tokenizer class does not have a fast version. "
1078
+ " Falling back to the slow version."
1079
+ )
1080
+ if tokenizer_class is None:
1081
+ tokenizer_class = tokenizer_class_from_name(tokenizer_class_name)
1082
+
1083
+ if tokenizer_class is None:
1084
+ raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.")
1085
+
1086
+ return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
1087
+
1088
+ # Next, let's try to use the tokenizer_config file to get the tokenizer class.
1089
+ tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
1090
+ if "_commit_hash" in tokenizer_config:
1091
+ kwargs["_commit_hash"] = tokenizer_config["_commit_hash"]
1092
+ config_tokenizer_class = tokenizer_config.get("tokenizer_class")
1093
+ tokenizer_auto_map = None
1094
+ if "auto_map" in tokenizer_config:
1095
+ if isinstance(tokenizer_config["auto_map"], (tuple, list)):
1096
+ # Legacy format for dynamic tokenizers
1097
+ tokenizer_auto_map = tokenizer_config["auto_map"]
1098
+ else:
1099
+ tokenizer_auto_map = tokenizer_config["auto_map"].get("AutoTokenizer", None)
1100
+
1101
+ # If that did not work, let's try to use the config.
1102
+ if config_tokenizer_class is None:
1103
+ if not isinstance(config, PretrainedConfig):
1104
+ if gguf_file:
1105
+ gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **kwargs)
1106
+ config_dict = load_gguf_checkpoint(gguf_path, return_tensors=False)["config"]
1107
+ config = AutoConfig.for_model(**config_dict)
1108
+ else:
1109
+ config = AutoConfig.from_pretrained(
1110
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
1111
+ )
1112
+ config_tokenizer_class = config.tokenizer_class
1113
+ if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
1114
+ tokenizer_auto_map = config.auto_map["AutoTokenizer"]
1115
+
1116
+ has_remote_code = tokenizer_auto_map is not None
1117
+ has_local_code = type(config) in TOKENIZER_MAPPING or (
1118
+ config_tokenizer_class is not None
1119
+ and (
1120
+ tokenizer_class_from_name(config_tokenizer_class) is not None
1121
+ or tokenizer_class_from_name(config_tokenizer_class + "Fast") is not None
1122
+ )
1123
+ )
1124
+ if has_remote_code:
1125
+ if use_fast and tokenizer_auto_map[1] is not None:
1126
+ class_ref = tokenizer_auto_map[1]
1127
+ else:
1128
+ class_ref = tokenizer_auto_map[0]
1129
+ if "--" in class_ref:
1130
+ upstream_repo = class_ref.split("--")[0]
1131
+ else:
1132
+ upstream_repo = None
1133
+ trust_remote_code = resolve_trust_remote_code(
1134
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
1135
+ )
1136
+
1137
+ if has_remote_code and trust_remote_code:
1138
+ tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
1139
+ _ = kwargs.pop("code_revision", None)
1140
+ tokenizer_class.register_for_auto_class()
1141
+ return tokenizer_class.from_pretrained(
1142
+ pretrained_model_name_or_path, *inputs, trust_remote_code=trust_remote_code, **kwargs
1143
+ )
1144
+ elif config_tokenizer_class is not None:
1145
+ tokenizer_class = None
1146
+ if use_fast and not config_tokenizer_class.endswith("Fast"):
1147
+ tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
1148
+ tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
1149
+ if tokenizer_class is None:
1150
+ tokenizer_class_candidate = config_tokenizer_class
1151
+ tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
1152
+ if tokenizer_class is None:
1153
+ raise ValueError(
1154
+ f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
1155
+ )
1156
+ return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
1157
+
1158
+ # Otherwise we have to be creative.
1159
+ # if model is an encoder decoder, the encoder tokenizer class is used by default
1160
+ if isinstance(config, EncoderDecoderConfig):
1161
+ if type(config.decoder) is not type(config.encoder):
1162
+ logger.warning(
1163
+ f"The encoder model config class: {config.encoder.__class__} is different from the decoder model "
1164
+ f"config class: {config.decoder.__class__}. It is not recommended to use the "
1165
+ "`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder "
1166
+ "specific tokenizer classes."
1167
+ )
1168
+ config = config.encoder
1169
+
1170
+ model_type = config_class_to_model_type(type(config).__name__)
1171
+ if model_type is not None:
1172
+ tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
1173
+
1174
+ if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
1175
+ return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
1176
+ else:
1177
+ if tokenizer_class_py is not None:
1178
+ return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
1179
+ else:
1180
+ raise ValueError(
1181
+ "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed "
1182
+ "in order to use this tokenizer."
1183
+ )
1184
+
1185
+ raise ValueError(
1186
+ f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n"
1187
+ f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING)}."
1188
+ )
1189
+
1190
+ @staticmethod
1191
+ def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, exist_ok=False):
1192
+ """
1193
+ Register a new tokenizer in this mapping.
1194
+
1195
+
1196
+ Args:
1197
+ config_class ([`PretrainedConfig`]):
1198
+ The configuration corresponding to the model to register.
1199
+ slow_tokenizer_class ([`PretrainedTokenizer`], *optional*):
1200
+ The slow tokenizer to register.
1201
+ fast_tokenizer_class ([`PretrainedTokenizerFast`], *optional*):
1202
+ The fast tokenizer to register.
1203
+ """
1204
+ if slow_tokenizer_class is None and fast_tokenizer_class is None:
1205
+ raise ValueError("You need to pass either a `slow_tokenizer_class` or a `fast_tokenizer_class")
1206
+ if slow_tokenizer_class is not None and issubclass(slow_tokenizer_class, PreTrainedTokenizerFast):
1207
+ raise ValueError("You passed a fast tokenizer in the `slow_tokenizer_class`.")
1208
+ if fast_tokenizer_class is not None and issubclass(fast_tokenizer_class, PreTrainedTokenizer):
1209
+ raise ValueError("You passed a slow tokenizer in the `fast_tokenizer_class`.")
1210
+
1211
+ if (
1212
+ slow_tokenizer_class is not None
1213
+ and fast_tokenizer_class is not None
1214
+ and issubclass(fast_tokenizer_class, PreTrainedTokenizerFast)
1215
+ and fast_tokenizer_class.slow_tokenizer_class != slow_tokenizer_class
1216
+ ):
1217
+ raise ValueError(
1218
+ "The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not "
1219
+ "consistent with the slow tokenizer class you passed (fast tokenizer has "
1220
+ f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those "
1221
+ "so they match!"
1222
+ )
1223
+
1224
+ # Avoid resetting a set slow/fast tokenizer if we are passing just the other ones.
1225
+ if config_class in TOKENIZER_MAPPING._extra_content:
1226
+ existing_slow, existing_fast = TOKENIZER_MAPPING[config_class]
1227
+ if slow_tokenizer_class is None:
1228
+ slow_tokenizer_class = existing_slow
1229
+ if fast_tokenizer_class is None:
1230
+ fast_tokenizer_class = existing_fast
1231
+
1232
+ TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class), exist_ok=exist_ok)
1233
+
1234
+
1235
+ __all__ = ["TOKENIZER_MAPPING", "AutoTokenizer"]
venv/lib/python3.13/site-packages/transformers/models/auto/video_processing_auto.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """AutoVideoProcessor class."""
16
+
17
+ import importlib
18
+ import json
19
+ import os
20
+ import warnings
21
+ from collections import OrderedDict
22
+ from typing import TYPE_CHECKING, Optional, Union
23
+
24
+ # Build the list of all video processors
25
+ from ...configuration_utils import PretrainedConfig
26
+ from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
27
+ from ...utils import CONFIG_NAME, VIDEO_PROCESSOR_NAME, cached_file, is_torchvision_available, logging
28
+ from ...utils.import_utils import requires
29
+ from ...video_processing_utils import BaseVideoProcessor
30
+ from .auto_factory import _LazyAutoMapping
31
+ from .configuration_auto import (
32
+ CONFIG_MAPPING_NAMES,
33
+ AutoConfig,
34
+ model_type_to_module_name,
35
+ replace_list_option_in_docstrings,
36
+ )
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ if TYPE_CHECKING:
43
+ # This significantly improves completion suggestion performance when
44
+ # the transformers package is used with Microsoft's Pylance language server.
45
+ VIDEO_PROCESSOR_MAPPING_NAMES: OrderedDict[str, tuple[Optional[str], Optional[str]]] = OrderedDict()
46
+ else:
47
+ VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict(
48
+ [
49
+ ("glm4v", "Glm4vVideoProcessor"),
50
+ ("instructblip", "InstructBlipVideoVideoProcessor"),
51
+ ("instructblipvideo", "InstructBlipVideoVideoProcessor"),
52
+ ("internvl", "InternVLVideoProcessor"),
53
+ ("llava_next_video", "LlavaNextVideoVideoProcessor"),
54
+ ("llava_onevision", "LlavaOnevisionVideoProcessor"),
55
+ ("perception_lm", "PerceptionLMVideoProcessor"),
56
+ ("qwen2_5_omni", "Qwen2VLVideoProcessor"),
57
+ ("qwen2_5_vl", "Qwen2VLVideoProcessor"),
58
+ ("qwen2_vl", "Qwen2VLVideoProcessor"),
59
+ ("qwen3_omni_moe", "Qwen2VLVideoProcessor"),
60
+ ("qwen3_vl", "Qwen3VLVideoProcessor"),
61
+ ("qwen3_vl_moe", "Qwen3VLVideoProcessor"),
62
+ ("sam2_video", "Sam2VideoVideoProcessor"),
63
+ ("smolvlm", "SmolVLMVideoProcessor"),
64
+ ("video_llava", "VideoLlavaVideoProcessor"),
65
+ ("vjepa2", "VJEPA2VideoProcessor"),
66
+ ]
67
+ )
68
+
69
+ for model_type, video_processors in VIDEO_PROCESSOR_MAPPING_NAMES.items():
70
+ fast_video_processor_class = video_processors
71
+
72
+ # If the torchvision is not available, we set it to None
73
+ if not is_torchvision_available():
74
+ fast_video_processor_class = None
75
+
76
+ VIDEO_PROCESSOR_MAPPING_NAMES[model_type] = fast_video_processor_class
77
+
78
+ VIDEO_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, VIDEO_PROCESSOR_MAPPING_NAMES)
79
+
80
+
81
+ def video_processor_class_from_name(class_name: str):
82
+ for module_name, extractors in VIDEO_PROCESSOR_MAPPING_NAMES.items():
83
+ if class_name in extractors:
84
+ module_name = model_type_to_module_name(module_name)
85
+
86
+ module = importlib.import_module(f".{module_name}", "transformers.models")
87
+ try:
88
+ return getattr(module, class_name)
89
+ except AttributeError:
90
+ continue
91
+
92
+ for extractor in VIDEO_PROCESSOR_MAPPING._extra_content.values():
93
+ if getattr(extractor, "__name__", None) == class_name:
94
+ return extractor
95
+
96
+ # We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main
97
+ # init and we return the proper dummy to get an appropriate error message.
98
+ main_module = importlib.import_module("transformers")
99
+ if hasattr(main_module, class_name):
100
+ return getattr(main_module, class_name)
101
+
102
+ return None
103
+
104
+
105
+ def get_video_processor_config(
106
+ pretrained_model_name_or_path: Union[str, os.PathLike],
107
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
108
+ force_download: bool = False,
109
+ resume_download: Optional[bool] = None,
110
+ proxies: Optional[dict[str, str]] = None,
111
+ token: Optional[Union[bool, str]] = None,
112
+ revision: Optional[str] = None,
113
+ local_files_only: bool = False,
114
+ **kwargs,
115
+ ):
116
+ """
117
+ Loads the video processor configuration from a pretrained model video processor configuration.
118
+
119
+ Args:
120
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
121
+ This can be either:
122
+
123
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
124
+ huggingface.co.
125
+ - a path to a *directory* containing a configuration file saved using the
126
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
127
+
128
+ cache_dir (`str` or `os.PathLike`, *optional*):
129
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
130
+ cache should not be used.
131
+ force_download (`bool`, *optional*, defaults to `False`):
132
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
133
+ exist.
134
+ resume_download:
135
+ Deprecated and ignored. All downloads are now resumed by default when possible.
136
+ Will be removed in v5 of Transformers.
137
+ proxies (`dict[str, str]`, *optional*):
138
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
139
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
140
+ token (`str` or *bool*, *optional*):
141
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
142
+ when running `hf auth login` (stored in `~/.huggingface`).
143
+ revision (`str`, *optional*, defaults to `"main"`):
144
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
145
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
146
+ identifier allowed by git.
147
+ local_files_only (`bool`, *optional*, defaults to `False`):
148
+ If `True`, will only try to load the video processor configuration from local files.
149
+
150
+ <Tip>
151
+
152
+ Passing `token=True` is required when you want to use a private model.
153
+
154
+ </Tip>
155
+
156
+ Returns:
157
+ `Dict`: The configuration of the video processor.
158
+
159
+ Examples:
160
+
161
+ ```python
162
+ # Download configuration from huggingface.co and cache.
163
+ video_processor_config = get_video_processor_config("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
164
+ # This model does not have a video processor config so the result will be an empty dict.
165
+ video_processor_config = get_video_processor_config("FacebookAI/xlm-roberta-base")
166
+
167
+ # Save a pretrained video processor locally and you can reload its config
168
+ from transformers import AutoVideoProcessor
169
+
170
+ video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
171
+ video_processor.save_pretrained("video-processor-test")
172
+ video_processor = get_video_processor_config("video-processor-test")
173
+ ```"""
174
+ use_auth_token = kwargs.pop("use_auth_token", None)
175
+ if use_auth_token is not None:
176
+ warnings.warn(
177
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
178
+ FutureWarning,
179
+ )
180
+ if token is not None:
181
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
182
+ token = use_auth_token
183
+
184
+ resolved_config_file = cached_file(
185
+ pretrained_model_name_or_path,
186
+ VIDEO_PROCESSOR_NAME,
187
+ cache_dir=cache_dir,
188
+ force_download=force_download,
189
+ resume_download=resume_download,
190
+ proxies=proxies,
191
+ token=token,
192
+ revision=revision,
193
+ local_files_only=local_files_only,
194
+ )
195
+ if resolved_config_file is None:
196
+ logger.info(
197
+ "Could not locate the video processor configuration file, will try to use the model config instead."
198
+ )
199
+ return {}
200
+
201
+ with open(resolved_config_file, encoding="utf-8") as reader:
202
+ return json.load(reader)
203
+
204
+
205
+ @requires(backends=("vision", "torchvision"))
206
+ class AutoVideoProcessor:
207
+ r"""
208
+ This is a generic video processor class that will be instantiated as one of the video processor classes of the
209
+ library when created with the [`AutoVideoProcessor.from_pretrained`] class method.
210
+
211
+ This class cannot be instantiated directly using `__init__()` (throws an error).
212
+ """
213
+
214
+ def __init__(self):
215
+ raise OSError(
216
+ "AutoVideoProcessor is designed to be instantiated "
217
+ "using the `AutoVideoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
218
+ )
219
+
220
+ @classmethod
221
+ @replace_list_option_in_docstrings(VIDEO_PROCESSOR_MAPPING_NAMES)
222
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
223
+ r"""
224
+ Instantiate one of the video processor classes of the library from a pretrained model vocabulary.
225
+
226
+ The video processor class to instantiate is selected based on the `model_type` property of the config object
227
+ (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
228
+ missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
229
+
230
+ List options
231
+
232
+ Params:
233
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
234
+ This can be either:
235
+
236
+ - a string, the *model id* of a pretrained video_processor hosted inside a model repo on
237
+ huggingface.co.
238
+ - a path to a *directory* containing a video processor file saved using the
239
+ [`~video_processing_utils.BaseVideoProcessor.save_pretrained`] method, e.g.,
240
+ `./my_model_directory/`.
241
+ - a path or url to a saved video processor JSON *file*, e.g.,
242
+ `./my_model_directory/preprocessor_config.json`.
243
+ cache_dir (`str` or `os.PathLike`, *optional*):
244
+ Path to a directory in which a downloaded pretrained model video processor should be cached if the
245
+ standard cache should not be used.
246
+ force_download (`bool`, *optional*, defaults to `False`):
247
+ Whether or not to force to (re-)download the video processor files and override the cached versions if
248
+ they exist.
249
+ resume_download:
250
+ Deprecated and ignored. All downloads are now resumed by default when possible.
251
+ Will be removed in v5 of Transformers.
252
+ proxies (`dict[str, str]`, *optional*):
253
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
254
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
255
+ token (`str` or *bool*, *optional*):
256
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
257
+ when running `hf auth login` (stored in `~/.huggingface`).
258
+ revision (`str`, *optional*, defaults to `"main"`):
259
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
260
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
261
+ identifier allowed by git.
262
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
263
+ If `False`, then this function returns just the final video processor object. If `True`, then this
264
+ functions returns a `Tuple(video_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
265
+ consisting of the key/value pairs whose keys are not video processor attributes: i.e., the part of
266
+ `kwargs` which has not been used to update `video_processor` and is otherwise ignored.
267
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
268
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
269
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
270
+ execute code present on the Hub on your local machine.
271
+ kwargs (`dict[str, Any]`, *optional*):
272
+ The values in kwargs of any keys which are video processor attributes will be used to override the
273
+ loaded values. Behavior concerning key/value pairs whose keys are *not* video processor attributes is
274
+ controlled by the `return_unused_kwargs` keyword parameter.
275
+
276
+ <Tip>
277
+
278
+ Passing `token=True` is required when you want to use a private model.
279
+
280
+ </Tip>
281
+
282
+ Examples:
283
+
284
+ ```python
285
+ >>> from transformers import AutoVideoProcessor
286
+
287
+ >>> # Download video processor from huggingface.co and cache.
288
+ >>> video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
289
+
290
+ >>> # If video processor files are in a directory (e.g. video processor was saved using *save_pretrained('./test/saved_model/')*)
291
+ >>> # video_processor = AutoVideoProcessor.from_pretrained("./test/saved_model/")
292
+ ```"""
293
+ use_auth_token = kwargs.pop("use_auth_token", None)
294
+ if use_auth_token is not None:
295
+ warnings.warn(
296
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
297
+ FutureWarning,
298
+ )
299
+ if kwargs.get("token") is not None:
300
+ raise ValueError(
301
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
302
+ )
303
+ kwargs["token"] = use_auth_token
304
+
305
+ config = kwargs.pop("config", None)
306
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
307
+ kwargs["_from_auto"] = True
308
+
309
+ config_dict, _ = BaseVideoProcessor.get_video_processor_dict(pretrained_model_name_or_path, **kwargs)
310
+ video_processor_class = config_dict.get("video_processor_type", None)
311
+ video_processor_auto_map = None
312
+ if "AutoVideoProcessor" in config_dict.get("auto_map", {}):
313
+ video_processor_auto_map = config_dict["auto_map"]["AutoVideoProcessor"]
314
+
315
+ # If we still don't have the video processor class, check if we're loading from a previous image processor config
316
+ # and if so, infer the video processor class from there.
317
+ if video_processor_class is None and video_processor_auto_map is None:
318
+ image_processor_class = config_dict.pop("image_processor_type", None)
319
+ if image_processor_class is not None:
320
+ video_processor_class_inferred = image_processor_class.replace("ImageProcessor", "VideoProcessor")
321
+
322
+ # Some models have different image processors, e.g. InternVL uses GotOCRImageProcessor
323
+ # We cannot use GotOCRVideoProcessor when falling back for BC and should try to infer from config later on
324
+ if video_processor_class_inferred in VIDEO_PROCESSOR_MAPPING_NAMES.values():
325
+ video_processor_class = video_processor_class_inferred
326
+ if "AutoImageProcessor" in config_dict.get("auto_map", {}):
327
+ image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
328
+ video_processor_auto_map = image_processor_auto_map.replace("ImageProcessor", "VideoProcessor")
329
+
330
+ # If we don't find the video processor class in the video processor config, let's try the model config.
331
+ if video_processor_class is None and video_processor_auto_map is None:
332
+ if not isinstance(config, PretrainedConfig):
333
+ config = AutoConfig.from_pretrained(
334
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
335
+ )
336
+ # It could be in `config.video_processor_type``
337
+ video_processor_class = getattr(config, "video_processor_type", None)
338
+ if hasattr(config, "auto_map") and "AutoVideoProcessor" in config.auto_map:
339
+ video_processor_auto_map = config.auto_map["AutoVideoProcessor"]
340
+
341
+ if video_processor_class is not None:
342
+ video_processor_class = video_processor_class_from_name(video_processor_class)
343
+
344
+ has_remote_code = video_processor_auto_map is not None
345
+ has_local_code = video_processor_class is not None or type(config) in VIDEO_PROCESSOR_MAPPING
346
+ trust_remote_code = resolve_trust_remote_code(
347
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
348
+ )
349
+
350
+ if has_remote_code and trust_remote_code:
351
+ class_ref = video_processor_auto_map
352
+ video_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
353
+ _ = kwargs.pop("code_revision", None)
354
+ video_processor_class.register_for_auto_class()
355
+ return video_processor_class.from_dict(config_dict, **kwargs)
356
+ elif video_processor_class is not None:
357
+ return video_processor_class.from_dict(config_dict, **kwargs)
358
+ # Last try: we use the VIDEO_PROCESSOR_MAPPING.
359
+ elif type(config) in VIDEO_PROCESSOR_MAPPING:
360
+ video_processor_class = VIDEO_PROCESSOR_MAPPING[type(config)]
361
+
362
+ if video_processor_class is not None:
363
+ return video_processor_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
364
+ else:
365
+ raise ValueError(
366
+ "This video processor cannot be instantiated. Please make sure you have `torchvision` installed."
367
+ )
368
+
369
+ raise ValueError(
370
+ f"Unrecognized video processor in {pretrained_model_name_or_path}. Should have a "
371
+ f"`video_processor_type` key in its {VIDEO_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
372
+ f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in VIDEO_PROCESSOR_MAPPING_NAMES)}"
373
+ )
374
+
375
+ @staticmethod
376
+ def register(
377
+ config_class,
378
+ video_processor_class,
379
+ exist_ok=False,
380
+ ):
381
+ """
382
+ Register a new video processor for this class.
383
+
384
+ Args:
385
+ config_class ([`PretrainedConfig`]):
386
+ The configuration corresponding to the model to register.
387
+ video_processor_class ([`BaseVideoProcessor`]):
388
+ The video processor to register.
389
+ """
390
+ VIDEO_PROCESSOR_MAPPING.register(config_class, video_processor_class, exist_ok=exist_ok)
391
+
392
+
393
+ __all__ = ["VIDEO_PROCESSOR_MAPPING", "AutoVideoProcessor"]
venv/lib/python3.13/site-packages/transformers/models/aya_vision/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_aya_vision import *
22
+ from .modeling_aya_vision import *
23
+ from .processing_aya_vision import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
venv/lib/python3.13/site-packages/transformers/models/aya_vision/configuration_aya_vision.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Cohere team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """AyaVision model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+ from ..auto import CONFIG_MAPPING, AutoConfig
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class AyaVisionConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`AyaVisionForConditionalGeneration`]. It is used to instantiate an
28
+ AyaVision model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of AyaVision.
30
+ e.g. [CohereForAI/aya-vision-8b](https://huggingface.co/CohereForAI/aya-vision-8b)
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SiglipVisionConfig`):
37
+ The config object or dictionary of the vision backbone.
38
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Cohere2Config`):
39
+ The config object or dictionary of the text backbone.
40
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`):
41
+ The feature selection strategy used to select the vision feature from the vision backbone.
42
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
43
+ If `"full"`, the full vision features are used.
44
+ vision_feature_layer (`int`, *optional*, defaults to -1):
45
+ The index of the layer to select the vision feature.
46
+ downsample_factor (`int`, *optional*, defaults to 2):
47
+ The downsample factor to apply to the vision features.
48
+ adapter_layer_norm_eps (`float`, *optional*, defaults to 1e-06):
49
+ The epsilon value used for layer normalization in the adapter.
50
+ image_token_index (`int`, *optional*, defaults to 255036):
51
+ The image token index to encode the image prompt.
52
+ """
53
+
54
+ model_type = "aya_vision"
55
+ attribute_map = {
56
+ "image_token_id": "image_token_index",
57
+ }
58
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
59
+
60
+ def __init__(
61
+ self,
62
+ vision_config=None,
63
+ text_config=None,
64
+ vision_feature_select_strategy="full",
65
+ vision_feature_layer=-1,
66
+ downsample_factor=2,
67
+ adapter_layer_norm_eps=1e-6,
68
+ image_token_index=255036,
69
+ **kwargs,
70
+ ):
71
+ self.image_token_index = image_token_index
72
+ self.downsample_factor = downsample_factor
73
+ self.adapter_layer_norm_eps = adapter_layer_norm_eps
74
+ if vision_feature_select_strategy not in ["default", "full"]:
75
+ raise ValueError(
76
+ "vision_feature_select_strategy should be one of 'default', 'full'."
77
+ f"Got: {vision_feature_select_strategy}"
78
+ )
79
+
80
+ self.vision_feature_select_strategy = vision_feature_select_strategy
81
+ self.vision_feature_layer = vision_feature_layer
82
+
83
+ if isinstance(vision_config, dict):
84
+ vision_config["model_type"] = vision_config.get("model_type", "siglip_vision_model")
85
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
86
+ elif vision_config is None:
87
+ vision_config = CONFIG_MAPPING["siglip_vision_model"](
88
+ hidden_size=1152,
89
+ intermediate_size=4304,
90
+ patch_size=14,
91
+ image_size=384,
92
+ num_hidden_layers=26,
93
+ num_attention_heads=14,
94
+ vision_use_head=False,
95
+ )
96
+
97
+ self.vision_config = vision_config
98
+
99
+ if isinstance(text_config, dict):
100
+ text_config["model_type"] = text_config.get("model_type", "cohere2")
101
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
102
+ elif text_config is None:
103
+ text_config = CONFIG_MAPPING["cohere2"]()
104
+
105
+ self.text_config = text_config
106
+
107
+ super().__init__(**kwargs)
108
+
109
+
110
+ __all__ = ["AyaVisionConfig"]
venv/lib/python3.13/site-packages/transformers/models/aya_vision/modeling_aya_vision.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/aya_vision/modular_aya_vision.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_aya_vision.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 the Cohere Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from dataclasses import dataclass
23
+ from typing import Optional, Union
24
+
25
+ import torch
26
+ from torch import nn
27
+
28
+ from ...activations import ACT2FN
29
+ from ...cache_utils import Cache
30
+ from ...generation import GenerationMixin
31
+ from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
32
+ from ...modeling_utils import PreTrainedModel
33
+ from ...processing_utils import Unpack
34
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
35
+ from ...utils.generic import check_model_inputs
36
+ from ..auto import AutoModel
37
+ from .configuration_aya_vision import AyaVisionConfig
38
+
39
+
40
+ class AyaVisionMultiModalProjector(nn.Module):
41
+ def __init__(self, config: AyaVisionConfig):
42
+ super().__init__()
43
+ self.config = config
44
+ self.downsample_factor = config.downsample_factor
45
+ self.alignment_intermediate_size = getattr(
46
+ config, "alignment_intermediate_size", config.text_config.hidden_size
47
+ )
48
+ self.layernorm = nn.LayerNorm(
49
+ config.vision_config.hidden_size * (config.downsample_factor**2), eps=config.adapter_layer_norm_eps
50
+ )
51
+
52
+ self.linear_1 = nn.Linear(
53
+ config.vision_config.hidden_size * (config.downsample_factor**2),
54
+ self.alignment_intermediate_size,
55
+ bias=True,
56
+ )
57
+
58
+ self.act = ACT2FN["silu"] # SwiGLU uses SiLU activation
59
+ # For SwiGLU, project down to half size since we split intermediate dim
60
+ self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2, config.text_config.hidden_size, bias=True)
61
+
62
+ def forward(self, image_features):
63
+ image_features = self.pixel_shuffle(image_features)
64
+ image_features = self.layernorm(image_features)
65
+ hidden_states = self.linear_1(image_features)
66
+
67
+ # Split along last dimension and apply SwiGLU
68
+ x, gate = hidden_states.chunk(2, dim=-1)
69
+ hidden_states = self.act(gate) * x
70
+
71
+ hidden_states = self.linear_2(hidden_states)
72
+ return hidden_states
73
+
74
+ def pixel_shuffle(self, image_features): # B, S, D
75
+ batch_size, seq_length, feature_dim = image_features.shape
76
+ height = width = int(seq_length**0.5)
77
+ image_features = image_features.reshape(image_features.shape[0], width, height, -1)
78
+ channels = image_features.shape[-1]
79
+ image_features = image_features.reshape(
80
+ batch_size, width, int(height / self.downsample_factor), int(channels * self.downsample_factor)
81
+ )
82
+ image_features = image_features.permute(0, 2, 1, 3)
83
+ image_features = image_features.reshape(
84
+ batch_size, int(height / self.downsample_factor), int(width / self.downsample_factor), -1
85
+ )
86
+ image_features = image_features.permute(0, 2, 1, 3)
87
+ return image_features
88
+
89
+
90
+ @auto_docstring
91
+ class AyaVisionPreTrainedModel(PreTrainedModel):
92
+ config: AyaVisionConfig
93
+ base_model_prefix = ""
94
+ supports_gradient_checkpointing = True
95
+ _skip_keys_device_placement = "past_key_values"
96
+
97
+ _supports_flash_attn = True
98
+ _supports_sdpa = True
99
+ _can_compile_fullgraph = False
100
+ _supports_flex_attn = True
101
+ _supports_attention_backend = True
102
+ _can_record_outputs = {
103
+ "hidden_states": "DecoderLayer",
104
+ "attentions": "Attention",
105
+ }
106
+
107
+
108
+ @dataclass
109
+ @auto_docstring(
110
+ custom_intro="""
111
+ Base class for AyaVision causal language model (or autoregressive) outputs.
112
+ """
113
+ )
114
+ class AyaVisionCausalLMOutputWithPast(ModelOutput):
115
+ r"""
116
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
117
+ Language modeling loss (for next-token prediction).
118
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
119
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
120
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
121
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
122
+
123
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
124
+ `past_key_values` input) to speed up sequential decoding.
125
+ image_hidden_states (`torch.FloatTensor`, *optional*):
126
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
127
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
128
+ """
129
+
130
+ loss: Optional[torch.FloatTensor] = None
131
+ logits: Optional[torch.FloatTensor] = None
132
+ past_key_values: Optional[Cache] = None
133
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
134
+ attentions: Optional[tuple[torch.FloatTensor]] = None
135
+ image_hidden_states: Optional[torch.FloatTensor] = None
136
+
137
+
138
+ @dataclass
139
+ @auto_docstring(
140
+ custom_intro="""
141
+ Base class for AyaVision outputs, with hidden states and attentions.
142
+ """
143
+ )
144
+ class AyaVisionModelOutputWithPast(BaseModelOutputWithPast):
145
+ r"""
146
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
147
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
148
+
149
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
150
+ `past_key_values` input) to speed up sequential decoding.
151
+ image_hidden_states (`torch.FloatTensor`, *optional*):
152
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
153
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
154
+ """
155
+
156
+ image_hidden_states: Optional[torch.FloatTensor] = None
157
+
158
+
159
+ @auto_docstring(
160
+ custom_intro="""
161
+ The AyaVision model which consists of a vision backbone and a language model, without a language modeling head.
162
+ """
163
+ )
164
+ class AyaVisionModel(AyaVisionPreTrainedModel):
165
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
166
+
167
+ def __init__(self, config: AyaVisionConfig):
168
+ super().__init__(config)
169
+ self.vision_tower = AutoModel.from_config(config.vision_config)
170
+
171
+ self.multi_modal_projector = AyaVisionMultiModalProjector(config)
172
+ self.language_model = AutoModel.from_config(config.text_config)
173
+ self.post_init()
174
+
175
+ def get_input_embeddings(self):
176
+ return self.language_model.get_input_embeddings()
177
+
178
+ def set_input_embeddings(self, value):
179
+ self.language_model.set_input_embeddings(value)
180
+
181
+ def set_decoder(self, decoder):
182
+ self.language_model = decoder
183
+
184
+ def get_decoder(self):
185
+ return self.language_model
186
+
187
+ def get_image_features(
188
+ self,
189
+ pixel_values: torch.FloatTensor,
190
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
191
+ vision_feature_select_strategy: Optional[str] = None,
192
+ **kwargs,
193
+ ):
194
+ """
195
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
196
+
197
+ Args:
198
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
199
+ The tensors corresponding to the input images.
200
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
201
+ The index of the layer to select the vision feature. If multiple indices are provided,
202
+ the vision feature of the corresponding indices will be concatenated to form the
203
+ vision features.
204
+ vision_feature_select_strategy (`str`, *optional*):
205
+ The feature selection strategy used to select the vision feature from the vision backbone.
206
+ Can be one of `"default"` or `"full"`
207
+ Returns:
208
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
209
+ """
210
+ vision_feature_layer = (
211
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
212
+ )
213
+ vision_feature_select_strategy = (
214
+ vision_feature_select_strategy
215
+ if vision_feature_select_strategy is not None
216
+ else self.config.vision_feature_select_strategy
217
+ )
218
+
219
+ if vision_feature_select_strategy not in ["default", "full"]:
220
+ raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
221
+
222
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
223
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
224
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
225
+
226
+ # If we have one vision feature layer, return the corresponding hidden states,
227
+ # otherwise, select the hidden states of each feature layer and concatenate them
228
+ if isinstance(vision_feature_layer, int):
229
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
230
+ if vision_feature_select_strategy == "default":
231
+ selected_image_feature = selected_image_feature[:, 1:]
232
+ else:
233
+ hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
234
+ # For default; crop CLS from each hidden state in the hidden state pool
235
+ if vision_feature_select_strategy == "default":
236
+ hs_pool = [hs[:, 1:] for hs in hs_pool]
237
+ selected_image_feature = torch.cat(hs_pool, dim=-1)
238
+
239
+ image_features = self.multi_modal_projector(selected_image_feature)
240
+ return image_features
241
+
242
+ def get_placeholder_mask(
243
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
244
+ ):
245
+ """
246
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
247
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
248
+ """
249
+ if input_ids is None:
250
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
251
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
252
+ )
253
+ special_image_mask = special_image_mask.all(-1)
254
+ else:
255
+ special_image_mask = input_ids == self.config.image_token_id
256
+
257
+ n_image_tokens = special_image_mask.sum()
258
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
259
+ n_image_features = image_features.shape[0] * image_features.shape[1]
260
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
261
+ raise ValueError(
262
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
263
+ )
264
+ return special_image_mask
265
+
266
+ @check_model_inputs()
267
+ @auto_docstring
268
+ def forward(
269
+ self,
270
+ input_ids: Optional[torch.LongTensor] = None,
271
+ pixel_values: Optional[torch.FloatTensor] = None,
272
+ attention_mask: Optional[torch.Tensor] = None,
273
+ position_ids: Optional[torch.LongTensor] = None,
274
+ past_key_values: Optional[Cache] = None,
275
+ inputs_embeds: Optional[torch.FloatTensor] = None,
276
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
277
+ vision_feature_select_strategy: Optional[str] = None,
278
+ use_cache: Optional[bool] = None,
279
+ cache_position: Optional[torch.LongTensor] = None,
280
+ **kwargs: Unpack[TransformersKwargs],
281
+ ) -> Union[tuple, AyaVisionModelOutputWithPast]:
282
+ vision_feature_layer = (
283
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
284
+ )
285
+ vision_feature_select_strategy = (
286
+ vision_feature_select_strategy
287
+ if vision_feature_select_strategy is not None
288
+ else self.config.vision_feature_select_strategy
289
+ )
290
+
291
+ if (input_ids is None) ^ (inputs_embeds is not None):
292
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
293
+
294
+ if inputs_embeds is None:
295
+ inputs_embeds = self.get_input_embeddings()(input_ids)
296
+
297
+ if pixel_values is not None:
298
+ image_features = self.get_image_features(
299
+ pixel_values=pixel_values,
300
+ vision_feature_layer=vision_feature_layer,
301
+ vision_feature_select_strategy=vision_feature_select_strategy,
302
+ )
303
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
304
+ special_image_mask = self.get_placeholder_mask(
305
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
306
+ )
307
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
308
+
309
+ outputs = self.language_model(
310
+ attention_mask=attention_mask,
311
+ position_ids=position_ids,
312
+ past_key_values=past_key_values,
313
+ inputs_embeds=inputs_embeds,
314
+ use_cache=use_cache,
315
+ cache_position=cache_position,
316
+ **kwargs,
317
+ )
318
+
319
+ return AyaVisionModelOutputWithPast(
320
+ last_hidden_state=outputs.last_hidden_state,
321
+ past_key_values=outputs.past_key_values,
322
+ hidden_states=outputs.hidden_states,
323
+ attentions=outputs.attentions,
324
+ image_hidden_states=image_features if pixel_values is not None else None,
325
+ )
326
+
327
+
328
+ @auto_docstring(
329
+ custom_intro="""
330
+ The AYA_VISION model which consists of a vision backbone and a language model.
331
+ """
332
+ )
333
+ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin):
334
+ _checkpoint_conversion_mapping = {
335
+ "^language_model.model": "model.language_model",
336
+ "^vision_tower": "model.vision_tower",
337
+ "^multi_modal_projector": "model.multi_modal_projector",
338
+ "^language_model.lm_head": "lm_head",
339
+ }
340
+ _tied_weights_keys = ["lm_head.weight"]
341
+
342
+ def __init__(self, config: AyaVisionConfig):
343
+ super().__init__(config)
344
+ self.model = AyaVisionModel(config)
345
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
346
+ self.post_init()
347
+
348
+ def get_input_embeddings(self):
349
+ return self.model.get_input_embeddings()
350
+
351
+ def set_input_embeddings(self, value):
352
+ self.model.set_input_embeddings(value)
353
+
354
+ def get_output_embeddings(self) -> nn.Module:
355
+ return self.lm_head
356
+
357
+ def set_decoder(self, decoder):
358
+ self.model.set_decoder(decoder)
359
+
360
+ def get_decoder(self):
361
+ return self.model.get_decoder()
362
+
363
+ def get_image_features(
364
+ self,
365
+ pixel_values: torch.FloatTensor,
366
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
367
+ vision_feature_select_strategy: Optional[str] = None,
368
+ **kwargs,
369
+ ):
370
+ return self.model.get_image_features(
371
+ pixel_values=pixel_values,
372
+ vision_feature_layer=vision_feature_layer,
373
+ vision_feature_select_strategy=vision_feature_select_strategy,
374
+ **kwargs,
375
+ )
376
+
377
+ # Make modules available through conditional class for BC
378
+ @property
379
+ def language_model(self):
380
+ return self.model.language_model
381
+
382
+ @property
383
+ def vision_tower(self):
384
+ return self.model.vision_tower
385
+
386
+ @property
387
+ def multi_modal_projector(self):
388
+ return self.model.multi_modal_projector
389
+
390
+ @can_return_tuple
391
+ @auto_docstring
392
+ def forward(
393
+ self,
394
+ input_ids: Optional[torch.LongTensor] = None,
395
+ pixel_values: Optional[torch.FloatTensor] = None,
396
+ attention_mask: Optional[torch.Tensor] = None,
397
+ position_ids: Optional[torch.LongTensor] = None,
398
+ past_key_values: Optional[Cache] = None,
399
+ inputs_embeds: Optional[torch.FloatTensor] = None,
400
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
401
+ vision_feature_select_strategy: Optional[str] = None,
402
+ labels: Optional[torch.LongTensor] = None,
403
+ cache_position: Optional[torch.LongTensor] = None,
404
+ logits_to_keep: Union[int, torch.Tensor] = 0,
405
+ image_sizes: Optional[torch.Tensor] = None,
406
+ **kwargs: Unpack[TransformersKwargs],
407
+ ) -> Union[tuple, AyaVisionCausalLMOutputWithPast]:
408
+ r"""
409
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
410
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
411
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
412
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
413
+
414
+ Example:
415
+
416
+ ```python
417
+ >>> from transformers import AutoProcessor, AyaVisionForConditionalGeneration
418
+ >>> import torch
419
+
420
+ >>> torch_device = "cuda:0"
421
+ >>> processor = AutoProcessor.from_pretrained("CohereForAI/aya-vision-8b", use_fast=True)
422
+ >>> model = AyaVisionForConditionalGeneration.from_pretrained("CohereForAI/aya-vision-8b", device_map=torch_device)
423
+
424
+ >>> messages = [
425
+ ... {
426
+ ... "role": "user",
427
+ ... "content": [
428
+ ... {
429
+ ... "type": "image",
430
+ ... "url": "https://pbs.twimg.com/media/Fx7YvfQWYAIp6rZ?format=jpg&name=medium",
431
+ ... },
432
+ ... {"type": "text", "text": "चित्र में लिखा पाठ क्या कहता है?"},
433
+ ... ],
434
+ ... }
435
+ ... ]
436
+
437
+ >>> inputs = processor.apply_chat_template(
438
+ ... messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", device=torch_device
439
+ ... ).to(model.device)
440
+
441
+ >>> gen_tokens = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.3)
442
+ >>> processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
443
+ ```"""
444
+ vision_feature_layer = (
445
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
446
+ )
447
+ vision_feature_select_strategy = (
448
+ vision_feature_select_strategy
449
+ if vision_feature_select_strategy is not None
450
+ else self.config.vision_feature_select_strategy
451
+ )
452
+
453
+ outputs = self.model(
454
+ input_ids=input_ids,
455
+ pixel_values=pixel_values,
456
+ attention_mask=attention_mask,
457
+ position_ids=position_ids,
458
+ past_key_values=past_key_values,
459
+ inputs_embeds=inputs_embeds,
460
+ vision_feature_layer=vision_feature_layer,
461
+ vision_feature_select_strategy=vision_feature_select_strategy,
462
+ cache_position=cache_position,
463
+ image_sizes=image_sizes,
464
+ **kwargs,
465
+ )
466
+
467
+ hidden_states = outputs[0]
468
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
469
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
470
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
471
+
472
+ loss = None
473
+ if labels is not None:
474
+ loss = self.loss_function(
475
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
476
+ )
477
+
478
+ return AyaVisionCausalLMOutputWithPast(
479
+ loss=loss,
480
+ logits=logits,
481
+ past_key_values=outputs.past_key_values,
482
+ hidden_states=outputs.hidden_states,
483
+ attentions=outputs.attentions,
484
+ image_hidden_states=outputs.image_hidden_states,
485
+ )
486
+
487
+ def prepare_inputs_for_generation(
488
+ self,
489
+ input_ids,
490
+ past_key_values=None,
491
+ inputs_embeds=None,
492
+ pixel_values=None,
493
+ attention_mask=None,
494
+ cache_position=None,
495
+ logits_to_keep=None,
496
+ **kwargs,
497
+ ):
498
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
499
+
500
+ model_inputs = super().prepare_inputs_for_generation(
501
+ input_ids,
502
+ past_key_values=past_key_values,
503
+ inputs_embeds=inputs_embeds,
504
+ attention_mask=attention_mask,
505
+ cache_position=cache_position,
506
+ logits_to_keep=logits_to_keep,
507
+ **kwargs,
508
+ )
509
+
510
+ if cache_position[0] == 0:
511
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
512
+ # Otherwise we need pixel values to be passed to model
513
+ model_inputs["pixel_values"] = pixel_values
514
+
515
+ return model_inputs
516
+
517
+
518
+ __all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]
venv/lib/python3.13/site-packages/transformers/models/aya_vision/modular_aya_vision.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 the Cohere Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch AyaVision model."""
16
+
17
+ from typing import Optional, Union
18
+
19
+ import torch
20
+ from torch import nn
21
+
22
+ from transformers.models.llava.modeling_llava import (
23
+ LlavaCausalLMOutputWithPast,
24
+ LlavaForConditionalGeneration,
25
+ LlavaModel,
26
+ LlavaModelOutputWithPast,
27
+ LlavaPreTrainedModel,
28
+ TransformersKwargs,
29
+ )
30
+
31
+ from ...activations import ACT2FN
32
+ from ...cache_utils import Cache
33
+ from ...processing_utils import Unpack
34
+ from ...utils import auto_docstring, logging
35
+ from ...utils.generic import check_model_inputs
36
+ from .configuration_aya_vision import AyaVisionConfig
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class AyaVisionMultiModalProjector(nn.Module):
43
+ def __init__(self, config: AyaVisionConfig):
44
+ super().__init__()
45
+ self.config = config
46
+ self.downsample_factor = config.downsample_factor
47
+ self.alignment_intermediate_size = getattr(
48
+ config, "alignment_intermediate_size", config.text_config.hidden_size
49
+ )
50
+ self.layernorm = nn.LayerNorm(
51
+ config.vision_config.hidden_size * (config.downsample_factor**2), eps=config.adapter_layer_norm_eps
52
+ )
53
+
54
+ self.linear_1 = nn.Linear(
55
+ config.vision_config.hidden_size * (config.downsample_factor**2),
56
+ self.alignment_intermediate_size,
57
+ bias=True,
58
+ )
59
+
60
+ self.act = ACT2FN["silu"] # SwiGLU uses SiLU activation
61
+ # For SwiGLU, project down to half size since we split intermediate dim
62
+ self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2, config.text_config.hidden_size, bias=True)
63
+
64
+ def forward(self, image_features):
65
+ image_features = self.pixel_shuffle(image_features)
66
+ image_features = self.layernorm(image_features)
67
+ hidden_states = self.linear_1(image_features)
68
+
69
+ # Split along last dimension and apply SwiGLU
70
+ x, gate = hidden_states.chunk(2, dim=-1)
71
+ hidden_states = self.act(gate) * x
72
+
73
+ hidden_states = self.linear_2(hidden_states)
74
+ return hidden_states
75
+
76
+ def pixel_shuffle(self, image_features): # B, S, D
77
+ batch_size, seq_length, feature_dim = image_features.shape
78
+ height = width = int(seq_length**0.5)
79
+ image_features = image_features.reshape(image_features.shape[0], width, height, -1)
80
+ channels = image_features.shape[-1]
81
+ image_features = image_features.reshape(
82
+ batch_size, width, int(height / self.downsample_factor), int(channels * self.downsample_factor)
83
+ )
84
+ image_features = image_features.permute(0, 2, 1, 3)
85
+ image_features = image_features.reshape(
86
+ batch_size, int(height / self.downsample_factor), int(width / self.downsample_factor), -1
87
+ )
88
+ image_features = image_features.permute(0, 2, 1, 3)
89
+ return image_features
90
+
91
+
92
+ class AyaVisionPreTrainedModel(LlavaPreTrainedModel):
93
+ _can_compile_fullgraph = False
94
+ _can_record_outputs = {
95
+ "hidden_states": "DecoderLayer",
96
+ "attentions": "Attention",
97
+ }
98
+
99
+
100
+ class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
101
+ pass
102
+
103
+
104
+ class AyaVisionModelOutputWithPast(LlavaModelOutputWithPast):
105
+ pass
106
+
107
+
108
+ class AyaVisionModel(LlavaModel):
109
+ # Unlike LLaVA, the model doesn't have to deal with Pixtral-style image states
110
+ def get_image_features(
111
+ self,
112
+ pixel_values: torch.FloatTensor,
113
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
114
+ vision_feature_select_strategy: Optional[str] = None,
115
+ **kwargs,
116
+ ):
117
+ """
118
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
119
+
120
+ Args:
121
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
122
+ The tensors corresponding to the input images.
123
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
124
+ The index of the layer to select the vision feature. If multiple indices are provided,
125
+ the vision feature of the corresponding indices will be concatenated to form the
126
+ vision features.
127
+ vision_feature_select_strategy (`str`, *optional*):
128
+ The feature selection strategy used to select the vision feature from the vision backbone.
129
+ Can be one of `"default"` or `"full"`
130
+ Returns:
131
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
132
+ """
133
+ vision_feature_layer = (
134
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
135
+ )
136
+ vision_feature_select_strategy = (
137
+ vision_feature_select_strategy
138
+ if vision_feature_select_strategy is not None
139
+ else self.config.vision_feature_select_strategy
140
+ )
141
+
142
+ if vision_feature_select_strategy not in ["default", "full"]:
143
+ raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
144
+
145
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
146
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
147
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
148
+
149
+ # If we have one vision feature layer, return the corresponding hidden states,
150
+ # otherwise, select the hidden states of each feature layer and concatenate them
151
+ if isinstance(vision_feature_layer, int):
152
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
153
+ if vision_feature_select_strategy == "default":
154
+ selected_image_feature = selected_image_feature[:, 1:]
155
+ else:
156
+ hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
157
+ # For default; crop CLS from each hidden state in the hidden state pool
158
+ if vision_feature_select_strategy == "default":
159
+ hs_pool = [hs[:, 1:] for hs in hs_pool]
160
+ selected_image_feature = torch.cat(hs_pool, dim=-1)
161
+
162
+ image_features = self.multi_modal_projector(selected_image_feature)
163
+ return image_features
164
+
165
+ @check_model_inputs()
166
+ @auto_docstring
167
+ def forward(
168
+ self,
169
+ input_ids: Optional[torch.LongTensor] = None,
170
+ pixel_values: Optional[torch.FloatTensor] = None,
171
+ attention_mask: Optional[torch.Tensor] = None,
172
+ position_ids: Optional[torch.LongTensor] = None,
173
+ past_key_values: Optional[Cache] = None,
174
+ inputs_embeds: Optional[torch.FloatTensor] = None,
175
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
176
+ vision_feature_select_strategy: Optional[str] = None,
177
+ use_cache: Optional[bool] = None,
178
+ cache_position: Optional[torch.LongTensor] = None,
179
+ **kwargs: Unpack[TransformersKwargs],
180
+ ) -> Union[tuple, AyaVisionModelOutputWithPast]:
181
+ vision_feature_layer = (
182
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
183
+ )
184
+ vision_feature_select_strategy = (
185
+ vision_feature_select_strategy
186
+ if vision_feature_select_strategy is not None
187
+ else self.config.vision_feature_select_strategy
188
+ )
189
+
190
+ if (input_ids is None) ^ (inputs_embeds is not None):
191
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
192
+
193
+ if inputs_embeds is None:
194
+ inputs_embeds = self.get_input_embeddings()(input_ids)
195
+
196
+ if pixel_values is not None:
197
+ image_features = self.get_image_features(
198
+ pixel_values=pixel_values,
199
+ vision_feature_layer=vision_feature_layer,
200
+ vision_feature_select_strategy=vision_feature_select_strategy,
201
+ )
202
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
203
+ special_image_mask = self.get_placeholder_mask(
204
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
205
+ )
206
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
207
+
208
+ outputs = self.language_model(
209
+ attention_mask=attention_mask,
210
+ position_ids=position_ids,
211
+ past_key_values=past_key_values,
212
+ inputs_embeds=inputs_embeds,
213
+ use_cache=use_cache,
214
+ cache_position=cache_position,
215
+ **kwargs,
216
+ )
217
+
218
+ return AyaVisionModelOutputWithPast(
219
+ last_hidden_state=outputs.last_hidden_state,
220
+ past_key_values=outputs.past_key_values,
221
+ hidden_states=outputs.hidden_states,
222
+ attentions=outputs.attentions,
223
+ image_hidden_states=image_features if pixel_values is not None else None,
224
+ )
225
+
226
+
227
+ class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration):
228
+ def forward(
229
+ self,
230
+ input_ids: Optional[torch.LongTensor] = None,
231
+ pixel_values: Optional[torch.FloatTensor] = None,
232
+ attention_mask: Optional[torch.Tensor] = None,
233
+ position_ids: Optional[torch.LongTensor] = None,
234
+ past_key_values: Optional[Cache] = None,
235
+ inputs_embeds: Optional[torch.FloatTensor] = None,
236
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
237
+ vision_feature_select_strategy: Optional[str] = None,
238
+ labels: Optional[torch.LongTensor] = None,
239
+ cache_position: Optional[torch.LongTensor] = None,
240
+ logits_to_keep: Union[int, torch.Tensor] = 0,
241
+ image_sizes: Optional[torch.Tensor] = None,
242
+ **kwargs: Unpack[TransformersKwargs],
243
+ ) -> Union[tuple, AyaVisionCausalLMOutputWithPast]:
244
+ r"""
245
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
246
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
247
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
248
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
249
+
250
+ Example:
251
+
252
+ ```python
253
+ >>> from transformers import AutoProcessor, AyaVisionForConditionalGeneration
254
+ >>> import torch
255
+
256
+ >>> torch_device = "cuda:0"
257
+ >>> processor = AutoProcessor.from_pretrained("CohereForAI/aya-vision-8b", use_fast=True)
258
+ >>> model = AyaVisionForConditionalGeneration.from_pretrained("CohereForAI/aya-vision-8b", device_map=torch_device)
259
+
260
+ >>> messages = [
261
+ ... {
262
+ ... "role": "user",
263
+ ... "content": [
264
+ ... {
265
+ ... "type": "image",
266
+ ... "url": "https://pbs.twimg.com/media/Fx7YvfQWYAIp6rZ?format=jpg&name=medium",
267
+ ... },
268
+ ... {"type": "text", "text": "चित्र में लिखा पाठ क्या कहता है?"},
269
+ ... ],
270
+ ... }
271
+ ... ]
272
+
273
+ >>> inputs = processor.apply_chat_template(
274
+ ... messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", device=torch_device
275
+ ... ).to(model.device)
276
+
277
+ >>> gen_tokens = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.3)
278
+ >>> processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
279
+ ```"""
280
+ super().forward(
281
+ input_ids=input_ids,
282
+ pixel_values=pixel_values,
283
+ attention_mask=attention_mask,
284
+ position_ids=position_ids,
285
+ past_key_values=past_key_values,
286
+ inputs_embeds=inputs_embeds,
287
+ vision_feature_layer=vision_feature_layer,
288
+ vision_feature_select_strategy=vision_feature_select_strategy,
289
+ labels=labels,
290
+ cache_position=cache_position,
291
+ logits_to_keep=logits_to_keep,
292
+ image_sizes=image_sizes,
293
+ **kwargs,
294
+ )
295
+
296
+
297
+ __all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]
venv/lib/python3.13/site-packages/transformers/models/aya_vision/processing_aya_vision.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Union
17
+
18
+ import numpy as np
19
+
20
+ from ...image_processing_utils import BatchFeature
21
+ from ...image_utils import ImageInput, make_flat_list_of_images
22
+ from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
23
+ from ...tokenization_utils_base import PreTokenizedInput, TextInput
24
+
25
+
26
+ class AyaVisionImagesKwargs(ImagesKwargs, total=False):
27
+ crop_to_patches: Optional[bool]
28
+ min_patches: Optional[int]
29
+ max_patches: Optional[int]
30
+
31
+
32
+ class AyaVisionProcessorKwargs(ProcessingKwargs, total=False):
33
+ images_kwargs: AyaVisionImagesKwargs
34
+ _defaults = {
35
+ "text_kwargs": {
36
+ "padding_side": "left",
37
+ "padding": True,
38
+ "return_mm_token_type_ids": False,
39
+ },
40
+ "images_kwargs": {
41
+ "crop_to_patches": True,
42
+ },
43
+ }
44
+
45
+
46
+ class AyaVisionProcessor(ProcessorMixin):
47
+ r"""
48
+ Constructs a AyaVision processor which wraps a [`AutoImageProcessor`] and
49
+ [`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and
50
+ tokenizer functionalities. See the [`~AyaVisionProcessor.__call__`] and [`~AyaVisionProcessor.decode`] for more information.
51
+ Args:
52
+ image_processor ([`AutoImageProcessor`], *optional*):
53
+ The image processor is a required input.
54
+ tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*):
55
+ The tokenizer is a required input.
56
+ patch_size (`int`, *optional*, defaults to 28):
57
+ The size of image patches for tokenization.
58
+ img_size (`int`, *optional*, defaults to 364):
59
+ The size of the image to be tokenized. This should correspond to the size given to the image processor.
60
+ image_token (`str`, *optional*, defaults to `"<image>"`):
61
+ The token to be used to represent an image in the text.
62
+ downsample_factor (`int`, *optional*, defaults to 1):
63
+ The factor by which to scale the patch size.
64
+ start_of_img_token (`str`, *optional*, defaults to `"<|START_OF_IMG|>"`):
65
+ The token to be used to represent the start of an image in the text.
66
+ end_of_img_token (`str`, *optional*, defaults to `"<|END_OF_IMG|>"`):
67
+ The token to be used to represent the end of an image in the text.
68
+ img_patch_token (`str`, *optional*, defaults to `"<|IMG_PATCH|>"`):
69
+ The token to be used to represent an image patch in the text.
70
+ img_line_break_token (`str`, *optional*, defaults to `"<|IMG_LINE_BREAK|>"`):
71
+ The token to be used to represent a line break in the text.
72
+ tile_token (`str`, *optional*, defaults to `"TILE"`):
73
+ The token to be used to represent an image patch in the text.
74
+ tile_global_token (`str`, *optional*, defaults to `"TILE_GLOBAL"`):
75
+ The token to be used to represent the cover image in the text.
76
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
77
+ in a chat into a tokenizable string.
78
+ """
79
+
80
+ attributes = ["image_processor", "tokenizer"]
81
+ image_processor_class = "AutoImageProcessor"
82
+ tokenizer_class = "AutoTokenizer"
83
+
84
+ def __init__(
85
+ self,
86
+ image_processor=None,
87
+ tokenizer=None,
88
+ patch_size: int = 28,
89
+ img_size: int = 364,
90
+ image_token="<image>", # set the default and let users change if they have peculiar special tokens in rare cases
91
+ downsample_factor: int = 1,
92
+ start_of_img_token="<|START_OF_IMG|>",
93
+ end_of_img_token="<|END_OF_IMG|>",
94
+ img_patch_token="<|IMG_PATCH|>",
95
+ img_line_break_token="<|IMG_LINE_BREAK|>",
96
+ tile_token="TILE",
97
+ tile_global_token="TILE_GLOBAL",
98
+ chat_template=None,
99
+ **kwargs,
100
+ ):
101
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
102
+
103
+ self.image_token = image_token
104
+ self.patch_size = patch_size * downsample_factor
105
+ self.img_size = img_size
106
+
107
+ self.start_of_img_token = start_of_img_token
108
+ self.end_of_img_token = end_of_img_token
109
+ self.img_patch_token = img_patch_token
110
+ self.img_line_break_token = img_line_break_token
111
+ self.tile_token = tile_token
112
+ self.tile_global_token = tile_global_token
113
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.img_patch_token)
114
+ self.image_ids = tokenizer.convert_tokens_to_ids(
115
+ [img_patch_token, tile_token, tile_global_token, start_of_img_token, end_of_img_token]
116
+ )
117
+
118
+ def _prompt_split_image(self, num_patches):
119
+ """
120
+ Create a structured string representation of image tokens
121
+
122
+ Args:
123
+ num_patches: Number of patches in the image
124
+
125
+ Returns:
126
+ String with appropriate image tokens
127
+ """
128
+
129
+ img_patches_per_tile = (self.img_size // self.patch_size) ** 2
130
+ img_string = f"{self.start_of_img_token}"
131
+ if num_patches > 1:
132
+ for idx in range(1, num_patches):
133
+ img_string += f"{self.tile_token}_{idx}" + f"{self.img_patch_token}" * img_patches_per_tile
134
+
135
+ img_string += f"{self.tile_global_token}" + f"{self.img_patch_token}" * img_patches_per_tile
136
+ img_string += f"{self.end_of_img_token}"
137
+ return img_string
138
+
139
+ def __call__(
140
+ self,
141
+ images: Optional[ImageInput] = None,
142
+ text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
143
+ audio=None,
144
+ videos=None,
145
+ **kwargs: Unpack[AyaVisionProcessorKwargs],
146
+ ) -> BatchFeature:
147
+ """
148
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
149
+ and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text.
150
+ To prepare the vision inputs, this method forwards the `images` and `kwargs` arguments to
151
+ GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`.
152
+
153
+ Args:
154
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
155
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
156
+ tensor. Both channels-first and channels-last formats are supported.
157
+ text (`str`, `list[str]`, `list[list[str]]`):
158
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
159
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
160
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
161
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
162
+ If set, will return tensors of a particular framework. Acceptable values are:
163
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
164
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
165
+ - `'np'`: Return NumPy `np.ndarray` objects.
166
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
167
+
168
+ Returns:
169
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
170
+
171
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
172
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
173
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
174
+ `None`).
175
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
176
+ """
177
+ if text is None:
178
+ raise ValueError("You have to specify text.")
179
+
180
+ output_kwargs = self._merge_kwargs(
181
+ AyaVisionProcessorKwargs,
182
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
183
+ **kwargs,
184
+ )
185
+
186
+ if not isinstance(text, (list, tuple)):
187
+ text = [text]
188
+
189
+ # Process images
190
+ image_inputs = {}
191
+ if images is not None:
192
+ images = self.image_processor.fetch_images(images)
193
+ images = make_flat_list_of_images(images)
194
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
195
+ num_patches = image_inputs.pop("num_patches")
196
+ image_index = 0
197
+ processed_text = []
198
+ for prompt in text:
199
+ new_prompt = prompt
200
+ while "<image>" in new_prompt:
201
+ # Replace the image placeholder with structured image tokens
202
+ image_tokens = self._prompt_split_image(num_patches[image_index])
203
+ new_prompt = new_prompt.replace("<image>", image_tokens, 1)
204
+ image_index += 1
205
+ processed_text.append(new_prompt)
206
+
207
+ if image_index != len(images):
208
+ raise ValueError("Number of image placeholders in the prompt does not match the number of images.")
209
+
210
+ text = processed_text
211
+
212
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
213
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
214
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
215
+
216
+ if return_mm_token_type_ids:
217
+ array_ids = np.array(text_inputs["input_ids"])
218
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
219
+ mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
220
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
221
+
222
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
223
+
224
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
225
+ """
226
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
227
+
228
+ Args:
229
+ image_sizes (`list[list[int]]`, *optional*):
230
+ The input sizes formatted as (height, width) per each image.
231
+
232
+ Returns:
233
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
234
+ input modalities, along with other useful data.
235
+ """
236
+
237
+ vision_data = {}
238
+ if image_sizes is not None:
239
+ images_kwargs = AyaVisionProcessorKwargs._defaults.get("images_kwargs", {})
240
+ images_kwargs.update(kwargs)
241
+
242
+ num_image_patches = [
243
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
244
+ for image_size in image_sizes
245
+ ]
246
+
247
+ token_per_patch = (self.img_size // self.patch_size) ** 2
248
+ num_image_tokens = [
249
+ token_per_patch + 3 + sum(token_per_patch + 1 for _ in range(1, num_patches))
250
+ for num_patches in num_image_patches
251
+ ] # Add +3 and +1 for BOI/EOI and image tile tokens
252
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
253
+
254
+ return MultiModalData(**vision_data)
255
+
256
+
257
+ __all__ = ["AyaVisionProcessor"]
venv/lib/python3.13/site-packages/transformers/models/barthez/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .tokenization_barthez import *
22
+ from .tokenization_barthez_fast import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
venv/lib/python3.13/site-packages/transformers/models/barthez/tokenization_barthez.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 Ecole Polytechnique and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License
15
+ """Tokenization classes for the BARThez model."""
16
+
17
+ import os
18
+ from shutil import copyfile
19
+ from typing import Any, Optional
20
+
21
+ import sentencepiece as spm
22
+
23
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
24
+ from ...utils import logging
25
+ from ...utils.import_utils import requires
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
31
+
32
+
33
+ SPIECE_UNDERLINE = "▁"
34
+
35
+ # TODO this class is useless. This is the most standard sentencpiece model. Let's find which one is closest and nuke this.
36
+
37
+
38
+ @requires(backends=("sentencepiece",))
39
+ class BarthezTokenizer(PreTrainedTokenizer):
40
+ """
41
+ Adapted from [`CamembertTokenizer`] and [`BartTokenizer`]. Construct a BARThez tokenizer. Based on
42
+ [SentencePiece](https://github.com/google/sentencepiece).
43
+
44
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
45
+ this superclass for more information regarding those methods.
46
+
47
+ Args:
48
+ vocab_file (`str`):
49
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
50
+ contains the vocabulary necessary to instantiate a tokenizer.
51
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
52
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
53
+
54
+ <Tip>
55
+
56
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
57
+ sequence. The token used is the `cls_token`.
58
+
59
+ </Tip>
60
+
61
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
62
+ The end of sequence token.
63
+
64
+ <Tip>
65
+
66
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
67
+ The token used is the `sep_token`.
68
+
69
+ </Tip>
70
+
71
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
72
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
73
+ sequence classification or for a text and a question for question answering. It is also used as the last
74
+ token of a sequence built with special tokens.
75
+ cls_token (`str`, *optional*, defaults to `"<s>"`):
76
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
77
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
78
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
79
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
80
+ token instead.
81
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
82
+ The token used for padding, for example when batching sequences of different lengths.
83
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
84
+ The token used for masking values. This is the token used when training this model with masked language
85
+ modeling. This is the token which the model will try to predict.
86
+ sp_model_kwargs (`dict`, *optional*):
87
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
88
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
89
+ to set:
90
+
91
+ - `enable_sampling`: Enable subword regularization.
92
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
93
+
94
+ - `nbest_size = {0,1}`: No sampling is performed.
95
+ - `nbest_size > 1`: samples from the nbest_size results.
96
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
97
+ using forward-filtering-and-backward-sampling algorithm.
98
+
99
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
100
+ BPE-dropout.
101
+
102
+ Attributes:
103
+ sp_model (`SentencePieceProcessor`):
104
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
105
+ """
106
+
107
+ vocab_files_names = VOCAB_FILES_NAMES
108
+ model_input_names = ["input_ids", "attention_mask"]
109
+
110
+ def __init__(
111
+ self,
112
+ vocab_file,
113
+ bos_token="<s>",
114
+ eos_token="</s>",
115
+ sep_token="</s>",
116
+ cls_token="<s>",
117
+ unk_token="<unk>",
118
+ pad_token="<pad>",
119
+ mask_token="<mask>",
120
+ sp_model_kwargs: Optional[dict[str, Any]] = None,
121
+ **kwargs,
122
+ ) -> None:
123
+ # Mask token behave like a normal word, i.e. include the space before it. Will have normalized=False by default this way
124
+ mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token
125
+
126
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
127
+
128
+ self.vocab_file = vocab_file
129
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
130
+ self.sp_model.Load(str(vocab_file))
131
+ super().__init__(
132
+ bos_token=bos_token,
133
+ eos_token=eos_token,
134
+ unk_token=unk_token,
135
+ sep_token=sep_token,
136
+ cls_token=cls_token,
137
+ pad_token=pad_token,
138
+ mask_token=mask_token,
139
+ sp_model_kwargs=self.sp_model_kwargs,
140
+ **kwargs,
141
+ )
142
+
143
+ def build_inputs_with_special_tokens(
144
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
145
+ ) -> list[int]:
146
+ """
147
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
148
+ adding special tokens. A BARThez sequence has the following format:
149
+
150
+ - single sequence: `<s> X </s>`
151
+ - pair of sequences: `<s> A </s></s> B </s>`
152
+
153
+ Args:
154
+ token_ids_0 (`list[int]`):
155
+ List of IDs to which the special tokens will be added.
156
+ token_ids_1 (`list[int]`, *optional*):
157
+ Optional second list of IDs for sequence pairs.
158
+
159
+ Returns:
160
+ `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
161
+ """
162
+
163
+ if token_ids_1 is None:
164
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
165
+ cls = [self.cls_token_id]
166
+ sep = [self.sep_token_id]
167
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
168
+
169
+ def get_special_tokens_mask(
170
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
171
+ ) -> list[int]:
172
+ """
173
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
174
+ special tokens using the tokenizer `prepare_for_model` method.
175
+
176
+ Args:
177
+ token_ids_0 (`list[int]`):
178
+ List of IDs.
179
+ token_ids_1 (`list[int]`, *optional*):
180
+ Optional second list of IDs for sequence pairs.
181
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
182
+ Whether or not the token list is already formatted with special tokens for the model.
183
+
184
+ Returns:
185
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
186
+ """
187
+ if already_has_special_tokens:
188
+ return super().get_special_tokens_mask(
189
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
190
+ )
191
+
192
+ if token_ids_1 is None:
193
+ return [1] + ([0] * len(token_ids_0)) + [1]
194
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
195
+
196
+ def create_token_type_ids_from_sequences(
197
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
198
+ ) -> list[int]:
199
+ """
200
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
201
+
202
+ Args:
203
+ token_ids_0 (`list[int]`):
204
+ List of IDs.
205
+ token_ids_1 (`list[int]`, *optional*):
206
+ Optional second list of IDs for sequence pairs.
207
+
208
+ Returns:
209
+ `list[int]`: List of zeros.
210
+ """
211
+ sep = [self.sep_token_id]
212
+ cls = [self.cls_token_id]
213
+
214
+ if token_ids_1 is None:
215
+ return len(cls + token_ids_0 + sep) * [0]
216
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
217
+
218
+ @property
219
+ def vocab_size(self):
220
+ return len(self.sp_model)
221
+
222
+ def get_vocab(self):
223
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
224
+ vocab.update(self.added_tokens_encoder)
225
+ return vocab
226
+
227
+ def _tokenize(self, text: str) -> list[str]:
228
+ return self.sp_model.encode(text, out_type=str)
229
+
230
+ def _convert_token_to_id(self, token):
231
+ """Converts a token (str) in an id using the vocab."""
232
+ return self.sp_model.PieceToId(token)
233
+
234
+ def _convert_id_to_token(self, index):
235
+ """Converts an index (integer) in a token (str) using the vocab."""
236
+ return self.sp_model.IdToPiece(index)
237
+
238
+ # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.convert_tokens_to_string
239
+ def convert_tokens_to_string(self, tokens):
240
+ """Converts a sequence of tokens (string) in a single string."""
241
+ current_sub_tokens = []
242
+ out_string = ""
243
+ prev_is_special = False
244
+ for token in tokens:
245
+ # make sure that special tokens are not decoded using sentencepiece model
246
+ if token in self.all_special_tokens:
247
+ if not prev_is_special:
248
+ out_string += " "
249
+ out_string += self.sp_model.decode(current_sub_tokens) + token
250
+ prev_is_special = True
251
+ current_sub_tokens = []
252
+ else:
253
+ current_sub_tokens.append(token)
254
+ prev_is_special = False
255
+ out_string += self.sp_model.decode(current_sub_tokens)
256
+ return out_string.strip()
257
+
258
+ def __getstate__(self):
259
+ state = self.__dict__.copy()
260
+ state["sp_model"] = None
261
+ return state
262
+
263
+ def __setstate__(self, d):
264
+ self.__dict__ = d
265
+
266
+ # for backward compatibility
267
+ if not hasattr(self, "sp_model_kwargs"):
268
+ self.sp_model_kwargs = {}
269
+
270
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
271
+ self.sp_model.Load(self.vocab_file)
272
+
273
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
274
+ if not os.path.isdir(save_directory):
275
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
276
+ return
277
+ out_vocab_file = os.path.join(
278
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
279
+ )
280
+
281
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
282
+ copyfile(self.vocab_file, out_vocab_file)
283
+ elif not os.path.isfile(self.vocab_file):
284
+ with open(out_vocab_file, "wb") as fi:
285
+ content_spiece_model = self.sp_model.serialized_model_proto()
286
+ fi.write(content_spiece_model)
287
+
288
+ return (out_vocab_file,)
289
+
290
+
291
+ __all__ = ["BarthezTokenizer"]
venv/lib/python3.13/site-packages/transformers/models/barthez/tokenization_barthez_fast.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 Ecole Polytechnique and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License
15
+ """Tokenization classes for the BARThez model."""
16
+
17
+ import os
18
+ from shutil import copyfile
19
+ from typing import Optional
20
+
21
+ from ...tokenization_utils import AddedToken
22
+ from ...tokenization_utils_fast import PreTrainedTokenizerFast
23
+ from ...utils import is_sentencepiece_available, logging
24
+
25
+
26
+ if is_sentencepiece_available():
27
+ from .tokenization_barthez import BarthezTokenizer
28
+ else:
29
+ BarthezTokenizer = None
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
34
+
35
+
36
+ SPIECE_UNDERLINE = "▁"
37
+
38
+
39
+ class BarthezTokenizerFast(PreTrainedTokenizerFast):
40
+ """
41
+ Adapted from [`CamembertTokenizer`] and [`BartTokenizer`]. Construct a "fast" BARThez tokenizer. Based on
42
+ [SentencePiece](https://github.com/google/sentencepiece).
43
+
44
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
45
+ refer to this superclass for more information regarding those methods.
46
+
47
+ Args:
48
+ vocab_file (`str`):
49
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
50
+ contains the vocabulary necessary to instantiate a tokenizer.
51
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
52
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
53
+
54
+ <Tip>
55
+
56
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
57
+ sequence. The token used is the `cls_token`.
58
+
59
+ </Tip>
60
+
61
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
62
+ The end of sequence token.
63
+
64
+ <Tip>
65
+
66
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
67
+ The token used is the `sep_token`.
68
+
69
+ </Tip>
70
+
71
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
72
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
73
+ sequence classification or for a text and a question for question answering. It is also used as the last
74
+ token of a sequence built with special tokens.
75
+ cls_token (`str`, *optional*, defaults to `"<s>"`):
76
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
77
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
78
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
79
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
80
+ token instead.
81
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
82
+ The token used for padding, for example when batching sequences of different lengths.
83
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
84
+ The token used for masking values. This is the token used when training this model with masked language
85
+ modeling. This is the token which the model will try to predict.
86
+ additional_special_tokens (`list[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
87
+ Additional special tokens used by the tokenizer.
88
+ """
89
+
90
+ vocab_files_names = VOCAB_FILES_NAMES
91
+ model_input_names = ["input_ids", "attention_mask"]
92
+ slow_tokenizer_class = BarthezTokenizer
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_file=None,
97
+ tokenizer_file=None,
98
+ bos_token="<s>",
99
+ eos_token="</s>",
100
+ sep_token="</s>",
101
+ cls_token="<s>",
102
+ unk_token="<unk>",
103
+ pad_token="<pad>",
104
+ mask_token="<mask>",
105
+ **kwargs,
106
+ ):
107
+ # Mask token behave like a normal word, i.e. include the space before it
108
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
109
+
110
+ super().__init__(
111
+ vocab_file,
112
+ tokenizer_file=tokenizer_file,
113
+ bos_token=bos_token,
114
+ eos_token=eos_token,
115
+ unk_token=unk_token,
116
+ sep_token=sep_token,
117
+ cls_token=cls_token,
118
+ pad_token=pad_token,
119
+ mask_token=mask_token,
120
+ **kwargs,
121
+ )
122
+
123
+ self.vocab_file = vocab_file
124
+
125
+ def build_inputs_with_special_tokens(
126
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
127
+ ) -> list[int]:
128
+ """
129
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
130
+ adding special tokens. A BARThez sequence has the following format:
131
+
132
+ - single sequence: `<s> X </s>`
133
+ - pair of sequences: `<s> A </s></s> B </s>`
134
+
135
+ Args:
136
+ token_ids_0 (`list[int]`):
137
+ List of IDs to which the special tokens will be added.
138
+ token_ids_1 (`list[int]`, *optional*):
139
+ Optional second list of IDs for sequence pairs.
140
+
141
+ Returns:
142
+ `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
143
+ """
144
+
145
+ if token_ids_1 is None:
146
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
147
+ cls = [self.cls_token_id]
148
+ sep = [self.sep_token_id]
149
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
150
+
151
+ def create_token_type_ids_from_sequences(
152
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
153
+ ) -> list[int]:
154
+ """
155
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
156
+
157
+ Args:
158
+ token_ids_0 (`list[int]`):
159
+ List of IDs.
160
+ token_ids_1 (`list[int]`, *optional*):
161
+ Optional second list of IDs for sequence pairs.
162
+
163
+ Returns:
164
+ `list[int]`: List of zeros.
165
+ """
166
+ sep = [self.sep_token_id]
167
+ cls = [self.cls_token_id]
168
+
169
+ if token_ids_1 is None:
170
+ return len(cls + token_ids_0 + sep) * [0]
171
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
172
+
173
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
174
+ if not self.can_save_slow_tokenizer:
175
+ raise ValueError(
176
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
177
+ "tokenizer."
178
+ )
179
+
180
+ if not os.path.isdir(save_directory):
181
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
182
+ return
183
+ out_vocab_file = os.path.join(
184
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
185
+ )
186
+
187
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
188
+ copyfile(self.vocab_file, out_vocab_file)
189
+
190
+ return (out_vocab_file,)
191
+
192
+
193
+ __all__ = ["BarthezTokenizerFast"]
venv/lib/python3.13/site-packages/transformers/models/bert_japanese/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .tokenization_bert_japanese import *
22
+ else:
23
+ import sys
24
+
25
+ _file = globals()["__file__"]
26
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
venv/lib/python3.13/site-packages/transformers/models/bert_japanese/tokenization_bert_japanese.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes."""
16
+
17
+ import collections
18
+ import copy
19
+ import os
20
+ import unicodedata
21
+ from typing import Any, Optional
22
+
23
+ from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
24
+ from ...utils import is_sentencepiece_available, is_sudachi_projection_available, logging
25
+
26
+
27
+ if is_sentencepiece_available():
28
+ import sentencepiece as spm
29
+ else:
30
+ spm = None
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "spm_file": "spiece.model"}
35
+
36
+ SPIECE_UNDERLINE = "▁"
37
+
38
+
39
+ # Copied from transformers.models.bert.tokenization_bert.load_vocab
40
+ def load_vocab(vocab_file):
41
+ """Loads a vocabulary file into a dictionary."""
42
+ vocab = collections.OrderedDict()
43
+ with open(vocab_file, "r", encoding="utf-8") as reader:
44
+ tokens = reader.readlines()
45
+ for index, token in enumerate(tokens):
46
+ token = token.rstrip("\n")
47
+ vocab[token] = index
48
+ return vocab
49
+
50
+
51
+ # Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
52
+ def whitespace_tokenize(text):
53
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
54
+ text = text.strip()
55
+ if not text:
56
+ return []
57
+ tokens = text.split()
58
+ return tokens
59
+
60
+
61
+ class BertJapaneseTokenizer(PreTrainedTokenizer):
62
+ r"""
63
+ Construct a BERT tokenizer for Japanese text.
64
+
65
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer
66
+ to: this superclass for more information regarding those methods.
67
+
68
+ Args:
69
+ vocab_file (`str`):
70
+ Path to a one-wordpiece-per-line vocabulary file.
71
+ spm_file (`str`, *optional*):
72
+ Path to [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm or .model
73
+ extension) that contains the vocabulary.
74
+ do_lower_case (`bool`, *optional*, defaults to `True`):
75
+ Whether to lower case the input. Only has an effect when do_basic_tokenize=True.
76
+ do_word_tokenize (`bool`, *optional*, defaults to `True`):
77
+ Whether to do word tokenization.
78
+ do_subword_tokenize (`bool`, *optional*, defaults to `True`):
79
+ Whether to do subword tokenization.
80
+ word_tokenizer_type (`str`, *optional*, defaults to `"basic"`):
81
+ Type of word tokenizer. Choose from ["basic", "mecab", "sudachi", "jumanpp"].
82
+ subword_tokenizer_type (`str`, *optional*, defaults to `"wordpiece"`):
83
+ Type of subword tokenizer. Choose from ["wordpiece", "character", "sentencepiece",].
84
+ mecab_kwargs (`dict`, *optional*):
85
+ Dictionary passed to the `MecabTokenizer` constructor.
86
+ sudachi_kwargs (`dict`, *optional*):
87
+ Dictionary passed to the `SudachiTokenizer` constructor.
88
+ jumanpp_kwargs (`dict`, *optional*):
89
+ Dictionary passed to the `JumanppTokenizer` constructor.
90
+ """
91
+
92
+ vocab_files_names = VOCAB_FILES_NAMES
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_file,
97
+ spm_file=None,
98
+ do_lower_case=False,
99
+ do_word_tokenize=True,
100
+ do_subword_tokenize=True,
101
+ word_tokenizer_type="basic",
102
+ subword_tokenizer_type="wordpiece",
103
+ never_split=None,
104
+ unk_token="[UNK]",
105
+ sep_token="[SEP]",
106
+ pad_token="[PAD]",
107
+ cls_token="[CLS]",
108
+ mask_token="[MASK]",
109
+ mecab_kwargs=None,
110
+ sudachi_kwargs=None,
111
+ jumanpp_kwargs=None,
112
+ **kwargs,
113
+ ):
114
+ if subword_tokenizer_type == "sentencepiece":
115
+ if not os.path.isfile(spm_file):
116
+ raise ValueError(
117
+ f"Can't find a vocabulary file at path '{spm_file}'. To load the vocabulary from a Google"
118
+ " pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
119
+ )
120
+ self.spm_file = spm_file
121
+ else:
122
+ if not os.path.isfile(vocab_file):
123
+ raise ValueError(
124
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google"
125
+ " pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
126
+ )
127
+ self.vocab = load_vocab(vocab_file)
128
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
129
+
130
+ self.do_word_tokenize = do_word_tokenize
131
+ self.word_tokenizer_type = word_tokenizer_type
132
+ self.lower_case = do_lower_case
133
+ self.never_split = never_split
134
+ self.mecab_kwargs = copy.deepcopy(mecab_kwargs)
135
+ self.sudachi_kwargs = copy.deepcopy(sudachi_kwargs)
136
+ self.jumanpp_kwargs = copy.deepcopy(jumanpp_kwargs)
137
+ if do_word_tokenize:
138
+ if word_tokenizer_type == "basic":
139
+ self.word_tokenizer = BasicTokenizer(
140
+ do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=False
141
+ )
142
+ elif word_tokenizer_type == "mecab":
143
+ self.word_tokenizer = MecabTokenizer(
144
+ do_lower_case=do_lower_case, never_split=never_split, **(mecab_kwargs or {})
145
+ )
146
+ elif word_tokenizer_type == "sudachi":
147
+ self.word_tokenizer = SudachiTokenizer(
148
+ do_lower_case=do_lower_case, never_split=never_split, **(sudachi_kwargs or {})
149
+ )
150
+ elif word_tokenizer_type == "jumanpp":
151
+ self.word_tokenizer = JumanppTokenizer(
152
+ do_lower_case=do_lower_case, never_split=never_split, **(jumanpp_kwargs or {})
153
+ )
154
+ else:
155
+ raise ValueError(f"Invalid word_tokenizer_type '{word_tokenizer_type}' is specified.")
156
+
157
+ self.do_subword_tokenize = do_subword_tokenize
158
+ self.subword_tokenizer_type = subword_tokenizer_type
159
+ if do_subword_tokenize:
160
+ if subword_tokenizer_type == "wordpiece":
161
+ self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
162
+ elif subword_tokenizer_type == "character":
163
+ self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=str(unk_token))
164
+ elif subword_tokenizer_type == "sentencepiece":
165
+ self.subword_tokenizer = SentencepieceTokenizer(vocab=self.spm_file, unk_token=str(unk_token))
166
+ else:
167
+ raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.")
168
+ super().__init__(
169
+ spm_file=spm_file,
170
+ unk_token=unk_token,
171
+ sep_token=sep_token,
172
+ pad_token=pad_token,
173
+ cls_token=cls_token,
174
+ mask_token=mask_token,
175
+ do_lower_case=do_lower_case,
176
+ do_word_tokenize=do_word_tokenize,
177
+ do_subword_tokenize=do_subword_tokenize,
178
+ word_tokenizer_type=word_tokenizer_type,
179
+ subword_tokenizer_type=subword_tokenizer_type,
180
+ never_split=never_split,
181
+ mecab_kwargs=mecab_kwargs,
182
+ sudachi_kwargs=sudachi_kwargs,
183
+ jumanpp_kwargs=jumanpp_kwargs,
184
+ **kwargs,
185
+ )
186
+
187
+ @property
188
+ def do_lower_case(self):
189
+ return self.lower_case
190
+
191
+ def __getstate__(self):
192
+ state = dict(self.__dict__)
193
+ if self.word_tokenizer_type in ["mecab", "sudachi", "jumanpp"]:
194
+ del state["word_tokenizer"]
195
+ return state
196
+
197
+ def __setstate__(self, state):
198
+ self.__dict__ = state
199
+ if self.word_tokenizer_type == "mecab":
200
+ self.word_tokenizer = MecabTokenizer(
201
+ do_lower_case=self.do_lower_case, never_split=self.never_split, **(self.mecab_kwargs or {})
202
+ )
203
+ elif self.word_tokenizer_type == "sudachi":
204
+ self.word_tokenizer = SudachiTokenizer(
205
+ do_lower_case=self.do_lower_case, never_split=self.never_split, **(self.sudachi_kwargs or {})
206
+ )
207
+ elif self.word_tokenizer_type == "jumanpp":
208
+ self.word_tokenizer = JumanppTokenizer(
209
+ do_lower_case=self.do_lower_case, never_split=self.never_split, **(self.jumanpp_kwargs or {})
210
+ )
211
+
212
+ def _tokenize(self, text):
213
+ if self.do_word_tokenize:
214
+ tokens = self.word_tokenizer.tokenize(text, never_split=self.all_special_tokens)
215
+ else:
216
+ tokens = [text]
217
+
218
+ if self.do_subword_tokenize:
219
+ split_tokens = [sub_token for token in tokens for sub_token in self.subword_tokenizer.tokenize(token)]
220
+ else:
221
+ split_tokens = tokens
222
+
223
+ return split_tokens
224
+
225
+ @property
226
+ def vocab_size(self):
227
+ if self.subword_tokenizer_type == "sentencepiece":
228
+ return len(self.subword_tokenizer.sp_model)
229
+ return len(self.vocab)
230
+
231
+ def get_vocab(self):
232
+ if self.subword_tokenizer_type == "sentencepiece":
233
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
234
+ vocab.update(self.added_tokens_encoder)
235
+ return vocab
236
+ return dict(self.vocab, **self.added_tokens_encoder)
237
+
238
+ def _convert_token_to_id(self, token):
239
+ """Converts a token (str) in an id using the vocab."""
240
+ if self.subword_tokenizer_type == "sentencepiece":
241
+ return self.subword_tokenizer.sp_model.PieceToId(token)
242
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
243
+
244
+ def _convert_id_to_token(self, index):
245
+ """Converts an index (integer) in a token (str) using the vocab."""
246
+ if self.subword_tokenizer_type == "sentencepiece":
247
+ return self.subword_tokenizer.sp_model.IdToPiece(index)
248
+ return self.ids_to_tokens.get(index, self.unk_token)
249
+
250
+ def convert_tokens_to_string(self, tokens):
251
+ """Converts a sequence of tokens (string) in a single string."""
252
+ if self.subword_tokenizer_type == "sentencepiece":
253
+ return self.subword_tokenizer.sp_model.decode(tokens)
254
+ out_string = " ".join(tokens).replace(" ##", "").strip()
255
+ return out_string
256
+
257
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.build_inputs_with_special_tokens
258
+ def build_inputs_with_special_tokens(
259
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
260
+ ) -> list[int]:
261
+ """
262
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
263
+ adding special tokens. A BERT sequence has the following format:
264
+
265
+ - single sequence: `[CLS] X [SEP]`
266
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
267
+
268
+ Args:
269
+ token_ids_0 (`List[int]`):
270
+ List of IDs to which the special tokens will be added.
271
+ token_ids_1 (`List[int]`, *optional*):
272
+ Optional second list of IDs for sequence pairs.
273
+
274
+ Returns:
275
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
276
+ """
277
+ if token_ids_1 is None:
278
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
279
+ cls = [self.cls_token_id]
280
+ sep = [self.sep_token_id]
281
+ return cls + token_ids_0 + sep + token_ids_1 + sep
282
+
283
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask
284
+ def get_special_tokens_mask(
285
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
286
+ ) -> list[int]:
287
+ """
288
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
289
+ special tokens using the tokenizer `prepare_for_model` method.
290
+
291
+ Args:
292
+ token_ids_0 (`List[int]`):
293
+ List of IDs.
294
+ token_ids_1 (`List[int]`, *optional*):
295
+ Optional second list of IDs for sequence pairs.
296
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
297
+ Whether or not the token list is already formatted with special tokens for the model.
298
+
299
+ Returns:
300
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
301
+ """
302
+
303
+ if already_has_special_tokens:
304
+ return super().get_special_tokens_mask(
305
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
306
+ )
307
+
308
+ if token_ids_1 is not None:
309
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
310
+ return [1] + ([0] * len(token_ids_0)) + [1]
311
+
312
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
313
+ if os.path.isdir(save_directory):
314
+ if self.subword_tokenizer_type == "sentencepiece":
315
+ vocab_file = os.path.join(
316
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["spm_file"]
317
+ )
318
+ else:
319
+ vocab_file = os.path.join(
320
+ save_directory,
321
+ (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"],
322
+ )
323
+ else:
324
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
325
+
326
+ if self.subword_tokenizer_type == "sentencepiece":
327
+ with open(vocab_file, "wb") as writer:
328
+ content_spiece_model = self.subword_tokenizer.sp_model.serialized_model_proto()
329
+ writer.write(content_spiece_model)
330
+ else:
331
+ with open(vocab_file, "w", encoding="utf-8") as writer:
332
+ index = 0
333
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
334
+ if index != token_index:
335
+ logger.warning(
336
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
337
+ " Please check that the vocabulary is not corrupted!"
338
+ )
339
+ index = token_index
340
+ writer.write(token + "\n")
341
+ index += 1
342
+ return (vocab_file,)
343
+
344
+
345
+ class MecabTokenizer:
346
+ """Runs basic tokenization with MeCab morphological parser."""
347
+
348
+ def __init__(
349
+ self,
350
+ do_lower_case=False,
351
+ never_split=None,
352
+ normalize_text=True,
353
+ mecab_dic: Optional[str] = "unidic_lite",
354
+ mecab_option: Optional[str] = None,
355
+ ):
356
+ """
357
+ Constructs a MecabTokenizer.
358
+
359
+ Args:
360
+ **do_lower_case**: (*optional*) boolean (default True)
361
+ Whether to lowercase the input.
362
+ **never_split**: (*optional*) list of str
363
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
364
+ [`PreTrainedTokenizer.tokenize`]) List of tokens not to split.
365
+ **normalize_text**: (*optional*) boolean (default True)
366
+ Whether to apply unicode normalization to text before tokenization.
367
+ **mecab_dic**: (*optional*) string (default "ipadic")
368
+ Name of dictionary to be used for MeCab initialization. If you are using a system-installed dictionary,
369
+ set this option to `None` and modify *mecab_option*.
370
+ **mecab_option**: (*optional*) string
371
+ String passed to MeCab constructor.
372
+ """
373
+ self.do_lower_case = do_lower_case
374
+ self.never_split = never_split if never_split is not None else []
375
+ self.normalize_text = normalize_text
376
+
377
+ try:
378
+ import fugashi
379
+ except ModuleNotFoundError as error:
380
+ raise error.__class__(
381
+ "You need to install fugashi to use MecabTokenizer. "
382
+ "See https://pypi.org/project/fugashi/ for installation."
383
+ )
384
+
385
+ mecab_option = mecab_option or ""
386
+
387
+ if mecab_dic is not None:
388
+ if mecab_dic == "ipadic":
389
+ try:
390
+ import ipadic
391
+ except ModuleNotFoundError as error:
392
+ raise error.__class__(
393
+ "The ipadic dictionary is not installed. "
394
+ "See https://github.com/polm/ipadic-py for installation."
395
+ )
396
+
397
+ dic_dir = ipadic.DICDIR
398
+
399
+ elif mecab_dic == "unidic_lite":
400
+ try:
401
+ import unidic_lite
402
+ except ModuleNotFoundError as error:
403
+ raise error.__class__(
404
+ "The unidic_lite dictionary is not installed. "
405
+ "See https://github.com/polm/unidic-lite for installation."
406
+ )
407
+
408
+ dic_dir = unidic_lite.DICDIR
409
+
410
+ elif mecab_dic == "unidic":
411
+ try:
412
+ import unidic
413
+ except ModuleNotFoundError as error:
414
+ raise error.__class__(
415
+ "The unidic dictionary is not installed. "
416
+ "See https://github.com/polm/unidic-py for installation."
417
+ )
418
+
419
+ dic_dir = unidic.DICDIR
420
+ if not os.path.isdir(dic_dir):
421
+ raise RuntimeError(
422
+ "The unidic dictionary itself is not found. "
423
+ "See https://github.com/polm/unidic-py for installation."
424
+ )
425
+
426
+ else:
427
+ raise ValueError("Invalid mecab_dic is specified.")
428
+
429
+ mecabrc = os.path.join(dic_dir, "mecabrc")
430
+ mecab_option = f'-d "{dic_dir}" -r "{mecabrc}" ' + mecab_option
431
+
432
+ self.mecab = fugashi.GenericTagger(mecab_option)
433
+
434
+ def tokenize(self, text, never_split=None, **kwargs):
435
+ """Tokenizes a piece of text."""
436
+ if self.normalize_text:
437
+ text = unicodedata.normalize("NFKC", text)
438
+
439
+ never_split = self.never_split + (never_split if never_split is not None else [])
440
+ tokens = []
441
+
442
+ for word in self.mecab(text):
443
+ token = word.surface
444
+
445
+ if self.do_lower_case and token not in never_split:
446
+ token = token.lower()
447
+
448
+ tokens.append(token)
449
+
450
+ return tokens
451
+
452
+
453
+ class SudachiTokenizer:
454
+ """Runs basic tokenization with Sudachi morphological parser."""
455
+
456
+ def __init__(
457
+ self,
458
+ do_lower_case=False,
459
+ never_split=None,
460
+ normalize_text=True,
461
+ trim_whitespace=False,
462
+ sudachi_split_mode="A",
463
+ sudachi_config_path=None,
464
+ sudachi_resource_dir=None,
465
+ sudachi_dict_type="core",
466
+ sudachi_projection=None,
467
+ ):
468
+ """
469
+ Constructs a SudachiTokenizer.
470
+
471
+ Args:
472
+ **do_lower_case**: (*optional*) boolean (default True)
473
+ Whether to lowercase the input.
474
+ **never_split**: (*optional*) list of str
475
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
476
+ [`PreTrainedTokenizer.tokenize`]) List of tokens not to split.
477
+ **normalize_text**: (*optional*) boolean (default True)
478
+ Whether to apply unicode normalization to text before tokenization.
479
+ **trim_whitespace**: (*optional*) boolean (default False)
480
+ Whether to trim all whitespace, tab, newline from tokens.
481
+ **sudachi_split_mode**: (*optional*) string
482
+ Split mode of sudachi, choose from `["A", "B", "C"]`.
483
+ **sudachi_config_path**: (*optional*) string
484
+ **sudachi_resource_dir**: (*optional*) string
485
+ **sudachi_dict_type**: (*optional*) string
486
+ dict type of sudachi, choose from `["small", "core", "full"]`.
487
+ **sudachi_projection**: (*optional*) string
488
+ Word projection mode of sudachi, choose from `["surface", "normalized", "reading", "dictionary", "dictionary_and_surface", "normalized_and_surface", "normalized_nouns"]`.
489
+ """
490
+
491
+ self.do_lower_case = do_lower_case
492
+ self.never_split = never_split if never_split is not None else []
493
+ self.normalize_text = normalize_text
494
+ self.trim_whitespace = trim_whitespace
495
+
496
+ try:
497
+ from sudachipy import dictionary, tokenizer
498
+ except ImportError:
499
+ raise ImportError(
500
+ "You need to install sudachipy to use SudachiTokenizer. "
501
+ "See https://github.com/WorksApplications/SudachiPy for installation."
502
+ )
503
+
504
+ if sudachi_split_mode == "A":
505
+ self.split_mode = tokenizer.Tokenizer.SplitMode.A
506
+ elif sudachi_split_mode == "B":
507
+ self.split_mode = tokenizer.Tokenizer.SplitMode.B
508
+ elif sudachi_split_mode == "C":
509
+ self.split_mode = tokenizer.Tokenizer.SplitMode.C
510
+ else:
511
+ raise ValueError("Invalid sudachi_split_mode is specified.")
512
+
513
+ self.projection = sudachi_projection
514
+
515
+ sudachi_dictionary = dictionary.Dictionary(
516
+ config_path=sudachi_config_path, resource_dir=sudachi_resource_dir, dict=sudachi_dict_type
517
+ )
518
+ if is_sudachi_projection_available():
519
+ self.sudachi = sudachi_dictionary.create(self.split_mode, projection=self.projection)
520
+ elif self.projection is not None:
521
+ raise ImportError("You need to install sudachipy>=0.6.8 to specify `projection` field in sudachi_kwargs.")
522
+ else:
523
+ self.sudachi = sudachi_dictionary.create(self.split_mode)
524
+
525
+ def tokenize(self, text, never_split=None, **kwargs):
526
+ """Tokenizes a piece of text."""
527
+ if self.normalize_text:
528
+ text = unicodedata.normalize("NFKC", text)
529
+
530
+ never_split = self.never_split + (never_split if never_split is not None else [])
531
+ tokens = []
532
+
533
+ for word in self.sudachi.tokenize(text):
534
+ token = word.surface()
535
+
536
+ if self.do_lower_case and token not in never_split:
537
+ token = token.lower()
538
+
539
+ if self.trim_whitespace:
540
+ if token.strip() == "":
541
+ continue
542
+ else:
543
+ token = token.strip()
544
+
545
+ tokens.append(token)
546
+
547
+ return tokens
548
+
549
+
550
+ class JumanppTokenizer:
551
+ """Runs basic tokenization with jumanpp morphological parser."""
552
+
553
+ def __init__(
554
+ self,
555
+ do_lower_case=False,
556
+ never_split=None,
557
+ normalize_text=True,
558
+ trim_whitespace=False,
559
+ ):
560
+ """
561
+ Constructs a JumanppTokenizer.
562
+
563
+ Args:
564
+ **do_lower_case**: (*optional*) boolean (default True)
565
+ Whether to lowercase the input.
566
+ **never_split**: (*optional*) list of str
567
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
568
+ [`PreTrainedTokenizer.tokenize`]) List of tokens not to split.
569
+ **normalize_text**: (*optional*) boolean (default True)
570
+ Whether to apply unicode normalization to text before tokenization.
571
+ **trim_whitespace**: (*optional*) boolean (default False)
572
+ Whether to trim all whitespace, tab, newline from tokens.
573
+ """
574
+
575
+ self.do_lower_case = do_lower_case
576
+ self.never_split = never_split if never_split is not None else []
577
+ self.normalize_text = normalize_text
578
+ self.trim_whitespace = trim_whitespace
579
+
580
+ try:
581
+ import rhoknp
582
+ except ImportError:
583
+ raise ImportError(
584
+ "You need to install rhoknp to use JumanppTokenizer. "
585
+ "See https://github.com/ku-nlp/rhoknp for installation."
586
+ )
587
+
588
+ self.juman = rhoknp.Jumanpp()
589
+
590
+ def tokenize(self, text, never_split=None, **kwargs):
591
+ """Tokenizes a piece of text."""
592
+ if self.normalize_text:
593
+ text = unicodedata.normalize("NFKC", text)
594
+
595
+ text = text.strip()
596
+
597
+ never_split = self.never_split + (never_split if never_split is not None else [])
598
+ tokens = []
599
+
600
+ for mrph in self.juman.apply_to_sentence(text).morphemes:
601
+ token = mrph.text
602
+
603
+ if self.do_lower_case and token not in never_split:
604
+ token = token.lower()
605
+
606
+ if self.trim_whitespace:
607
+ if token.strip() == "":
608
+ continue
609
+ else:
610
+ token = token.strip()
611
+
612
+ tokens.append(token)
613
+
614
+ return tokens
615
+
616
+
617
+ class CharacterTokenizer:
618
+ """Runs Character tokenization."""
619
+
620
+ def __init__(self, vocab, unk_token, normalize_text=True):
621
+ """
622
+ Constructs a CharacterTokenizer.
623
+
624
+ Args:
625
+ **vocab**:
626
+ Vocabulary object.
627
+ **unk_token**: str
628
+ A special symbol for out-of-vocabulary token.
629
+ **normalize_text**: (`optional`) boolean (default True)
630
+ Whether to apply unicode normalization to text before tokenization.
631
+ """
632
+ self.vocab = vocab
633
+ self.unk_token = unk_token
634
+ self.normalize_text = normalize_text
635
+
636
+ def tokenize(self, text):
637
+ """
638
+ Tokenizes a piece of text into characters.
639
+
640
+ For example, `input = "apple""` will return as output `["a", "p", "p", "l", "e"]`.
641
+
642
+ Args:
643
+ text: A single token or whitespace separated tokens.
644
+ This should have already been passed through *BasicTokenizer*.
645
+
646
+ Returns:
647
+ A list of characters.
648
+ """
649
+ if self.normalize_text:
650
+ text = unicodedata.normalize("NFKC", text)
651
+
652
+ output_tokens = []
653
+ for char in text:
654
+ if char not in self.vocab:
655
+ output_tokens.append(self.unk_token)
656
+ continue
657
+
658
+ output_tokens.append(char)
659
+
660
+ return output_tokens
661
+
662
+
663
+ # Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
664
+ class BasicTokenizer:
665
+ """
666
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
667
+
668
+ Args:
669
+ do_lower_case (`bool`, *optional*, defaults to `True`):
670
+ Whether or not to lowercase the input when tokenizing.
671
+ never_split (`Iterable`, *optional*):
672
+ Collection of tokens which will never be split during tokenization. Only has an effect when
673
+ `do_basic_tokenize=True`
674
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
675
+ Whether or not to tokenize Chinese characters.
676
+
677
+ This should likely be deactivated for Japanese (see this
678
+ [issue](https://github.com/huggingface/transformers/issues/328)).
679
+ strip_accents (`bool`, *optional*):
680
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
681
+ value for `lowercase` (as in the original BERT).
682
+ do_split_on_punc (`bool`, *optional*, defaults to `True`):
683
+ In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
684
+ the full context of the words, such as contractions.
685
+ """
686
+
687
+ def __init__(
688
+ self,
689
+ do_lower_case=True,
690
+ never_split=None,
691
+ tokenize_chinese_chars=True,
692
+ strip_accents=None,
693
+ do_split_on_punc=True,
694
+ ):
695
+ if never_split is None:
696
+ never_split = []
697
+ self.do_lower_case = do_lower_case
698
+ self.never_split = set(never_split)
699
+ self.tokenize_chinese_chars = tokenize_chinese_chars
700
+ self.strip_accents = strip_accents
701
+ self.do_split_on_punc = do_split_on_punc
702
+
703
+ def tokenize(self, text, never_split=None):
704
+ """
705
+ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
706
+
707
+ Args:
708
+ never_split (`List[str]`, *optional*)
709
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
710
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
711
+ """
712
+ # union() returns a new set by concatenating the two sets.
713
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
714
+ text = self._clean_text(text)
715
+
716
+ # This was added on November 1st, 2018 for the multilingual and Chinese
717
+ # models. This is also applied to the English models now, but it doesn't
718
+ # matter since the English models were not trained on any Chinese data
719
+ # and generally don't have any Chinese data in them (there are Chinese
720
+ # characters in the vocabulary because Wikipedia does have some Chinese
721
+ # words in the English Wikipedia.).
722
+ if self.tokenize_chinese_chars:
723
+ text = self._tokenize_chinese_chars(text)
724
+ # prevents treating the same character with different unicode codepoints as different characters
725
+ unicode_normalized_text = unicodedata.normalize("NFC", text)
726
+ orig_tokens = whitespace_tokenize(unicode_normalized_text)
727
+ split_tokens = []
728
+ for token in orig_tokens:
729
+ if token not in never_split:
730
+ if self.do_lower_case:
731
+ token = token.lower()
732
+ if self.strip_accents is not False:
733
+ token = self._run_strip_accents(token)
734
+ elif self.strip_accents:
735
+ token = self._run_strip_accents(token)
736
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
737
+
738
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
739
+ return output_tokens
740
+
741
+ def _run_strip_accents(self, text):
742
+ """Strips accents from a piece of text."""
743
+ text = unicodedata.normalize("NFD", text)
744
+ output = []
745
+ for char in text:
746
+ cat = unicodedata.category(char)
747
+ if cat == "Mn":
748
+ continue
749
+ output.append(char)
750
+ return "".join(output)
751
+
752
+ def _run_split_on_punc(self, text, never_split=None):
753
+ """Splits punctuation on a piece of text."""
754
+ if not self.do_split_on_punc or (never_split is not None and text in never_split):
755
+ return [text]
756
+ chars = list(text)
757
+ i = 0
758
+ start_new_word = True
759
+ output = []
760
+ while i < len(chars):
761
+ char = chars[i]
762
+ if _is_punctuation(char):
763
+ output.append([char])
764
+ start_new_word = True
765
+ else:
766
+ if start_new_word:
767
+ output.append([])
768
+ start_new_word = False
769
+ output[-1].append(char)
770
+ i += 1
771
+
772
+ return ["".join(x) for x in output]
773
+
774
+ def _tokenize_chinese_chars(self, text):
775
+ """Adds whitespace around any CJK character."""
776
+ output = []
777
+ for char in text:
778
+ cp = ord(char)
779
+ if self._is_chinese_char(cp):
780
+ output.append(" ")
781
+ output.append(char)
782
+ output.append(" ")
783
+ else:
784
+ output.append(char)
785
+ return "".join(output)
786
+
787
+ def _is_chinese_char(self, cp):
788
+ """Checks whether CP is the codepoint of a CJK character."""
789
+ # This defines a "chinese character" as anything in the CJK Unicode block:
790
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
791
+ #
792
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
793
+ # despite its name. The modern Korean Hangul alphabet is a different block,
794
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
795
+ # space-separated words, so they are not treated specially and handled
796
+ # like the all of the other languages.
797
+ if (
798
+ (cp >= 0x4E00 and cp <= 0x9FFF)
799
+ or (cp >= 0x3400 and cp <= 0x4DBF)
800
+ or (cp >= 0x20000 and cp <= 0x2A6DF)
801
+ or (cp >= 0x2A700 and cp <= 0x2B73F)
802
+ or (cp >= 0x2B740 and cp <= 0x2B81F)
803
+ or (cp >= 0x2B820 and cp <= 0x2CEAF)
804
+ or (cp >= 0xF900 and cp <= 0xFAFF)
805
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
806
+ ):
807
+ return True
808
+
809
+ return False
810
+
811
+ def _clean_text(self, text):
812
+ """Performs invalid character removal and whitespace cleanup on text."""
813
+ output = []
814
+ for char in text:
815
+ cp = ord(char)
816
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
817
+ continue
818
+ if _is_whitespace(char):
819
+ output.append(" ")
820
+ else:
821
+ output.append(char)
822
+ return "".join(output)
823
+
824
+
825
+ # Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
826
+ class WordpieceTokenizer:
827
+ """Runs WordPiece tokenization."""
828
+
829
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
830
+ self.vocab = vocab
831
+ self.unk_token = unk_token
832
+ self.max_input_chars_per_word = max_input_chars_per_word
833
+
834
+ def tokenize(self, text):
835
+ """
836
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
837
+ tokenization using the given vocabulary.
838
+
839
+ For example, `input = "unaffable"` will return as output `["un", "##aff", "##able"]`.
840
+
841
+ Args:
842
+ text: A single token or whitespace separated tokens. This should have
843
+ already been passed through *BasicTokenizer*.
844
+
845
+ Returns:
846
+ A list of wordpiece tokens.
847
+ """
848
+
849
+ output_tokens = []
850
+ for token in whitespace_tokenize(text):
851
+ chars = list(token)
852
+ if len(chars) > self.max_input_chars_per_word:
853
+ output_tokens.append(self.unk_token)
854
+ continue
855
+
856
+ is_bad = False
857
+ start = 0
858
+ sub_tokens = []
859
+ while start < len(chars):
860
+ end = len(chars)
861
+ cur_substr = None
862
+ while start < end:
863
+ substr = "".join(chars[start:end])
864
+ if start > 0:
865
+ substr = "##" + substr
866
+ if substr in self.vocab:
867
+ cur_substr = substr
868
+ break
869
+ end -= 1
870
+ if cur_substr is None:
871
+ is_bad = True
872
+ break
873
+ sub_tokens.append(cur_substr)
874
+ start = end
875
+
876
+ if is_bad:
877
+ output_tokens.append(self.unk_token)
878
+ else:
879
+ output_tokens.extend(sub_tokens)
880
+ return output_tokens
881
+
882
+
883
+ class SentencepieceTokenizer:
884
+ """
885
+ Runs sentencepiece tokenization. Based on transformers.models.albert.tokenization_albert.AlbertTokenizer.
886
+ """
887
+
888
+ def __init__(
889
+ self,
890
+ vocab,
891
+ unk_token,
892
+ do_lower_case=False,
893
+ remove_space=True,
894
+ keep_accents=True,
895
+ sp_model_kwargs: Optional[dict[str, Any]] = None,
896
+ ):
897
+ self.vocab = vocab
898
+ self.unk_token = unk_token
899
+ self.do_lower_case = do_lower_case
900
+ self.remove_space = remove_space
901
+ self.keep_accents = keep_accents
902
+
903
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
904
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
905
+ self.sp_model.Load(self.vocab)
906
+
907
+ def preprocess_text(self, inputs):
908
+ if self.remove_space:
909
+ outputs = " ".join(inputs.strip().split())
910
+ else:
911
+ outputs = inputs
912
+ outputs = outputs.replace("``", '"').replace("''", '"')
913
+
914
+ if not self.keep_accents:
915
+ outputs = unicodedata.normalize("NFKD", outputs)
916
+ outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
917
+ if self.do_lower_case:
918
+ outputs = outputs.lower()
919
+
920
+ return outputs
921
+
922
+ def tokenize(self, text):
923
+ """
924
+ Tokenizes text by sentencepiece. Based on [SentencePiece](https://github.com/google/sentencepiece).
925
+ Tokenization needs the given vocabulary.
926
+
927
+ Args:
928
+ text: A string needs to be tokenized.
929
+
930
+ Returns:
931
+ A list of sentencepiece tokens.
932
+ """
933
+ text = self.preprocess_text(text)
934
+ pieces = self.sp_model.encode(text, out_type=str)
935
+ new_pieces = []
936
+ for piece in pieces:
937
+ if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
938
+ cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
939
+ if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
940
+ if len(cur_pieces[0]) == 1:
941
+ cur_pieces = cur_pieces[1:]
942
+ else:
943
+ cur_pieces[0] = cur_pieces[0][1:]
944
+ cur_pieces.append(piece[-1])
945
+ new_pieces.extend(cur_pieces)
946
+ else:
947
+ new_pieces.append(piece)
948
+
949
+ return new_pieces
950
+
951
+
952
+ __all__ = ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"]
venv/lib/python3.13/site-packages/transformers/models/bertweet/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .tokenization_bertweet import *
22
+ else:
23
+ import sys
24
+
25
+ _file = globals()["__file__"]
26
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
venv/lib/python3.13/site-packages/transformers/models/bertweet/tokenization_bertweet.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, VinAI Research and the HuggingFace Inc. team.
3
+ # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """Tokenization classes for BERTweet"""
17
+
18
+ import html
19
+ import os
20
+ import re
21
+ from shutil import copyfile
22
+ from typing import Optional
23
+
24
+ import regex
25
+
26
+ from ...tokenization_utils import PreTrainedTokenizer
27
+ from ...utils import logging
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+ VOCAB_FILES_NAMES = {
33
+ "vocab_file": "vocab.txt",
34
+ "merges_file": "bpe.codes",
35
+ }
36
+
37
+
38
+ def get_pairs(word):
39
+ """
40
+ Return set of symbol pairs in a word.
41
+
42
+ Word is represented as tuple of symbols (symbols being variable-length strings).
43
+ """
44
+ pairs = set()
45
+ prev_char = word[0]
46
+ for char in word[1:]:
47
+ pairs.add((prev_char, char))
48
+ prev_char = char
49
+
50
+ pairs = set(pairs)
51
+ return pairs
52
+
53
+
54
+ class BertweetTokenizer(PreTrainedTokenizer):
55
+ """
56
+ Constructs a BERTweet tokenizer, using Byte-Pair-Encoding.
57
+
58
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
59
+ this superclass for more information regarding those methods.
60
+
61
+ Args:
62
+ vocab_file (`str`):
63
+ Path to the vocabulary file.
64
+ merges_file (`str`):
65
+ Path to the merges file.
66
+ normalization (`bool`, *optional*, defaults to `False`):
67
+ Whether or not to apply a normalization preprocess.
68
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
69
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
70
+
71
+ <Tip>
72
+
73
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
74
+ sequence. The token used is the `cls_token`.
75
+
76
+ </Tip>
77
+
78
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
79
+ The end of sequence token.
80
+
81
+ <Tip>
82
+
83
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
84
+ The token used is the `sep_token`.
85
+
86
+ </Tip>
87
+
88
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
89
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
90
+ sequence classification or for a text and a question for question answering. It is also used as the last
91
+ token of a sequence built with special tokens.
92
+ cls_token (`str`, *optional*, defaults to `"<s>"`):
93
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
94
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
95
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
96
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
97
+ token instead.
98
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
99
+ The token used for padding, for example when batching sequences of different lengths.
100
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
101
+ The token used for masking values. This is the token used when training this model with masked language
102
+ modeling. This is the token which the model will try to predict.
103
+ """
104
+
105
+ vocab_files_names = VOCAB_FILES_NAMES
106
+
107
+ def __init__(
108
+ self,
109
+ vocab_file,
110
+ merges_file,
111
+ normalization=False,
112
+ bos_token="<s>",
113
+ eos_token="</s>",
114
+ sep_token="</s>",
115
+ cls_token="<s>",
116
+ unk_token="<unk>",
117
+ pad_token="<pad>",
118
+ mask_token="<mask>",
119
+ **kwargs,
120
+ ):
121
+ try:
122
+ from emoji import demojize
123
+
124
+ self.demojizer = demojize
125
+ except ImportError:
126
+ logger.warning(
127
+ "emoji is not installed, thus not converting emoticons or emojis into text. Install emoji: pip3"
128
+ " install emoji==0.6.0"
129
+ )
130
+ self.demojizer = None
131
+
132
+ self.vocab_file = vocab_file
133
+ self.merges_file = merges_file
134
+
135
+ self.encoder = {}
136
+ self.encoder[str(bos_token)] = 0
137
+ self.encoder[str(pad_token)] = 1
138
+ self.encoder[str(eos_token)] = 2
139
+ self.encoder[str(unk_token)] = 3
140
+
141
+ self.add_from_file(vocab_file)
142
+
143
+ self.decoder = {v: k for k, v in self.encoder.items()}
144
+
145
+ with open(merges_file, encoding="utf-8") as merges_handle:
146
+ merges = merges_handle.read().split("\n")[:-1]
147
+ merges = [tuple(merge.split()[:-1]) for merge in merges]
148
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
149
+ self.cache = {}
150
+
151
+ self.normalization = normalization
152
+ self.tweetPreprocessor = TweetTokenizer()
153
+ self.special_puncts = {"’": "'", "…": "..."}
154
+
155
+ super().__init__(
156
+ normalization=normalization,
157
+ bos_token=bos_token,
158
+ eos_token=eos_token,
159
+ sep_token=sep_token,
160
+ cls_token=cls_token,
161
+ unk_token=unk_token,
162
+ pad_token=pad_token,
163
+ mask_token=mask_token,
164
+ **kwargs,
165
+ )
166
+
167
+ def build_inputs_with_special_tokens(
168
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
169
+ ) -> list[int]:
170
+ """
171
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
172
+ adding special tokens. A BERTweet sequence has the following format:
173
+
174
+ - single sequence: `<s> X </s>`
175
+ - pair of sequences: `<s> A </s></s> B </s>`
176
+
177
+ Args:
178
+ token_ids_0 (`list[int]`):
179
+ List of IDs to which the special tokens will be added.
180
+ token_ids_1 (`list[int]`, *optional*):
181
+ Optional second list of IDs for sequence pairs.
182
+
183
+ Returns:
184
+ `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
185
+ """
186
+
187
+ if token_ids_1 is None:
188
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
189
+ cls = [self.cls_token_id]
190
+ sep = [self.sep_token_id]
191
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
192
+
193
+ def get_special_tokens_mask(
194
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
195
+ ) -> list[int]:
196
+ """
197
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
198
+ special tokens using the tokenizer `prepare_for_model` method.
199
+
200
+ Args:
201
+ token_ids_0 (`list[int]`):
202
+ List of IDs.
203
+ token_ids_1 (`list[int]`, *optional*):
204
+ Optional second list of IDs for sequence pairs.
205
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
206
+ Whether or not the token list is already formatted with special tokens for the model.
207
+
208
+ Returns:
209
+ `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
210
+ """
211
+
212
+ if already_has_special_tokens:
213
+ return super().get_special_tokens_mask(
214
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
215
+ )
216
+
217
+ if token_ids_1 is None:
218
+ return [1] + ([0] * len(token_ids_0)) + [1]
219
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
220
+
221
+ def create_token_type_ids_from_sequences(
222
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
223
+ ) -> list[int]:
224
+ """
225
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. BERTweet does
226
+ not make use of token type ids, therefore a list of zeros is returned.
227
+
228
+ Args:
229
+ token_ids_0 (`list[int]`):
230
+ List of IDs.
231
+ token_ids_1 (`list[int]`, *optional*):
232
+ Optional second list of IDs for sequence pairs.
233
+
234
+ Returns:
235
+ `list[int]`: List of zeros.
236
+ """
237
+
238
+ sep = [self.sep_token_id]
239
+ cls = [self.cls_token_id]
240
+
241
+ if token_ids_1 is None:
242
+ return len(cls + token_ids_0 + sep) * [0]
243
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
244
+
245
+ @property
246
+ def vocab_size(self):
247
+ return len(self.encoder)
248
+
249
+ def get_vocab(self):
250
+ return dict(self.encoder, **self.added_tokens_encoder)
251
+
252
+ def bpe(self, token):
253
+ if token in self.cache:
254
+ return self.cache[token]
255
+ word = tuple(token)
256
+ word = tuple(list(word[:-1]) + [word[-1] + "</w>"])
257
+ pairs = get_pairs(word)
258
+
259
+ if not pairs:
260
+ return token
261
+
262
+ while True:
263
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
264
+ if bigram not in self.bpe_ranks:
265
+ break
266
+ first, second = bigram
267
+ new_word = []
268
+ i = 0
269
+ while i < len(word):
270
+ try:
271
+ j = word.index(first, i)
272
+ except ValueError:
273
+ new_word.extend(word[i:])
274
+ break
275
+ else:
276
+ new_word.extend(word[i:j])
277
+ i = j
278
+
279
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
280
+ new_word.append(first + second)
281
+ i += 2
282
+ else:
283
+ new_word.append(word[i])
284
+ i += 1
285
+ new_word = tuple(new_word)
286
+ word = new_word
287
+ if len(word) == 1:
288
+ break
289
+ else:
290
+ pairs = get_pairs(word)
291
+ word = "@@ ".join(word)
292
+ word = word[:-4]
293
+ self.cache[token] = word
294
+ return word
295
+
296
+ def _tokenize(self, text):
297
+ """Tokenize a string."""
298
+ if self.normalization: # Perform Tweet normalization before performing BPE
299
+ text = self.normalizeTweet(text)
300
+
301
+ split_tokens = []
302
+ words = re.findall(r"\S+\n?", text)
303
+ for token in words:
304
+ split_tokens.extend(list(self.bpe(token).split(" ")))
305
+ return split_tokens
306
+
307
+ def normalizeTweet(self, tweet):
308
+ """
309
+ Normalize a raw Tweet
310
+ """
311
+ for punct in self.special_puncts:
312
+ tweet = tweet.replace(punct, self.special_puncts[punct])
313
+
314
+ tokens = self.tweetPreprocessor.tokenize(tweet)
315
+ normTweet = " ".join([self.normalizeToken(token) for token in tokens])
316
+
317
+ normTweet = (
318
+ normTweet.replace("cannot ", "can not ")
319
+ .replace("n't ", " n't ")
320
+ .replace("n 't ", " n't ")
321
+ .replace("ca n't", "can't")
322
+ .replace("ai n't", "ain't")
323
+ )
324
+ normTweet = (
325
+ normTweet.replace("'m ", " 'm ")
326
+ .replace("'re ", " 're ")
327
+ .replace("'s ", " 's ")
328
+ .replace("'ll ", " 'll ")
329
+ .replace("'d ", " 'd ")
330
+ .replace("'ve ", " 've ")
331
+ )
332
+ normTweet = (
333
+ normTweet.replace(" p . m .", " p.m.")
334
+ .replace(" p . m ", " p.m ")
335
+ .replace(" a . m .", " a.m.")
336
+ .replace(" a . m ", " a.m ")
337
+ )
338
+
339
+ return " ".join(normTweet.split())
340
+
341
+ def normalizeToken(self, token):
342
+ """
343
+ Normalize tokens in a Tweet
344
+ """
345
+ lowercased_token = token.lower()
346
+ if token.startswith("@"):
347
+ return "@USER"
348
+ elif lowercased_token.startswith("http") or lowercased_token.startswith("www"):
349
+ return "HTTPURL"
350
+ elif len(token) == 1:
351
+ if token in self.special_puncts:
352
+ return self.special_puncts[token]
353
+ if self.demojizer is not None:
354
+ return self.demojizer(token)
355
+ else:
356
+ return token
357
+ else:
358
+ return token
359
+
360
+ def _convert_token_to_id(self, token):
361
+ """Converts a token (str) in an id using the vocab."""
362
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
363
+
364
+ def _convert_id_to_token(self, index):
365
+ """Converts an index (integer) in a token (str) using the vocab."""
366
+ return self.decoder.get(index, self.unk_token)
367
+
368
+ def convert_tokens_to_string(self, tokens):
369
+ """Converts a sequence of tokens (string) in a single string."""
370
+ out_string = " ".join(tokens).replace("@@ ", "").strip()
371
+ return out_string
372
+
373
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
374
+ if not os.path.isdir(save_directory):
375
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
376
+ return
377
+ out_vocab_file = os.path.join(
378
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
379
+ )
380
+ out_merge_file = os.path.join(
381
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
382
+ )
383
+
384
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
385
+ copyfile(self.vocab_file, out_vocab_file)
386
+ elif not os.path.isfile(self.vocab_file):
387
+ with open(out_vocab_file, "wb") as fi:
388
+ content_spiece_model = self.sp_model.serialized_model_proto()
389
+ fi.write(content_spiece_model)
390
+
391
+ if os.path.abspath(self.merges_file) != os.path.abspath(out_merge_file):
392
+ copyfile(self.merges_file, out_merge_file)
393
+
394
+ return out_vocab_file, out_merge_file
395
+
396
+ # def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
397
+ # filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens))
398
+ # tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens)
399
+ # tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far)
400
+ # return ''.join(tokens_generated_so_far)
401
+
402
+ def add_from_file(self, f):
403
+ """
404
+ Loads a pre-existing dictionary from a text file and adds its symbols to this instance.
405
+ """
406
+ if isinstance(f, str):
407
+ try:
408
+ with open(f, "r", encoding="utf-8") as fd:
409
+ self.add_from_file(fd)
410
+ except FileNotFoundError as fnfe:
411
+ raise fnfe
412
+ except UnicodeError:
413
+ raise Exception(f"Incorrect encoding detected in {f}, please rebuild the dataset")
414
+ return
415
+
416
+ lines = f.readlines()
417
+ for lineTmp in lines:
418
+ line = lineTmp.strip()
419
+ idx = line.rfind(" ")
420
+ if idx == -1:
421
+ raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
422
+ word = line[:idx]
423
+ self.encoder[word] = len(self.encoder)
424
+
425
+
426
+ # Natural Language Toolkit: Twitter Tokenizer
427
+ #
428
+ # Copyright (C) 2001-2020 NLTK Project
429
+ # Author: Christopher Potts <cgpotts@stanford.edu>
430
+ # Ewan Klein <ewan@inf.ed.ac.uk> (modifications)
431
+ # Pierpaolo Pantone <> (modifications)
432
+ # URL: http://nltk.org/
433
+ # For license information, see LICENSE.TXT
434
+ #
435
+
436
+
437
+ """
438
+ Twitter-aware tokenizer, designed to be flexible and easy to adapt to new domains and tasks. The basic logic is this:
439
+
440
+ 1. The tuple regex_strings defines a list of regular expression strings.
441
+
442
+ 2. The regex_strings strings are put, in order, into a compiled regular expression object called word_re.
443
+
444
+ 3. The tokenization is done by word_re.findall(s), where s is the user-supplied string, inside the tokenize() method of
445
+ the class Tokenizer.
446
+
447
+ 4. When instantiating Tokenizer objects, there is a single option: preserve_case. By default, it is set to True. If it
448
+ is set to False, then the tokenizer will lowercase everything except for emoticons.
449
+
450
+ """
451
+
452
+
453
+ ######################################################################
454
+ #
455
+ # import regex # https://github.com/nltk/nltk/issues/2409
456
+ # import html
457
+ #
458
+ ######################################################################
459
+ # The following strings are components in the regular expression
460
+ # that is used for tokenizing. It's important that phone_number
461
+ # appears first in the final regex (since it can contain whitespace).
462
+ # It also could matter that tags comes after emoticons, due to the
463
+ # possibility of having text like
464
+ #
465
+ # <:| and some text >:)
466
+ #
467
+ # Most importantly, the final element should always be last, since it
468
+ # does a last ditch whitespace-based tokenization of whatever is left.
469
+
470
+ # ToDo: Update with http://en.wikipedia.org/wiki/List_of_emoticons ?
471
+
472
+ # This particular element is used in a couple ways, so we define it
473
+ # with a name:
474
+ # docstyle-ignore
475
+ EMOTICONS = r"""
476
+ (?:
477
+ [<>]?
478
+ [:;=8] # eyes
479
+ [\-o\*\']? # optional nose
480
+ [\)\]\(\[dDpP/\:\}\{@\|\\] # mouth
481
+ |
482
+ [\)\]\(\[dDpP/\:\}\{@\|\\] # mouth
483
+ [\-o\*\']? # optional nose
484
+ [:;=8] # eyes
485
+ [<>]?
486
+ |
487
+ <3 # heart
488
+ )"""
489
+
490
+ # URL pattern due to John Gruber, modified by Tom Winzig. See
491
+ # https://gist.github.com/winzig/8894715
492
+ # docstyle-ignore
493
+ URLS = r""" # Capture 1: entire matched URL
494
+ (?:
495
+ https?: # URL protocol and colon
496
+ (?:
497
+ /{1,3} # 1-3 slashes
498
+ | # or
499
+ [a-z0-9%] # Single letter or digit or '%'
500
+ # (Trying not to match e.g. "URI::Escape")
501
+ )
502
+ | # or
503
+ # looks like domain name followed by a slash:
504
+ [a-z0-9.\-]+[.]
505
+ (?:[a-z]{2,13})
506
+ /
507
+ )
508
+ (?: # One or more:
509
+ [^\s()<>{}\[\]]+ # Run of non-space, non-()<>{}[]
510
+ | # or
511
+ \([^\s()]*?\([^\s()]+\)[^\s()]*?\) # balanced parens, one level deep: (...(...)...)
512
+ |
513
+ \([^\s]+?\) # balanced parens, non-recursive: (...)
514
+ )+
515
+ (?: # End with:
516
+ \([^\s()]*?\([^\s()]+\)[^\s()]*?\) # balanced parens, one level deep: (...(...)...)
517
+ |
518
+ \([^\s]+?\) # balanced parens, non-recursive: (...)
519
+ | # or
520
+ [^\s`!()\[\]{};:'".,<>?«»“”‘’] # not a space or one of these punct chars
521
+ )
522
+ | # OR, the following to match naked domains:
523
+ (?:
524
+ (?<!@) # not preceded by a @, avoid matching foo@_gmail.com_
525
+ [a-z0-9]+
526
+ (?:[.\-][a-z0-9]+)*
527
+ [.]
528
+ (?:[a-z]{2,13})
529
+ \b
530
+ /?
531
+ (?!@) # not succeeded by a @,
532
+ # avoid matching "foo.na" in "foo.na@example.com"
533
+ )
534
+ """
535
+
536
+ # docstyle-ignore
537
+ # The components of the tokenizer:
538
+ REGEXPS = (
539
+ URLS,
540
+ # Phone numbers:
541
+ r"""
542
+ (?:
543
+ (?: # (international)
544
+ \+?[01]
545
+ [ *\-.\)]*
546
+ )?
547
+ (?: # (area code)
548
+ [\(]?
549
+ \d{3}
550
+ [ *\-.\)]*
551
+ )?
552
+ \d{3} # exchange
553
+ [ *\-.\)]*
554
+ \d{4} # base
555
+ )""",
556
+ # ASCII Emoticons
557
+ EMOTICONS,
558
+ # HTML tags:
559
+ r"""<[^>\s]+>""",
560
+ # ASCII Arrows
561
+ r"""[\-]+>|<[\-]+""",
562
+ # Twitter username:
563
+ r"""(?:@[\w_]+)""",
564
+ # Twitter hashtags:
565
+ r"""(?:\#+[\w_]+[\w\'_\-]*[\w_]+)""",
566
+ # email addresses
567
+ r"""[\w.+-]+@[\w-]+\.(?:[\w-]\.?)+[\w-]""",
568
+ # docstyle-ignore
569
+ # Remaining word types:
570
+ r"""
571
+ (?:[^\W\d_](?:[^\W\d_]|['\-_])+[^\W\d_]) # Words with apostrophes or dashes.
572
+ |
573
+ (?:[+\-]?\d+[,/.:-]\d+[+\-]?) # Numbers, including fractions, decimals.
574
+ |
575
+ (?:[\w_]+) # Words without apostrophes or dashes.
576
+ |
577
+ (?:\.(?:\s*\.){1,}) # Ellipsis dots.
578
+ |
579
+ (?:\S) # Everything else that isn't whitespace.
580
+ """,
581
+ )
582
+
583
+ ######################################################################
584
+ # This is the core tokenizing regex:
585
+
586
+ WORD_RE = regex.compile(r"""(%s)""" % "|".join(REGEXPS), regex.VERBOSE | regex.I | regex.UNICODE)
587
+
588
+ # WORD_RE performs poorly on these patterns:
589
+ HANG_RE = regex.compile(r"([^a-zA-Z0-9])\1{3,}")
590
+
591
+ # The emoticon string gets its own regex so that we can preserve case for
592
+ # them as needed:
593
+ EMOTICON_RE = regex.compile(EMOTICONS, regex.VERBOSE | regex.I | regex.UNICODE)
594
+
595
+ # These are for regularizing HTML entities to Unicode:
596
+ ENT_RE = regex.compile(r"&(#?(x?))([^&;\s]+);")
597
+
598
+
599
+ ######################################################################
600
+ # Functions for converting html entities
601
+ ######################################################################
602
+
603
+
604
+ def _str_to_unicode(text, encoding=None, errors="strict"):
605
+ if encoding is None:
606
+ encoding = "utf-8"
607
+ if isinstance(text, bytes):
608
+ return text.decode(encoding, errors)
609
+ return text
610
+
611
+
612
+ def _replace_html_entities(text, keep=(), remove_illegal=True, encoding="utf-8"):
613
+ """
614
+ Remove entities from text by converting them to their corresponding unicode character.
615
+
616
+ Args:
617
+ text:
618
+ A unicode string or a byte string encoded in the given *encoding* (which defaults to 'utf-8').
619
+ keep (list):
620
+ List of entity names which should not be replaced. This supports both numeric entities (`&#nnnn;` and
621
+ `&#hhhh;`) and named entities (such as `&nbsp;` or `&gt;`).
622
+ remove_illegal (bool):
623
+ If `True`, entities that can't be converted are removed. Otherwise, entities that can't be converted are
624
+ kept "as is".
625
+
626
+ Returns: A unicode string with the entities removed.
627
+
628
+ See https://github.com/scrapy/w3lib/blob/master/w3lib/html.py
629
+
630
+ Examples:
631
+
632
+ ```python
633
+ >>> from nltk.tokenize.casual import _replace_html_entities
634
+
635
+ >>> _replace_html_entities(b"Price: &pound;100")
636
+ 'Price: \\xa3100'
637
+
638
+ >>> print(_replace_html_entities(b"Price: &pound;100"))
639
+ Price: £100
640
+ ```"""
641
+
642
+ def _convert_entity(match):
643
+ entity_body = match.group(3)
644
+ if match.group(1):
645
+ try:
646
+ if match.group(2):
647
+ number = int(entity_body, 16)
648
+ else:
649
+ number = int(entity_body, 10)
650
+ # Numeric character references in the 80-9F range are typically
651
+ # interpreted by browsers as representing the characters mapped
652
+ # to bytes 80-9F in the Windows-1252 encoding. For more info
653
+ # see: https://en.wikipedia.org/wiki/ISO/IEC_8859-1#Similar_character_sets
654
+ if 0x80 <= number <= 0x9F:
655
+ return bytes((number,)).decode("cp1252")
656
+ except ValueError:
657
+ number = None
658
+ else:
659
+ if entity_body in keep:
660
+ return match.group(0)
661
+ else:
662
+ number = html.entities.name2codepoint.get(entity_body)
663
+ if number is not None:
664
+ try:
665
+ return chr(number)
666
+ except (ValueError, OverflowError):
667
+ pass
668
+
669
+ return "" if remove_illegal else match.group(0)
670
+
671
+ return ENT_RE.sub(_convert_entity, _str_to_unicode(text, encoding))
672
+
673
+
674
+ ######################################################################
675
+
676
+
677
+ class TweetTokenizer:
678
+ r"""
679
+ Examples:
680
+
681
+ ```python
682
+ >>> # Tokenizer for tweets.
683
+ >>> from nltk.tokenize import TweetTokenizer
684
+
685
+ >>> tknzr = TweetTokenizer()
686
+ >>> s0 = "This is a cooool #dummysmiley: :-) :-P <3 and some arrows < > -> <--"
687
+ >>> tknzr.tokenize(s0)
688
+ ['This', 'is', 'a', 'cooool', '#dummysmiley', ':', ':-)', ':-P', '<3', 'and', 'some', 'arrows', '<', '>', '->', '<--']
689
+
690
+ >>> # Examples using *strip_handles* and *reduce_len parameters*:
691
+ >>> tknzr = TweetTokenizer(strip_handles=True, reduce_len=True)
692
+ >>> s1 = "@remy: This is waaaaayyyy too much for you!!!!!!"
693
+ >>> tknzr.tokenize(s1)
694
+ [':', 'This', 'is', 'waaayyy', 'too', 'much', 'for', 'you', '!', '!', '!']
695
+ ```"""
696
+
697
+ def __init__(self, preserve_case=True, reduce_len=False, strip_handles=False):
698
+ self.preserve_case = preserve_case
699
+ self.reduce_len = reduce_len
700
+ self.strip_handles = strip_handles
701
+
702
+ def tokenize(self, text):
703
+ """
704
+ Args:
705
+ text: str
706
+
707
+ Returns: list(str) A tokenized list of strings; concatenating this list returns the original string if
708
+ `preserve_case=False`
709
+ """
710
+ # Fix HTML character entities:
711
+ text = _replace_html_entities(text)
712
+ # Remove username handles
713
+ if self.strip_handles:
714
+ text = remove_handles(text)
715
+ # Normalize word lengthening
716
+ if self.reduce_len:
717
+ text = reduce_lengthening(text)
718
+ # Shorten problematic sequences of characters
719
+ safe_text = HANG_RE.sub(r"\1\1\1", text)
720
+ # Tokenize:
721
+ words = WORD_RE.findall(safe_text)
722
+ # Possibly alter the case, but avoid changing emoticons like :D into :d:
723
+ if not self.preserve_case:
724
+ words = [x if EMOTICON_RE.search(x) else x.lower() for x in words]
725
+ return words
726
+
727
+
728
+ ######################################################################
729
+ # Normalization Functions
730
+ ######################################################################
731
+
732
+
733
+ def reduce_lengthening(text):
734
+ """
735
+ Replace repeated character sequences of length 3 or greater with sequences of length 3.
736
+ """
737
+ pattern = regex.compile(r"(.)\1{2,}")
738
+ return pattern.sub(r"\1\1\1", text)
739
+
740
+
741
+ def remove_handles(text):
742
+ """
743
+ Remove Twitter username handles from text.
744
+ """
745
+ pattern = regex.compile(
746
+ r"(?<![A-Za-z0-9_!@#\$%&*])@(([A-Za-z0-9_]){20}(?!@))|(?<![A-Za-z0-9_!@#\$%&*])@(([A-Za-z0-9_]){1,19})(?![A-Za-z0-9_]*@)"
747
+ )
748
+ # Substitute handles with ' ' to ensure that text on either side of removed handles are tokenized correctly
749
+ return pattern.sub(" ", text)
750
+
751
+
752
+ ######################################################################
753
+ # Tokenization Function
754
+ ######################################################################
755
+
756
+
757
+ def casual_tokenize(text, preserve_case=True, reduce_len=False, strip_handles=False):
758
+ """
759
+ Convenience function for wrapping the tokenizer.
760
+ """
761
+ return TweetTokenizer(preserve_case=preserve_case, reduce_len=reduce_len, strip_handles=strip_handles).tokenize(
762
+ text
763
+ )
764
+
765
+
766
+ ###############################################################################
767
+
768
+
769
+ __all__ = ["BertweetTokenizer"]
venv/lib/python3.13/site-packages/transformers/models/biogpt/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_biogpt import *
22
+ from .modeling_biogpt import *
23
+ from .tokenization_biogpt import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
venv/lib/python3.13/site-packages/transformers/models/biogpt/configuration_biogpt.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """BioGPT model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class BioGptConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`BioGptModel`]. It is used to instantiate an
27
+ BioGPT model according to the specified arguments, defining the model architecture. Instantiating a configuration
28
+ with the defaults will yield a similar configuration to that of the BioGPT
29
+ [microsoft/biogpt](https://huggingface.co/microsoft/biogpt) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 42384):
37
+ Vocabulary size of the BioGPT model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`BioGptModel`].
39
+ hidden_size (`int`, *optional*, defaults to 1024):
40
+ Dimension of the encoder layers and the pooler layer.
41
+ num_hidden_layers (`int`, *optional*, defaults to 24):
42
+ Number of hidden layers in the Transformer encoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 16):
44
+ Number of attention heads for each attention layer in the Transformer encoder.
45
+ intermediate_size (`int`, *optional*, defaults to 4096):
46
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
47
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
48
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
49
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
50
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
51
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
52
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
53
+ The dropout ratio for the attention probabilities.
54
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
55
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
56
+ just in case (e.g., 512 or 1024 or 2048).
57
+ initializer_range (`float`, *optional*, defaults to 0.02):
58
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
59
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
60
+ The epsilon used by the layer normalization layers.
61
+ scale_embedding (`bool`, *optional*, defaults to `True`):
62
+ Scale embeddings by diving by sqrt(d_model).
63
+ use_cache (`bool`, *optional*, defaults to `True`):
64
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
65
+ relevant if `config.is_decoder=True`.
66
+ layerdrop (`float`, *optional*, defaults to 0.0):
67
+ Please refer to the paper about LayerDrop: https://huggingface.co/papers/1909.11556 for further details
68
+ activation_dropout (`float`, *optional*, defaults to 0.0):
69
+ The dropout ratio for activations inside the fully connected layer.
70
+ pad_token_id (`int`, *optional*, defaults to 1):
71
+ Padding token id.
72
+ bos_token_id (`int`, *optional*, defaults to 0):
73
+ Beginning of stream token id.
74
+ eos_token_id (`int`, *optional*, defaults to 2):
75
+ End of stream token id.
76
+
77
+ Example:
78
+
79
+ ```python
80
+ >>> from transformers import BioGptModel, BioGptConfig
81
+
82
+ >>> # Initializing a BioGPT microsoft/biogpt style configuration
83
+ >>> configuration = BioGptConfig()
84
+
85
+ >>> # Initializing a model from the microsoft/biogpt style configuration
86
+ >>> model = BioGptModel(configuration)
87
+
88
+ >>> # Accessing the model configuration
89
+ >>> configuration = model.config
90
+ ```"""
91
+
92
+ model_type = "biogpt"
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_size=42384,
97
+ hidden_size=1024,
98
+ num_hidden_layers=24,
99
+ num_attention_heads=16,
100
+ intermediate_size=4096,
101
+ hidden_act="gelu",
102
+ hidden_dropout_prob=0.1,
103
+ attention_probs_dropout_prob=0.1,
104
+ max_position_embeddings=1024,
105
+ initializer_range=0.02,
106
+ layer_norm_eps=1e-12,
107
+ scale_embedding=True,
108
+ use_cache=True,
109
+ layerdrop=0.0,
110
+ activation_dropout=0.0,
111
+ pad_token_id=1,
112
+ bos_token_id=0,
113
+ eos_token_id=2,
114
+ **kwargs,
115
+ ):
116
+ self.vocab_size = vocab_size
117
+ self.max_position_embeddings = max_position_embeddings
118
+ self.hidden_size = hidden_size
119
+ self.num_hidden_layers = num_hidden_layers
120
+ self.num_attention_heads = num_attention_heads
121
+ self.intermediate_size = intermediate_size
122
+ self.hidden_act = hidden_act
123
+ self.hidden_dropout_prob = hidden_dropout_prob
124
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
125
+ self.initializer_range = initializer_range
126
+ self.layer_norm_eps = layer_norm_eps
127
+ self.scale_embedding = scale_embedding
128
+ self.use_cache = use_cache
129
+ self.layerdrop = layerdrop
130
+ self.activation_dropout = activation_dropout
131
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
132
+
133
+
134
+ __all__ = ["BioGptConfig"]
venv/lib/python3.13/site-packages/transformers/models/biogpt/modeling_biogpt.py ADDED
@@ -0,0 +1,967 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/biogpt/modular_biogpt.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_biogpt.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ import math
23
+ from typing import Callable, Optional, Union
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+ from ...activations import ACT2FN
30
+ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
31
+ from ...generation import GenerationMixin
32
+ from ...modeling_attn_mask_utils import AttentionMaskConverter
33
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
34
+ from ...modeling_layers import GradientCheckpointingLayer
35
+ from ...modeling_outputs import (
36
+ BaseModelOutputWithPastAndCrossAttentions,
37
+ CausalLMOutputWithCrossAttentions,
38
+ SequenceClassifierOutputWithPast,
39
+ TokenClassifierOutput,
40
+ )
41
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
42
+ from ...processing_utils import Unpack
43
+ from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available, logging
44
+ from ...utils.deprecation import deprecate_kwarg
45
+ from .configuration_biogpt import BioGptConfig
46
+
47
+
48
+ if is_torch_flex_attn_available():
49
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ class BioGptLearnedPositionalEmbedding(nn.Embedding):
56
+ """
57
+ This module learns positional embeddings up to a fixed maximum size.
58
+ """
59
+
60
+ def __init__(self, num_embeddings: int, embedding_dim: int):
61
+ # BIOGPT is set up so that if padding_idx is specified then offset the embedding ids by 2
62
+ # and adjust num_embeddings appropriately. Other models don't have this hack
63
+ self.offset = 2
64
+ super().__init__(num_embeddings + self.offset, embedding_dim)
65
+
66
+ def forward(
67
+ self,
68
+ attention_mask: torch.LongTensor,
69
+ past_key_values_length: int = 0,
70
+ position_ids: Optional[torch.LongTensor] = None,
71
+ ):
72
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
73
+
74
+ if position_ids is None:
75
+ position_ids = torch.cumsum(attention_mask, dim=1)
76
+ position_ids = (position_ids * attention_mask - 1).long()
77
+ # cut positions if `past_key_values_length` is > 0
78
+ position_ids = position_ids[:, past_key_values_length:]
79
+
80
+ return super().forward(position_ids + self.offset)
81
+
82
+
83
+ class BioGptScaledWordEmbedding(nn.Embedding):
84
+ """
85
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
86
+ """
87
+
88
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
89
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
90
+ self.embed_scale = embed_scale
91
+
92
+ def forward(self, input_ids: torch.Tensor):
93
+ return super().forward(input_ids) * self.embed_scale
94
+
95
+
96
+ def eager_attention_forward(
97
+ module: nn.Module,
98
+ query: torch.Tensor,
99
+ key: torch.Tensor,
100
+ value: torch.Tensor,
101
+ attention_mask: Optional[torch.Tensor],
102
+ scaling: Optional[float] = None,
103
+ dropout: float = 0.0,
104
+ head_mask: Optional[torch.Tensor] = None,
105
+ **kwargs,
106
+ ):
107
+ if scaling is None:
108
+ scaling = query.size(-1) ** -0.5
109
+
110
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
111
+ if attention_mask is not None:
112
+ attn_weights = attn_weights + attention_mask
113
+
114
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
115
+
116
+ if head_mask is not None:
117
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
118
+
119
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
120
+ attn_output = torch.matmul(attn_weights, value)
121
+ attn_output = attn_output.transpose(1, 2).contiguous()
122
+
123
+ return attn_output, attn_weights
124
+
125
+
126
+ class BioGptAttention(nn.Module):
127
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
128
+
129
+ def __init__(
130
+ self,
131
+ embed_dim: int,
132
+ num_heads: int,
133
+ dropout: float = 0.0,
134
+ is_decoder: bool = False,
135
+ bias: bool = True,
136
+ is_causal: bool = False,
137
+ config: Optional[BioGptConfig] = None,
138
+ layer_idx: Optional[int] = None,
139
+ ):
140
+ super().__init__()
141
+ self.embed_dim = embed_dim
142
+ self.num_heads = num_heads
143
+ self.dropout = dropout
144
+ self.head_dim = embed_dim // num_heads
145
+ self.config = config
146
+
147
+ if (self.head_dim * num_heads) != self.embed_dim:
148
+ raise ValueError(
149
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
150
+ f" and `num_heads`: {num_heads})."
151
+ )
152
+ self.scaling = self.head_dim**-0.5
153
+ self.is_decoder = is_decoder
154
+ self.is_causal = is_causal
155
+ self.layer_idx = layer_idx
156
+ if layer_idx is None and self.is_decoder:
157
+ logger.warning_once(
158
+ f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
159
+ "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
160
+ "when creating this class."
161
+ )
162
+
163
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
164
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
165
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
166
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
167
+
168
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.Tensor,
172
+ key_value_states: Optional[torch.Tensor] = None,
173
+ past_key_values: Optional[Cache] = None,
174
+ attention_mask: Optional[torch.Tensor] = None,
175
+ layer_head_mask: Optional[torch.Tensor] = None,
176
+ output_attentions: bool = False,
177
+ cache_position: Optional[torch.Tensor] = None,
178
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
179
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
180
+ **kwargs: Unpack[FlashAttentionKwargs],
181
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
182
+ """Input shape: Batch x Time x Channel"""
183
+
184
+ # if key_value_states are provided this layer is used as a cross-attention layer
185
+ # for the decoder
186
+ is_cross_attention = key_value_states is not None
187
+
188
+ # determine input shapes
189
+ bsz, tgt_len = hidden_states.shape[:-1]
190
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
191
+
192
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
193
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
194
+
195
+ # get query proj
196
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
197
+
198
+ is_updated = False
199
+ if past_key_values is not None:
200
+ if isinstance(past_key_values, EncoderDecoderCache):
201
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
202
+ if is_cross_attention:
203
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
204
+ curr_past_key_value = past_key_values.cross_attention_cache
205
+ else:
206
+ curr_past_key_value = past_key_values.self_attention_cache
207
+ else:
208
+ curr_past_key_value = past_key_values
209
+
210
+ current_states = key_value_states if is_cross_attention else hidden_states
211
+ if is_cross_attention and past_key_values is not None and is_updated:
212
+ # reuse k,v, cross_attentions
213
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
214
+ value_states = curr_past_key_value.layers[self.layer_idx].values
215
+ else:
216
+ key_states = self.k_proj(current_states)
217
+ value_states = self.v_proj(current_states)
218
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
219
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
220
+
221
+ if past_key_values is not None:
222
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
223
+ cache_position = cache_position if not is_cross_attention else None
224
+ key_states, value_states = curr_past_key_value.update(
225
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
226
+ )
227
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
228
+ if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
229
+ past_key_values.is_updated[self.layer_idx] = True
230
+
231
+ attention_interface: Callable = eager_attention_forward
232
+ if self.config._attn_implementation != "eager":
233
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
234
+
235
+ attn_output, attn_weights = attention_interface(
236
+ self,
237
+ query_states,
238
+ key_states,
239
+ value_states,
240
+ attention_mask,
241
+ dropout=0.0 if not self.training else self.dropout,
242
+ scaling=self.scaling,
243
+ output_attentions=output_attentions,
244
+ head_mask=layer_head_mask,
245
+ **kwargs,
246
+ )
247
+
248
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
249
+ attn_output = self.out_proj(attn_output)
250
+
251
+ return attn_output, attn_weights
252
+
253
+
254
+ class BioGptDecoderLayer(GradientCheckpointingLayer):
255
+ def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None):
256
+ super().__init__()
257
+ self.embed_dim = config.hidden_size
258
+
259
+ self.self_attn = BioGptAttention(
260
+ embed_dim=self.embed_dim,
261
+ num_heads=config.num_attention_heads,
262
+ dropout=config.attention_probs_dropout_prob,
263
+ is_decoder=True,
264
+ is_causal=True,
265
+ config=config,
266
+ layer_idx=layer_idx,
267
+ )
268
+ self.dropout = config.hidden_dropout_prob
269
+ self.activation_fn = ACT2FN[config.hidden_act]
270
+ self.activation_dropout = config.activation_dropout
271
+
272
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
273
+
274
+ self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size)
275
+ self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim)
276
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
277
+
278
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
279
+ def forward(
280
+ self,
281
+ hidden_states: torch.Tensor,
282
+ attention_mask: Optional[torch.Tensor] = None,
283
+ layer_head_mask: Optional[torch.Tensor] = None,
284
+ past_key_values: Optional[Cache] = None,
285
+ output_attentions: Optional[bool] = False,
286
+ use_cache: Optional[bool] = True,
287
+ position_ids: Optional[torch.LongTensor] = None,
288
+ cache_position: Optional[torch.Tensor] = None,
289
+ **kwargs: Unpack[TransformersKwargs],
290
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
291
+ """
292
+ Args:
293
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
294
+ attention_mask (`torch.FloatTensor`): attention mask of size
295
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
296
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
297
+ `(encoder_attention_heads,)`.
298
+ past_key_values (`Cache`): cached past key and value projection states
299
+ output_attentions (`bool`, *optional*):
300
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
301
+ returned tensors for more detail.
302
+ use_cache (`bool`, *optional*):
303
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
304
+ (see `past_key_values`).
305
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
306
+ Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
307
+ cache in the correct position and to infer the complete sequence length.
308
+ """
309
+ residual = hidden_states
310
+
311
+ hidden_states = self.self_attn_layer_norm(hidden_states)
312
+
313
+ # Self Attention
314
+ hidden_states, self_attn_weights = self.self_attn(
315
+ hidden_states=hidden_states,
316
+ past_key_values=past_key_values,
317
+ attention_mask=attention_mask,
318
+ layer_head_mask=layer_head_mask,
319
+ output_attentions=output_attentions,
320
+ position_ids=position_ids,
321
+ cache_position=cache_position,
322
+ **kwargs,
323
+ )
324
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
325
+ hidden_states = residual + hidden_states
326
+
327
+ # Fully Connected
328
+ residual = hidden_states
329
+ hidden_states = self.final_layer_norm(hidden_states)
330
+ hidden_states = self.fc1(hidden_states)
331
+ hidden_states = self.activation_fn(hidden_states)
332
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
333
+ hidden_states = self.fc2(hidden_states)
334
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
335
+ hidden_states = residual + hidden_states
336
+
337
+ outputs = (hidden_states,)
338
+
339
+ if output_attentions:
340
+ outputs += (self_attn_weights,)
341
+
342
+ return outputs
343
+
344
+
345
+ @auto_docstring
346
+ class BioGptPreTrainedModel(PreTrainedModel):
347
+ config: BioGptConfig
348
+ base_model_prefix = "biogpt"
349
+ supports_gradient_checkpointing = True
350
+ _supports_flash_attn = True
351
+ _supports_sdpa = True
352
+ _supports_flex_attn = True
353
+
354
+ _can_compile_fullgraph = True
355
+
356
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
357
+ def _update_causal_mask(
358
+ self,
359
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
360
+ input_tensor: torch.Tensor,
361
+ cache_position: torch.Tensor,
362
+ past_key_values: Cache,
363
+ ):
364
+ if self.config._attn_implementation == "flex_attention":
365
+ if isinstance(attention_mask, torch.Tensor):
366
+ attention_mask = make_flex_block_causal_mask(attention_mask)
367
+ # Other attention flavors support in-built causal (when `mask is None`)
368
+ # while we need to create our specific block mask regardless
369
+ elif attention_mask is None:
370
+ attention_mask = make_flex_block_causal_mask(
371
+ torch.ones(
372
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
373
+ device=attention_mask.device,
374
+ )
375
+ )
376
+ return attention_mask
377
+
378
+ if self.config._attn_implementation == "flash_attention_2":
379
+ if attention_mask is not None and (attention_mask == 0.0).any():
380
+ return attention_mask
381
+ return None
382
+
383
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
384
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
385
+ # to infer the attention mask.
386
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
387
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
388
+
389
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
390
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
391
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
392
+ attention_mask,
393
+ inputs_embeds=input_tensor,
394
+ past_key_values_length=past_seen_tokens,
395
+ is_training=self.training,
396
+ ):
397
+ return None
398
+
399
+ dtype = input_tensor.dtype
400
+ sequence_length = input_tensor.shape[1]
401
+ if using_compilable_cache:
402
+ target_length = past_key_values.get_max_cache_shape()
403
+ else:
404
+ target_length = (
405
+ attention_mask.shape[-1]
406
+ if isinstance(attention_mask, torch.Tensor)
407
+ else past_seen_tokens + sequence_length + 1
408
+ )
409
+
410
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
411
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
412
+ attention_mask,
413
+ sequence_length=sequence_length,
414
+ target_length=target_length,
415
+ dtype=dtype,
416
+ cache_position=cache_position,
417
+ batch_size=input_tensor.shape[0],
418
+ )
419
+
420
+ if (
421
+ self.config._attn_implementation == "sdpa"
422
+ and attention_mask is not None
423
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
424
+ ):
425
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
426
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
427
+ # Details: https://github.com/pytorch/pytorch/issues/110213
428
+ min_dtype = torch.finfo(dtype).min
429
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
430
+
431
+ return causal_mask
432
+
433
+ @staticmethod
434
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
435
+ def _prepare_4d_causal_attention_mask_with_cache_position(
436
+ attention_mask: torch.Tensor,
437
+ sequence_length: int,
438
+ target_length: int,
439
+ dtype: torch.dtype,
440
+ cache_position: torch.Tensor,
441
+ batch_size: int,
442
+ **kwargs,
443
+ ):
444
+ """
445
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
446
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
447
+
448
+ Args:
449
+ attention_mask (`torch.Tensor`):
450
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
451
+ `(batch_size, 1, query_length, key_value_length)`.
452
+ sequence_length (`int`):
453
+ The sequence length being processed.
454
+ target_length (`int`):
455
+ The target length: when generating with static cache, the mask should be as long as the static cache,
456
+ to account for the 0 padding, the part of the cache that is not filled yet.
457
+ dtype (`torch.dtype`):
458
+ The dtype to use for the 4D attention mask.
459
+ cache_position (`torch.Tensor`):
460
+ Indices depicting the position of the input sequence tokens in the sequence.
461
+ batch_size (`torch.Tensor`):
462
+ Batch size.
463
+ """
464
+ if attention_mask is not None and attention_mask.dim() == 4:
465
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
466
+ causal_mask = attention_mask
467
+ else:
468
+ min_dtype = torch.finfo(dtype).min
469
+ causal_mask = torch.full(
470
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
471
+ )
472
+ if sequence_length != 1:
473
+ causal_mask = torch.triu(causal_mask, diagonal=1)
474
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
475
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
476
+ if attention_mask is not None:
477
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
478
+ mask_length = attention_mask.shape[-1]
479
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
480
+ causal_mask.device
481
+ )
482
+ padding_mask = padding_mask == 0
483
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
484
+ padding_mask, min_dtype
485
+ )
486
+
487
+ return causal_mask
488
+
489
+
490
+ @auto_docstring
491
+ class BioGptModel(BioGptPreTrainedModel):
492
+ def __init__(self, config: BioGptConfig):
493
+ super().__init__(config)
494
+ self.config = config
495
+ self.layerdrop = config.layerdrop
496
+ self.dropout = config.hidden_dropout_prob
497
+ self.embed_dim = config.hidden_size
498
+ self.padding_idx = config.pad_token_id
499
+ embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
500
+
501
+ self.embed_tokens = BioGptScaledWordEmbedding(
502
+ config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale
503
+ )
504
+ self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
505
+
506
+ self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
507
+ self.layer_norm = nn.LayerNorm(self.embed_dim)
508
+
509
+ self.gradient_checkpointing = False
510
+ # Initialize weights and apply final processing
511
+ self.post_init()
512
+
513
+ @auto_docstring
514
+ def forward(
515
+ self,
516
+ input_ids: Optional[torch.LongTensor] = None,
517
+ attention_mask: Optional[torch.FloatTensor] = None,
518
+ head_mask: Optional[torch.FloatTensor] = None,
519
+ inputs_embeds: Optional[torch.FloatTensor] = None,
520
+ past_key_values: Optional[Cache] = None,
521
+ use_cache: Optional[bool] = None,
522
+ position_ids: Optional[torch.LongTensor] = None,
523
+ output_attentions: Optional[bool] = None,
524
+ output_hidden_states: Optional[bool] = None,
525
+ return_dict: Optional[bool] = None,
526
+ cache_position: Optional[torch.Tensor] = None,
527
+ **kwargs: Unpack[TransformersKwargs],
528
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
529
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
530
+ output_hidden_states = (
531
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
532
+ )
533
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
534
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
535
+
536
+ # retrieve input_ids and inputs_embeds
537
+ if (input_ids is None) ^ (inputs_embeds is not None):
538
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
539
+ elif input_ids is not None:
540
+ input = input_ids
541
+ input_shape = input.shape
542
+ input_ids = input_ids.view(-1, input_shape[-1])
543
+ elif inputs_embeds is not None:
544
+ input_shape = inputs_embeds.size()[:-1]
545
+ input = inputs_embeds[:, :, -1]
546
+ else:
547
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
548
+
549
+ if inputs_embeds is None:
550
+ inputs_embeds = self.embed_tokens(input)
551
+
552
+ if self.gradient_checkpointing and self.training:
553
+ if use_cache:
554
+ logger.warning_once(
555
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
556
+ )
557
+ use_cache = False
558
+
559
+ # initialize past_key_values
560
+ if use_cache and past_key_values is None:
561
+ past_key_values = DynamicCache(config=self.config)
562
+ if use_cache and isinstance(past_key_values, tuple):
563
+ logger.warning_once(
564
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
565
+ "You should pass an instance of `DynamicCache` instead, e.g. "
566
+ "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
567
+ )
568
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
569
+
570
+ batch_size, seq_length = inputs_embeds.size()[:-1]
571
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
572
+ if cache_position is None:
573
+ cache_position = torch.arange(
574
+ past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
575
+ )
576
+
577
+ if attention_mask is None:
578
+ # required mask seq length can be calculated via length of past cache
579
+ mask_seq_length = past_key_values_length + seq_length
580
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
581
+
582
+ self_attn_cache = past_key_values
583
+
584
+ causal_mask = self._update_causal_mask(
585
+ attention_mask,
586
+ inputs_embeds,
587
+ cache_position,
588
+ self_attn_cache,
589
+ )
590
+
591
+ # embed positions
592
+ if position_ids is None:
593
+ # position_ids = cache_position.unsqueeze(0)
594
+ position_ids = torch.cumsum(attention_mask, dim=1)
595
+ position_ids = (position_ids * attention_mask - 1).long()
596
+ # cut positions if `past_seen_tokens` is > 0
597
+ position_ids = position_ids[:, past_key_values_length:]
598
+
599
+ positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
600
+ hidden_states = inputs_embeds + positions
601
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
602
+
603
+ if self.gradient_checkpointing and self.training:
604
+ if use_cache:
605
+ logger.warning_once(
606
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
607
+ )
608
+ use_cache = False
609
+
610
+ all_hidden_states = () if output_hidden_states else None
611
+ all_self_attns = () if output_attentions else None
612
+ all_cross_attentions = None
613
+
614
+ for idx, decoder_layer in enumerate(self.layers):
615
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
616
+ if output_hidden_states:
617
+ all_hidden_states += (hidden_states,)
618
+ if self.training:
619
+ dropout_probability = torch.rand([])
620
+ if dropout_probability < self.layerdrop:
621
+ continue
622
+
623
+ layer_outputs = decoder_layer(
624
+ hidden_states,
625
+ attention_mask=causal_mask,
626
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
627
+ past_key_values=past_key_values,
628
+ output_attentions=output_attentions,
629
+ use_cache=use_cache,
630
+ position_ids=position_ids,
631
+ cache_position=cache_position,
632
+ **kwargs,
633
+ )
634
+
635
+ hidden_states = layer_outputs[0]
636
+
637
+ if output_attentions:
638
+ all_self_attns += (layer_outputs[1],)
639
+
640
+ # add hidden states from the last decoder layer
641
+ if output_hidden_states:
642
+ all_hidden_states += (hidden_states,)
643
+
644
+ hidden_states = self.layer_norm(hidden_states)
645
+
646
+ if not return_dict:
647
+ return tuple(
648
+ v
649
+ for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
650
+ if v is not None
651
+ )
652
+ return BaseModelOutputWithPastAndCrossAttentions(
653
+ last_hidden_state=hidden_states,
654
+ past_key_values=past_key_values,
655
+ hidden_states=all_hidden_states,
656
+ attentions=all_self_attns,
657
+ cross_attentions=all_cross_attentions,
658
+ )
659
+
660
+
661
+ @auto_docstring(
662
+ custom_intro="""
663
+ BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
664
+ """
665
+ )
666
+ class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
667
+ _tied_weights_keys = ["output_projection.weight"]
668
+
669
+ def __init__(self, config):
670
+ super().__init__(config)
671
+
672
+ self.biogpt = BioGptModel(config)
673
+ self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
674
+
675
+ # Initialize weights and apply final processing
676
+ self.post_init()
677
+
678
+ def get_output_embeddings(self):
679
+ return self.output_projection
680
+
681
+ def set_output_embeddings(self, new_embeddings):
682
+ self.output_projection = new_embeddings
683
+
684
+ @auto_docstring
685
+ def forward(
686
+ self,
687
+ input_ids: Optional[torch.LongTensor] = None,
688
+ attention_mask: Optional[torch.FloatTensor] = None,
689
+ head_mask: Optional[torch.FloatTensor] = None,
690
+ inputs_embeds: Optional[torch.FloatTensor] = None,
691
+ past_key_values: Optional[Cache] = None,
692
+ labels: Optional[torch.LongTensor] = None,
693
+ use_cache: Optional[bool] = None,
694
+ position_ids: Optional[torch.LongTensor] = None,
695
+ output_attentions: Optional[bool] = None,
696
+ output_hidden_states: Optional[bool] = None,
697
+ return_dict: Optional[bool] = None,
698
+ cache_position: Optional[torch.Tensor] = None,
699
+ **kwargs: Unpack[TransformersKwargs],
700
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
701
+ r"""
702
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
703
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
704
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
705
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
706
+ """
707
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
708
+
709
+ outputs = self.biogpt(
710
+ input_ids,
711
+ attention_mask=attention_mask,
712
+ head_mask=head_mask,
713
+ inputs_embeds=inputs_embeds,
714
+ past_key_values=past_key_values,
715
+ use_cache=use_cache,
716
+ position_ids=position_ids,
717
+ output_attentions=output_attentions,
718
+ output_hidden_states=output_hidden_states,
719
+ return_dict=return_dict,
720
+ cache_position=cache_position,
721
+ **kwargs,
722
+ )
723
+
724
+ sequence_output = outputs[0]
725
+ prediction_scores = self.output_projection(sequence_output)
726
+
727
+ lm_loss = None
728
+ if labels is not None:
729
+ lm_loss = self.loss_function(
730
+ prediction_scores,
731
+ labels,
732
+ vocab_size=self.config.vocab_size,
733
+ **kwargs,
734
+ )
735
+
736
+ if not return_dict:
737
+ output = (prediction_scores,) + outputs[1:]
738
+ return ((lm_loss,) + output) if lm_loss is not None else output
739
+
740
+ return CausalLMOutputWithCrossAttentions(
741
+ loss=lm_loss,
742
+ logits=prediction_scores,
743
+ past_key_values=outputs.past_key_values,
744
+ hidden_states=outputs.hidden_states,
745
+ attentions=outputs.attentions,
746
+ cross_attentions=outputs.cross_attentions,
747
+ )
748
+
749
+
750
+ @auto_docstring
751
+ class BioGptForTokenClassification(BioGptPreTrainedModel):
752
+ def __init__(self, config):
753
+ super().__init__(config)
754
+ self.num_labels = config.num_labels
755
+
756
+ self.biogpt = BioGptModel(config)
757
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
758
+ classifier_dropout = config.classifier_dropout
759
+ else:
760
+ classifier_dropout = config.hidden_dropout_prob
761
+ self.dropout = nn.Dropout(classifier_dropout)
762
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
763
+
764
+ self.post_init()
765
+
766
+ @auto_docstring
767
+ def forward(
768
+ self,
769
+ input_ids: Optional[torch.LongTensor] = None,
770
+ token_type_ids: Optional[torch.LongTensor] = None,
771
+ attention_mask: Optional[torch.FloatTensor] = None,
772
+ head_mask: Optional[torch.FloatTensor] = None,
773
+ past_key_values: Optional[Cache] = None,
774
+ inputs_embeds: Optional[torch.FloatTensor] = None,
775
+ labels: Optional[torch.LongTensor] = None,
776
+ use_cache: Optional[bool] = None,
777
+ position_ids: Optional[torch.LongTensor] = None,
778
+ output_attentions: Optional[bool] = None,
779
+ output_hidden_states: Optional[bool] = None,
780
+ return_dict: Optional[bool] = None,
781
+ cache_position: Optional[torch.Tensor] = None,
782
+ ) -> Union[tuple, TokenClassifierOutput]:
783
+ r"""
784
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
785
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
786
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
787
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
788
+ """
789
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
790
+
791
+ transformer_outputs = self.biogpt(
792
+ input_ids,
793
+ past_key_values=past_key_values,
794
+ attention_mask=attention_mask,
795
+ head_mask=head_mask,
796
+ inputs_embeds=inputs_embeds,
797
+ use_cache=use_cache,
798
+ position_ids=position_ids,
799
+ output_attentions=output_attentions,
800
+ output_hidden_states=output_hidden_states,
801
+ return_dict=return_dict,
802
+ cache_position=cache_position,
803
+ )
804
+
805
+ hidden_states = transformer_outputs[0]
806
+ hidden_states = self.dropout(hidden_states)
807
+ logits = self.classifier(hidden_states)
808
+
809
+ loss = None
810
+ if labels is not None:
811
+ loss_fct = CrossEntropyLoss()
812
+ # Only keep active parts of the loss
813
+ if attention_mask is not None:
814
+ active_loss = attention_mask.view(-1) == 1
815
+ active_logits = logits.view(-1, self.num_labels)
816
+ active_labels = torch.where(
817
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
818
+ )
819
+ loss = loss_fct(active_logits, active_labels)
820
+ else:
821
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
822
+
823
+ if not return_dict:
824
+ output = (logits,) + transformer_outputs[2:]
825
+ return ((loss,) + output) if loss is not None else output
826
+
827
+ return TokenClassifierOutput(
828
+ loss=loss,
829
+ logits=logits,
830
+ hidden_states=transformer_outputs.hidden_states,
831
+ attentions=transformer_outputs.attentions,
832
+ )
833
+
834
+
835
+ @auto_docstring(
836
+ custom_intro="""
837
+ The BioGpt Model transformer with a sequence classification head on top (linear layer).
838
+
839
+ [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
840
+ (e.g. GPT-2) do.
841
+
842
+ Since it does classification on the last token, it is required to know the position of the last token. If a
843
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
844
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
845
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
846
+ each row of the batch).
847
+ """
848
+ )
849
+ class BioGptForSequenceClassification(BioGptPreTrainedModel):
850
+ def __init__(self, config: BioGptConfig):
851
+ super().__init__(config)
852
+ self.num_labels = config.num_labels
853
+ self.biogpt = BioGptModel(config)
854
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
855
+
856
+ # Initialize weights and apply final processing
857
+ self.post_init()
858
+
859
+ @auto_docstring
860
+ def forward(
861
+ self,
862
+ input_ids: Optional[torch.LongTensor] = None,
863
+ attention_mask: Optional[torch.FloatTensor] = None,
864
+ head_mask: Optional[torch.FloatTensor] = None,
865
+ past_key_values: Optional[Cache] = None,
866
+ inputs_embeds: Optional[torch.FloatTensor] = None,
867
+ labels: Optional[torch.LongTensor] = None,
868
+ use_cache: Optional[bool] = None,
869
+ position_ids: Optional[torch.LongTensor] = None,
870
+ output_attentions: Optional[bool] = None,
871
+ output_hidden_states: Optional[bool] = None,
872
+ return_dict: Optional[bool] = None,
873
+ cache_position: Optional[torch.Tensor] = None,
874
+ logits_to_keep: Union[int, torch.Tensor] = 0,
875
+ ) -> Union[tuple, SequenceClassifierOutputWithPast]:
876
+ r"""
877
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
878
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
879
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
880
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
881
+ """
882
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
883
+
884
+ transformer_outputs = self.biogpt(
885
+ input_ids,
886
+ past_key_values=past_key_values,
887
+ attention_mask=attention_mask,
888
+ head_mask=head_mask,
889
+ inputs_embeds=inputs_embeds,
890
+ use_cache=use_cache,
891
+ position_ids=position_ids,
892
+ output_attentions=output_attentions,
893
+ output_hidden_states=output_hidden_states,
894
+ return_dict=return_dict,
895
+ cache_position=cache_position,
896
+ )
897
+ hidden_states = transformer_outputs[0]
898
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
899
+ logits = self.score(hidden_states[:, slice_indices, :])
900
+
901
+ if input_ids is not None:
902
+ batch_size, sequence_length = input_ids.shape[:2]
903
+ else:
904
+ batch_size, sequence_length = inputs_embeds.shape[:2]
905
+
906
+ if self.config.pad_token_id is None:
907
+ sequence_length = -1
908
+ else:
909
+ if input_ids is not None:
910
+ sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
911
+ else:
912
+ sequence_length = -1
913
+ logger.warning_once(
914
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
915
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
916
+ )
917
+
918
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length]
919
+
920
+ loss = None
921
+ if labels is not None:
922
+ if self.config.problem_type is None:
923
+ if self.num_labels == 1:
924
+ self.config.problem_type = "regression"
925
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
926
+ self.config.problem_type = "single_label_classification"
927
+ else:
928
+ self.config.problem_type = "multi_label_classification"
929
+
930
+ if self.config.problem_type == "regression":
931
+ loss_fct = MSELoss()
932
+ if self.num_labels == 1:
933
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
934
+ else:
935
+ loss = loss_fct(pooled_logits, labels)
936
+ elif self.config.problem_type == "single_label_classification":
937
+ loss_fct = CrossEntropyLoss()
938
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
939
+ elif self.config.problem_type == "multi_label_classification":
940
+ loss_fct = BCEWithLogitsLoss()
941
+ loss = loss_fct(pooled_logits, labels)
942
+ if not return_dict:
943
+ output = (pooled_logits,) + transformer_outputs[1:]
944
+ return ((loss,) + output) if loss is not None else output
945
+
946
+ return SequenceClassifierOutputWithPast(
947
+ loss=loss,
948
+ logits=pooled_logits,
949
+ past_key_values=transformer_outputs.past_key_values,
950
+ hidden_states=transformer_outputs.hidden_states,
951
+ attentions=transformer_outputs.attentions,
952
+ )
953
+
954
+ def get_input_embeddings(self):
955
+ return self.biogpt.embed_tokens
956
+
957
+ def set_input_embeddings(self, value):
958
+ self.biogpt.embed_tokens = value
959
+
960
+
961
+ __all__ = [
962
+ "BioGptForCausalLM",
963
+ "BioGptForTokenClassification",
964
+ "BioGptForSequenceClassification",
965
+ "BioGptModel",
966
+ "BioGptPreTrainedModel",
967
+ ]
venv/lib/python3.13/site-packages/transformers/models/biogpt/modular_biogpt.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch BioGPT model."""
16
+
17
+ import math
18
+ from typing import Optional, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+
24
+ from ...activations import ACT2FN
25
+ from ...cache_utils import Cache, DynamicCache
26
+ from ...generation import GenerationMixin
27
+ from ...modeling_attn_mask_utils import (
28
+ AttentionMaskConverter,
29
+ )
30
+ from ...modeling_outputs import (
31
+ BaseModelOutputWithPastAndCrossAttentions,
32
+ CausalLMOutputWithCrossAttentions,
33
+ SequenceClassifierOutputWithPast,
34
+ TokenClassifierOutput,
35
+ )
36
+ from ...modeling_utils import PreTrainedModel
37
+ from ...processing_utils import Unpack
38
+ from ...utils import (
39
+ TransformersKwargs,
40
+ auto_docstring,
41
+ is_torch_flex_attn_available,
42
+ logger,
43
+ )
44
+ from ...utils.deprecation import deprecate_kwarg
45
+ from ..bart.modeling_bart import (
46
+ BartAttention,
47
+ BartDecoderLayer,
48
+ BartScaledWordEmbedding,
49
+ )
50
+ from ..opt.modeling_opt import OPTLearnedPositionalEmbedding
51
+ from .configuration_biogpt import BioGptConfig
52
+
53
+
54
+ if is_torch_flex_attn_available():
55
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
56
+
57
+
58
+ class BioGptLearnedPositionalEmbedding(OPTLearnedPositionalEmbedding):
59
+ def forward(
60
+ self,
61
+ attention_mask: torch.LongTensor,
62
+ past_key_values_length: int = 0,
63
+ position_ids: Optional[torch.LongTensor] = None,
64
+ ):
65
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
66
+ super().forward(attention_mask, past_key_values_length, position_ids)
67
+
68
+
69
+ class BioGptScaledWordEmbedding(BartScaledWordEmbedding):
70
+ pass
71
+
72
+
73
+ class BioGptAttention(BartAttention):
74
+ pass
75
+
76
+
77
+ class BioGptDecoderLayer(BartDecoderLayer):
78
+ def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None):
79
+ super().__init__(config)
80
+ self.embed_dim = config.hidden_size
81
+
82
+ self.self_attn = BioGptAttention(
83
+ embed_dim=self.embed_dim,
84
+ num_heads=config.num_attention_heads,
85
+ dropout=config.attention_probs_dropout_prob,
86
+ is_decoder=True,
87
+ is_causal=True,
88
+ config=config,
89
+ layer_idx=layer_idx,
90
+ )
91
+ self.dropout = config.hidden_dropout_prob
92
+ self.activation_fn = ACT2FN[config.hidden_act]
93
+
94
+ self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size)
95
+ self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim)
96
+
97
+ del self.encoder_attn
98
+ del self.encoder_attn_layer_norm
99
+
100
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
101
+ def forward(
102
+ self,
103
+ hidden_states: torch.Tensor,
104
+ attention_mask: Optional[torch.Tensor] = None,
105
+ layer_head_mask: Optional[torch.Tensor] = None,
106
+ past_key_values: Optional[Cache] = None,
107
+ output_attentions: Optional[bool] = False,
108
+ use_cache: Optional[bool] = True,
109
+ position_ids: Optional[torch.LongTensor] = None,
110
+ cache_position: Optional[torch.Tensor] = None,
111
+ **kwargs: Unpack[TransformersKwargs],
112
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
113
+ """
114
+ Args:
115
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
116
+ attention_mask (`torch.FloatTensor`): attention mask of size
117
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
118
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
119
+ `(encoder_attention_heads,)`.
120
+ past_key_values (`Cache`): cached past key and value projection states
121
+ output_attentions (`bool`, *optional*):
122
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
123
+ returned tensors for more detail.
124
+ use_cache (`bool`, *optional*):
125
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
126
+ (see `past_key_values`).
127
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
128
+ Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
129
+ cache in the correct position and to infer the complete sequence length.
130
+ """
131
+ residual = hidden_states
132
+
133
+ hidden_states = self.self_attn_layer_norm(hidden_states)
134
+
135
+ # Self Attention
136
+ hidden_states, self_attn_weights = self.self_attn(
137
+ hidden_states=hidden_states,
138
+ past_key_values=past_key_values,
139
+ attention_mask=attention_mask,
140
+ layer_head_mask=layer_head_mask,
141
+ output_attentions=output_attentions,
142
+ position_ids=position_ids,
143
+ cache_position=cache_position,
144
+ **kwargs,
145
+ )
146
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
147
+ hidden_states = residual + hidden_states
148
+
149
+ # Fully Connected
150
+ residual = hidden_states
151
+ hidden_states = self.final_layer_norm(hidden_states)
152
+ hidden_states = self.fc1(hidden_states)
153
+ hidden_states = self.activation_fn(hidden_states)
154
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
155
+ hidden_states = self.fc2(hidden_states)
156
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
157
+ hidden_states = residual + hidden_states
158
+
159
+ outputs = (hidden_states,)
160
+
161
+ if output_attentions:
162
+ outputs += (self_attn_weights,)
163
+
164
+ return outputs
165
+
166
+
167
+ @auto_docstring
168
+ class BioGptPreTrainedModel(PreTrainedModel):
169
+ config: BioGptConfig
170
+ base_model_prefix = "biogpt"
171
+ supports_gradient_checkpointing = True
172
+ _supports_flash_attn = True
173
+ _supports_sdpa = True
174
+ _supports_flex_attn = True
175
+
176
+ _can_compile_fullgraph = True
177
+
178
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
179
+ def _update_causal_mask(
180
+ self,
181
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
182
+ input_tensor: torch.Tensor,
183
+ cache_position: torch.Tensor,
184
+ past_key_values: Cache,
185
+ ):
186
+ if self.config._attn_implementation == "flex_attention":
187
+ if isinstance(attention_mask, torch.Tensor):
188
+ attention_mask = make_flex_block_causal_mask(attention_mask)
189
+ # Other attention flavors support in-built causal (when `mask is None`)
190
+ # while we need to create our specific block mask regardless
191
+ elif attention_mask is None:
192
+ attention_mask = make_flex_block_causal_mask(
193
+ torch.ones(
194
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
195
+ device=attention_mask.device,
196
+ )
197
+ )
198
+ return attention_mask
199
+
200
+ if self.config._attn_implementation == "flash_attention_2":
201
+ if attention_mask is not None and (attention_mask == 0.0).any():
202
+ return attention_mask
203
+ return None
204
+
205
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
206
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
207
+ # to infer the attention mask.
208
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
209
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
210
+
211
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
212
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
213
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
214
+ attention_mask,
215
+ inputs_embeds=input_tensor,
216
+ past_key_values_length=past_seen_tokens,
217
+ is_training=self.training,
218
+ ):
219
+ return None
220
+
221
+ dtype = input_tensor.dtype
222
+ sequence_length = input_tensor.shape[1]
223
+ if using_compilable_cache:
224
+ target_length = past_key_values.get_max_cache_shape()
225
+ else:
226
+ target_length = (
227
+ attention_mask.shape[-1]
228
+ if isinstance(attention_mask, torch.Tensor)
229
+ else past_seen_tokens + sequence_length + 1
230
+ )
231
+
232
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
233
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
234
+ attention_mask,
235
+ sequence_length=sequence_length,
236
+ target_length=target_length,
237
+ dtype=dtype,
238
+ cache_position=cache_position,
239
+ batch_size=input_tensor.shape[0],
240
+ )
241
+
242
+ if (
243
+ self.config._attn_implementation == "sdpa"
244
+ and attention_mask is not None
245
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
246
+ ):
247
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
248
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
249
+ # Details: https://github.com/pytorch/pytorch/issues/110213
250
+ min_dtype = torch.finfo(dtype).min
251
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
252
+
253
+ return causal_mask
254
+
255
+ @staticmethod
256
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
257
+ def _prepare_4d_causal_attention_mask_with_cache_position(
258
+ attention_mask: torch.Tensor,
259
+ sequence_length: int,
260
+ target_length: int,
261
+ dtype: torch.dtype,
262
+ cache_position: torch.Tensor,
263
+ batch_size: int,
264
+ **kwargs,
265
+ ):
266
+ """
267
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
268
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
269
+
270
+ Args:
271
+ attention_mask (`torch.Tensor`):
272
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
273
+ `(batch_size, 1, query_length, key_value_length)`.
274
+ sequence_length (`int`):
275
+ The sequence length being processed.
276
+ target_length (`int`):
277
+ The target length: when generating with static cache, the mask should be as long as the static cache,
278
+ to account for the 0 padding, the part of the cache that is not filled yet.
279
+ dtype (`torch.dtype`):
280
+ The dtype to use for the 4D attention mask.
281
+ cache_position (`torch.Tensor`):
282
+ Indices depicting the position of the input sequence tokens in the sequence.
283
+ batch_size (`torch.Tensor`):
284
+ Batch size.
285
+ """
286
+ if attention_mask is not None and attention_mask.dim() == 4:
287
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
288
+ causal_mask = attention_mask
289
+ else:
290
+ min_dtype = torch.finfo(dtype).min
291
+ causal_mask = torch.full(
292
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
293
+ )
294
+ if sequence_length != 1:
295
+ causal_mask = torch.triu(causal_mask, diagonal=1)
296
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
297
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
298
+ if attention_mask is not None:
299
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
300
+ mask_length = attention_mask.shape[-1]
301
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
302
+ causal_mask.device
303
+ )
304
+ padding_mask = padding_mask == 0
305
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
306
+ padding_mask, min_dtype
307
+ )
308
+
309
+ return causal_mask
310
+
311
+
312
+ @auto_docstring
313
+ class BioGptModel(BioGptPreTrainedModel):
314
+ def __init__(self, config: BioGptConfig):
315
+ super().__init__(config)
316
+ self.config = config
317
+ self.layerdrop = config.layerdrop
318
+ self.dropout = config.hidden_dropout_prob
319
+ self.embed_dim = config.hidden_size
320
+ self.padding_idx = config.pad_token_id
321
+ embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
322
+
323
+ self.embed_tokens = BioGptScaledWordEmbedding(
324
+ config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale
325
+ )
326
+ self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
327
+
328
+ self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
329
+ self.layer_norm = nn.LayerNorm(self.embed_dim)
330
+
331
+ self.gradient_checkpointing = False
332
+ # Initialize weights and apply final processing
333
+ self.post_init()
334
+
335
+ @auto_docstring
336
+ def forward(
337
+ self,
338
+ input_ids: Optional[torch.LongTensor] = None,
339
+ attention_mask: Optional[torch.FloatTensor] = None,
340
+ head_mask: Optional[torch.FloatTensor] = None,
341
+ inputs_embeds: Optional[torch.FloatTensor] = None,
342
+ past_key_values: Optional[Cache] = None,
343
+ use_cache: Optional[bool] = None,
344
+ position_ids: Optional[torch.LongTensor] = None,
345
+ output_attentions: Optional[bool] = None,
346
+ output_hidden_states: Optional[bool] = None,
347
+ return_dict: Optional[bool] = None,
348
+ cache_position: Optional[torch.Tensor] = None,
349
+ **kwargs: Unpack[TransformersKwargs],
350
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
351
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
352
+ output_hidden_states = (
353
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
354
+ )
355
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
356
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
357
+
358
+ # retrieve input_ids and inputs_embeds
359
+ if (input_ids is None) ^ (inputs_embeds is not None):
360
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
361
+ elif input_ids is not None:
362
+ input = input_ids
363
+ input_shape = input.shape
364
+ input_ids = input_ids.view(-1, input_shape[-1])
365
+ elif inputs_embeds is not None:
366
+ input_shape = inputs_embeds.size()[:-1]
367
+ input = inputs_embeds[:, :, -1]
368
+ else:
369
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
370
+
371
+ if inputs_embeds is None:
372
+ inputs_embeds = self.embed_tokens(input)
373
+
374
+ if self.gradient_checkpointing and self.training:
375
+ if use_cache:
376
+ logger.warning_once(
377
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
378
+ )
379
+ use_cache = False
380
+
381
+ # initialize past_key_values
382
+ if use_cache and past_key_values is None:
383
+ past_key_values = DynamicCache(config=self.config)
384
+ if use_cache and isinstance(past_key_values, tuple):
385
+ logger.warning_once(
386
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
387
+ "You should pass an instance of `DynamicCache` instead, e.g. "
388
+ "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
389
+ )
390
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
391
+
392
+ batch_size, seq_length = inputs_embeds.size()[:-1]
393
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
394
+ if cache_position is None:
395
+ cache_position = torch.arange(
396
+ past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
397
+ )
398
+
399
+ if attention_mask is None:
400
+ # required mask seq length can be calculated via length of past cache
401
+ mask_seq_length = past_key_values_length + seq_length
402
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
403
+
404
+ self_attn_cache = past_key_values
405
+
406
+ causal_mask = self._update_causal_mask(
407
+ attention_mask,
408
+ inputs_embeds,
409
+ cache_position,
410
+ self_attn_cache,
411
+ )
412
+
413
+ # embed positions
414
+ if position_ids is None:
415
+ # position_ids = cache_position.unsqueeze(0)
416
+ position_ids = torch.cumsum(attention_mask, dim=1)
417
+ position_ids = (position_ids * attention_mask - 1).long()
418
+ # cut positions if `past_seen_tokens` is > 0
419
+ position_ids = position_ids[:, past_key_values_length:]
420
+
421
+ positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
422
+ hidden_states = inputs_embeds + positions
423
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
424
+
425
+ if self.gradient_checkpointing and self.training:
426
+ if use_cache:
427
+ logger.warning_once(
428
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
429
+ )
430
+ use_cache = False
431
+
432
+ all_hidden_states = () if output_hidden_states else None
433
+ all_self_attns = () if output_attentions else None
434
+ all_cross_attentions = None
435
+
436
+ for idx, decoder_layer in enumerate(self.layers):
437
+ # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
438
+ if output_hidden_states:
439
+ all_hidden_states += (hidden_states,)
440
+ if self.training:
441
+ dropout_probability = torch.rand([])
442
+ if dropout_probability < self.layerdrop:
443
+ continue
444
+
445
+ layer_outputs = decoder_layer(
446
+ hidden_states,
447
+ attention_mask=causal_mask,
448
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
449
+ past_key_values=past_key_values,
450
+ output_attentions=output_attentions,
451
+ use_cache=use_cache,
452
+ position_ids=position_ids,
453
+ cache_position=cache_position,
454
+ **kwargs,
455
+ )
456
+
457
+ hidden_states = layer_outputs[0]
458
+
459
+ if output_attentions:
460
+ all_self_attns += (layer_outputs[1],)
461
+
462
+ # add hidden states from the last decoder layer
463
+ if output_hidden_states:
464
+ all_hidden_states += (hidden_states,)
465
+
466
+ hidden_states = self.layer_norm(hidden_states)
467
+
468
+ if not return_dict:
469
+ return tuple(
470
+ v
471
+ for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
472
+ if v is not None
473
+ )
474
+ return BaseModelOutputWithPastAndCrossAttentions(
475
+ last_hidden_state=hidden_states,
476
+ past_key_values=past_key_values,
477
+ hidden_states=all_hidden_states,
478
+ attentions=all_self_attns,
479
+ cross_attentions=all_cross_attentions,
480
+ )
481
+
482
+
483
+ @auto_docstring(
484
+ custom_intro="""
485
+ BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
486
+ """
487
+ )
488
+ class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
489
+ _tied_weights_keys = ["output_projection.weight"]
490
+
491
+ def __init__(self, config):
492
+ super().__init__(config)
493
+
494
+ self.biogpt = BioGptModel(config)
495
+ self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
496
+
497
+ # Initialize weights and apply final processing
498
+ self.post_init()
499
+
500
+ def get_output_embeddings(self):
501
+ return self.output_projection
502
+
503
+ def set_output_embeddings(self, new_embeddings):
504
+ self.output_projection = new_embeddings
505
+
506
+ @auto_docstring
507
+ def forward(
508
+ self,
509
+ input_ids: Optional[torch.LongTensor] = None,
510
+ attention_mask: Optional[torch.FloatTensor] = None,
511
+ head_mask: Optional[torch.FloatTensor] = None,
512
+ inputs_embeds: Optional[torch.FloatTensor] = None,
513
+ past_key_values: Optional[Cache] = None,
514
+ labels: Optional[torch.LongTensor] = None,
515
+ use_cache: Optional[bool] = None,
516
+ position_ids: Optional[torch.LongTensor] = None,
517
+ output_attentions: Optional[bool] = None,
518
+ output_hidden_states: Optional[bool] = None,
519
+ return_dict: Optional[bool] = None,
520
+ cache_position: Optional[torch.Tensor] = None,
521
+ **kwargs: Unpack[TransformersKwargs],
522
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
523
+ r"""
524
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
525
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
526
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
527
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
528
+ """
529
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
530
+
531
+ outputs = self.biogpt(
532
+ input_ids,
533
+ attention_mask=attention_mask,
534
+ head_mask=head_mask,
535
+ inputs_embeds=inputs_embeds,
536
+ past_key_values=past_key_values,
537
+ use_cache=use_cache,
538
+ position_ids=position_ids,
539
+ output_attentions=output_attentions,
540
+ output_hidden_states=output_hidden_states,
541
+ return_dict=return_dict,
542
+ cache_position=cache_position,
543
+ **kwargs,
544
+ )
545
+
546
+ sequence_output = outputs[0]
547
+ prediction_scores = self.output_projection(sequence_output)
548
+
549
+ lm_loss = None
550
+ if labels is not None:
551
+ lm_loss = self.loss_function(
552
+ prediction_scores,
553
+ labels,
554
+ vocab_size=self.config.vocab_size,
555
+ **kwargs,
556
+ )
557
+
558
+ if not return_dict:
559
+ output = (prediction_scores,) + outputs[1:]
560
+ return ((lm_loss,) + output) if lm_loss is not None else output
561
+
562
+ return CausalLMOutputWithCrossAttentions(
563
+ loss=lm_loss,
564
+ logits=prediction_scores,
565
+ past_key_values=outputs.past_key_values,
566
+ hidden_states=outputs.hidden_states,
567
+ attentions=outputs.attentions,
568
+ cross_attentions=outputs.cross_attentions,
569
+ )
570
+
571
+
572
+ @auto_docstring
573
+ class BioGptForTokenClassification(BioGptPreTrainedModel):
574
+ def __init__(self, config):
575
+ super().__init__(config)
576
+ self.num_labels = config.num_labels
577
+
578
+ self.biogpt = BioGptModel(config)
579
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
580
+ classifier_dropout = config.classifier_dropout
581
+ else:
582
+ classifier_dropout = config.hidden_dropout_prob
583
+ self.dropout = nn.Dropout(classifier_dropout)
584
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
585
+
586
+ self.post_init()
587
+
588
+ @auto_docstring
589
+ def forward(
590
+ self,
591
+ input_ids: Optional[torch.LongTensor] = None,
592
+ token_type_ids: Optional[torch.LongTensor] = None,
593
+ attention_mask: Optional[torch.FloatTensor] = None,
594
+ head_mask: Optional[torch.FloatTensor] = None,
595
+ past_key_values: Optional[Cache] = None,
596
+ inputs_embeds: Optional[torch.FloatTensor] = None,
597
+ labels: Optional[torch.LongTensor] = None,
598
+ use_cache: Optional[bool] = None,
599
+ position_ids: Optional[torch.LongTensor] = None,
600
+ output_attentions: Optional[bool] = None,
601
+ output_hidden_states: Optional[bool] = None,
602
+ return_dict: Optional[bool] = None,
603
+ cache_position: Optional[torch.Tensor] = None,
604
+ ) -> Union[tuple, TokenClassifierOutput]:
605
+ r"""
606
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
607
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
608
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
609
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
610
+ """
611
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
612
+
613
+ transformer_outputs = self.biogpt(
614
+ input_ids,
615
+ past_key_values=past_key_values,
616
+ attention_mask=attention_mask,
617
+ head_mask=head_mask,
618
+ inputs_embeds=inputs_embeds,
619
+ use_cache=use_cache,
620
+ position_ids=position_ids,
621
+ output_attentions=output_attentions,
622
+ output_hidden_states=output_hidden_states,
623
+ return_dict=return_dict,
624
+ cache_position=cache_position,
625
+ )
626
+
627
+ hidden_states = transformer_outputs[0]
628
+ hidden_states = self.dropout(hidden_states)
629
+ logits = self.classifier(hidden_states)
630
+
631
+ loss = None
632
+ if labels is not None:
633
+ loss_fct = CrossEntropyLoss()
634
+ # Only keep active parts of the loss
635
+ if attention_mask is not None:
636
+ active_loss = attention_mask.view(-1) == 1
637
+ active_logits = logits.view(-1, self.num_labels)
638
+ active_labels = torch.where(
639
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
640
+ )
641
+ loss = loss_fct(active_logits, active_labels)
642
+ else:
643
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
644
+
645
+ if not return_dict:
646
+ output = (logits,) + transformer_outputs[2:]
647
+ return ((loss,) + output) if loss is not None else output
648
+
649
+ return TokenClassifierOutput(
650
+ loss=loss,
651
+ logits=logits,
652
+ hidden_states=transformer_outputs.hidden_states,
653
+ attentions=transformer_outputs.attentions,
654
+ )
655
+
656
+
657
+ @auto_docstring(
658
+ custom_intro="""
659
+ The BioGpt Model transformer with a sequence classification head on top (linear layer).
660
+
661
+ [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
662
+ (e.g. GPT-2) do.
663
+
664
+ Since it does classification on the last token, it is required to know the position of the last token. If a
665
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
666
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
667
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
668
+ each row of the batch).
669
+ """
670
+ )
671
+ class BioGptForSequenceClassification(BioGptPreTrainedModel):
672
+ def __init__(self, config: BioGptConfig):
673
+ super().__init__(config)
674
+ self.num_labels = config.num_labels
675
+ self.biogpt = BioGptModel(config)
676
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
677
+
678
+ # Initialize weights and apply final processing
679
+ self.post_init()
680
+
681
+ @auto_docstring
682
+ def forward(
683
+ self,
684
+ input_ids: Optional[torch.LongTensor] = None,
685
+ attention_mask: Optional[torch.FloatTensor] = None,
686
+ head_mask: Optional[torch.FloatTensor] = None,
687
+ past_key_values: Optional[Cache] = None,
688
+ inputs_embeds: Optional[torch.FloatTensor] = None,
689
+ labels: Optional[torch.LongTensor] = None,
690
+ use_cache: Optional[bool] = None,
691
+ position_ids: Optional[torch.LongTensor] = None,
692
+ output_attentions: Optional[bool] = None,
693
+ output_hidden_states: Optional[bool] = None,
694
+ return_dict: Optional[bool] = None,
695
+ cache_position: Optional[torch.Tensor] = None,
696
+ logits_to_keep: Union[int, torch.Tensor] = 0,
697
+ ) -> Union[tuple, SequenceClassifierOutputWithPast]:
698
+ r"""
699
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
700
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
701
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
702
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
703
+ """
704
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
705
+
706
+ transformer_outputs = self.biogpt(
707
+ input_ids,
708
+ past_key_values=past_key_values,
709
+ attention_mask=attention_mask,
710
+ head_mask=head_mask,
711
+ inputs_embeds=inputs_embeds,
712
+ use_cache=use_cache,
713
+ position_ids=position_ids,
714
+ output_attentions=output_attentions,
715
+ output_hidden_states=output_hidden_states,
716
+ return_dict=return_dict,
717
+ cache_position=cache_position,
718
+ )
719
+ hidden_states = transformer_outputs[0]
720
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
721
+ logits = self.score(hidden_states[:, slice_indices, :])
722
+
723
+ if input_ids is not None:
724
+ batch_size, sequence_length = input_ids.shape[:2]
725
+ else:
726
+ batch_size, sequence_length = inputs_embeds.shape[:2]
727
+
728
+ if self.config.pad_token_id is None:
729
+ sequence_length = -1
730
+ else:
731
+ if input_ids is not None:
732
+ sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
733
+ else:
734
+ sequence_length = -1
735
+ logger.warning_once(
736
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
737
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
738
+ )
739
+
740
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length]
741
+
742
+ loss = None
743
+ if labels is not None:
744
+ if self.config.problem_type is None:
745
+ if self.num_labels == 1:
746
+ self.config.problem_type = "regression"
747
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
748
+ self.config.problem_type = "single_label_classification"
749
+ else:
750
+ self.config.problem_type = "multi_label_classification"
751
+
752
+ if self.config.problem_type == "regression":
753
+ loss_fct = MSELoss()
754
+ if self.num_labels == 1:
755
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
756
+ else:
757
+ loss = loss_fct(pooled_logits, labels)
758
+ elif self.config.problem_type == "single_label_classification":
759
+ loss_fct = CrossEntropyLoss()
760
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
761
+ elif self.config.problem_type == "multi_label_classification":
762
+ loss_fct = BCEWithLogitsLoss()
763
+ loss = loss_fct(pooled_logits, labels)
764
+ if not return_dict:
765
+ output = (pooled_logits,) + transformer_outputs[1:]
766
+ return ((loss,) + output) if loss is not None else output
767
+
768
+ return SequenceClassifierOutputWithPast(
769
+ loss=loss,
770
+ logits=pooled_logits,
771
+ past_key_values=transformer_outputs.past_key_values,
772
+ hidden_states=transformer_outputs.hidden_states,
773
+ attentions=transformer_outputs.attentions,
774
+ )
775
+
776
+ def get_input_embeddings(self):
777
+ return self.biogpt.embed_tokens
778
+
779
+ def set_input_embeddings(self, value):
780
+ self.biogpt.embed_tokens = value
781
+
782
+
783
+ __all__ = [
784
+ "BioGptForCausalLM",
785
+ "BioGptForTokenClassification",
786
+ "BioGptForSequenceClassification",
787
+ "BioGptModel",
788
+ "BioGptPreTrainedModel",
789
+ ]
venv/lib/python3.13/site-packages/transformers/models/biogpt/tokenization_biogpt.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for BioGPT."""
16
+
17
+ import json
18
+ import os
19
+ from typing import Optional
20
+
21
+ from ...tokenization_utils import PreTrainedTokenizer
22
+ from ...utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ VOCAB_FILES_NAMES = {
28
+ "vocab_file": "vocab.json",
29
+ "merges_file": "merges.txt",
30
+ }
31
+
32
+
33
+ def get_pairs(word):
34
+ """
35
+ Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
36
+ strings)
37
+ """
38
+ pairs = set()
39
+ prev_char = word[0]
40
+ for char in word[1:]:
41
+ pairs.add((prev_char, char))
42
+ prev_char = char
43
+ return pairs
44
+
45
+
46
+ class BioGptTokenizer(PreTrainedTokenizer):
47
+ """
48
+ Construct an FAIRSEQ Transformer tokenizer. Moses tokenization followed by Byte-Pair Encoding.
49
+
50
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
51
+ this superclass for more information regarding those methods.
52
+
53
+ Args:
54
+ vocab_file (`str`):
55
+ Path to the vocabulary file.
56
+ merges_file (`str`):
57
+ Merges file.
58
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
59
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
60
+ token instead.
61
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
62
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
63
+
64
+ <Tip>
65
+
66
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
67
+ sequence. The token used is the `cls_token`.
68
+
69
+ </Tip>
70
+
71
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
72
+ The end of sequence token.
73
+
74
+ <Tip>
75
+
76
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
77
+ The token used is the `sep_token`.
78
+
79
+ </Tip>
80
+
81
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
82
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
83
+ sequence classification or for a text and a question for question answering. It is also used as the last
84
+ token of a sequence built with special tokens.
85
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
86
+ The token used for padding, for example when batching sequences of different lengths.
87
+ """
88
+
89
+ vocab_files_names = VOCAB_FILES_NAMES
90
+ model_input_names = ["input_ids", "attention_mask"]
91
+
92
+ def __init__(
93
+ self,
94
+ vocab_file,
95
+ merges_file,
96
+ unk_token="<unk>",
97
+ bos_token="<s>",
98
+ eos_token="</s>",
99
+ sep_token="</s>",
100
+ pad_token="<pad>",
101
+ **kwargs,
102
+ ):
103
+ try:
104
+ import sacremoses
105
+ except ImportError:
106
+ raise ImportError(
107
+ "You need to install sacremoses to use BioGptTokenizer. "
108
+ "See https://pypi.org/project/sacremoses/ for installation."
109
+ )
110
+
111
+ self.lang = "en"
112
+ self.sm = sacremoses
113
+ # cache of sm.MosesTokenizer instance
114
+ self.cache_moses_tokenizer = {}
115
+ self.cache_moses_detokenizer = {}
116
+
117
+ """ Initialisation"""
118
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
119
+ self.encoder = json.load(vocab_handle)
120
+ self.decoder = {v: k for k, v in self.encoder.items()}
121
+ with open(merges_file, encoding="utf-8") as merges_handle:
122
+ merges = merges_handle.read().split("\n")[:-1]
123
+ merges = [tuple(merge.split()[:2]) for merge in merges]
124
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
125
+ self.cache = {}
126
+
127
+ super().__init__(
128
+ bos_token=bos_token,
129
+ eos_token=eos_token,
130
+ sep_token=sep_token,
131
+ unk_token=unk_token,
132
+ pad_token=pad_token,
133
+ **kwargs,
134
+ )
135
+
136
+ @property
137
+ def vocab_size(self):
138
+ """Returns vocab size"""
139
+ return len(self.encoder)
140
+
141
+ def get_vocab(self):
142
+ return dict(self.encoder, **self.added_tokens_encoder)
143
+
144
+ def moses_tokenize(self, text, lang):
145
+ if lang not in self.cache_moses_tokenizer:
146
+ moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
147
+ self.cache_moses_tokenizer[lang] = moses_tokenizer
148
+ return self.cache_moses_tokenizer[lang].tokenize(
149
+ text, aggressive_dash_splits=True, return_str=False, escape=True
150
+ )
151
+
152
+ def moses_detokenize(self, tokens, lang):
153
+ if lang not in self.cache_moses_detokenizer:
154
+ moses_detokenizer = self.sm.MosesDetokenizer(lang=lang)
155
+ self.cache_moses_detokenizer[lang] = moses_detokenizer
156
+ return self.cache_moses_detokenizer[lang].detokenize(tokens)
157
+
158
+ def bpe(self, token):
159
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
160
+ if token in self.cache:
161
+ return self.cache[token]
162
+ pairs = get_pairs(word)
163
+
164
+ if not pairs:
165
+ return token + "</w>"
166
+
167
+ while True:
168
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
169
+ if bigram not in self.bpe_ranks:
170
+ break
171
+ first, second = bigram
172
+ new_word = []
173
+ i = 0
174
+ while i < len(word):
175
+ try:
176
+ j = word.index(first, i)
177
+ except ValueError:
178
+ new_word.extend(word[i:])
179
+ break
180
+ else:
181
+ new_word.extend(word[i:j])
182
+ i = j
183
+
184
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
185
+ new_word.append(first + second)
186
+ i += 2
187
+ else:
188
+ new_word.append(word[i])
189
+ i += 1
190
+ new_word = tuple(new_word)
191
+ word = new_word
192
+ if len(word) == 1:
193
+ break
194
+ else:
195
+ pairs = get_pairs(word)
196
+ word = " ".join(word)
197
+ if word == "\n </w>":
198
+ word = "\n</w>"
199
+ self.cache[token] = word
200
+ return word
201
+
202
+ def _tokenize(self, text, bypass_tokenizer=False):
203
+ """Returns a tokenized string."""
204
+ if bypass_tokenizer:
205
+ text = text.split()
206
+ else:
207
+ text = self.moses_tokenize(text, self.lang)
208
+
209
+ split_tokens = []
210
+ for token in text:
211
+ if token:
212
+ split_tokens.extend(list(self.bpe(token).split(" ")))
213
+
214
+ return split_tokens
215
+
216
+ def _convert_token_to_id(self, token):
217
+ """Converts a token (str) in an id using the vocab."""
218
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
219
+
220
+ def _convert_id_to_token(self, index):
221
+ """Converts an index (integer) in a token (str) using the vocab."""
222
+ return self.decoder.get(index, self.unk_token)
223
+
224
+ def convert_tokens_to_string(self, tokens):
225
+ """Converts a sequence of tokens (string) in a single string."""
226
+ # remove BPE
227
+ tokens = [t.replace(" ", "").replace("</w>", " ") for t in tokens]
228
+ tokens = "".join(tokens).split()
229
+ # detokenize
230
+ text = self.moses_detokenize(tokens, self.lang)
231
+ return text
232
+
233
+ def build_inputs_with_special_tokens(
234
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
235
+ ) -> list[int]:
236
+ """
237
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
238
+ adding special tokens. A BioGPT sequence has the following format:
239
+
240
+ - single sequence: `</s> X `
241
+ - pair of sequences: `</s> A </s> B `
242
+
243
+ Args:
244
+ token_ids_0 (`List[int]`):
245
+ List of IDs to which the special tokens will be added.
246
+ token_ids_1 (`List[int]`, *optional*):
247
+ Optional second list of IDs for sequence pairs.
248
+
249
+ Returns:
250
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
251
+ """
252
+ if token_ids_1 is None:
253
+ return [self.sep_token_id] + token_ids_0
254
+ sep = [self.sep_token_id]
255
+ return sep + token_ids_0 + sep + token_ids_1
256
+
257
+ def get_special_tokens_mask(
258
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
259
+ ) -> list[int]:
260
+ """
261
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
262
+ special tokens using the tokenizer `prepare_for_model` method.
263
+
264
+ Args:
265
+ token_ids_0 (`List[int]`):
266
+ List of IDs.
267
+ token_ids_1 (`List[int]`, *optional*):
268
+ Optional second list of IDs for sequence pairs.
269
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
270
+ Whether or not the token list is already formatted with special tokens for the model.
271
+
272
+ Returns:
273
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
274
+ """
275
+ if already_has_special_tokens:
276
+ return super().get_special_tokens_mask(
277
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
278
+ )
279
+ # no bos used in fairseq
280
+ if token_ids_1 is not None:
281
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
282
+ return [1] + ([0] * len(token_ids_0))
283
+
284
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
285
+ if not os.path.isdir(save_directory):
286
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
287
+ return
288
+ vocab_file = os.path.join(
289
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
290
+ )
291
+ merge_file = os.path.join(
292
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
293
+ )
294
+
295
+ with open(vocab_file, "w", encoding="utf-8") as f:
296
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
297
+
298
+ index = 0
299
+ with open(merge_file, "w", encoding="utf-8") as writer:
300
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
301
+ if index != token_index:
302
+ logger.warning(
303
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
304
+ " Please check that the tokenizer is not corrupted!"
305
+ )
306
+ index = token_index
307
+ writer.write(" ".join(bpe_tokens) + "\n")
308
+ index += 1
309
+
310
+ return vocab_file, merge_file
311
+
312
+ def __getstate__(self):
313
+ state = self.__dict__.copy()
314
+ state["sm"] = None
315
+ return state
316
+
317
+ def __setstate__(self, d):
318
+ self.__dict__ = d
319
+
320
+ try:
321
+ import sacremoses
322
+ except ImportError:
323
+ raise ImportError(
324
+ "You need to install sacremoses to use XLMTokenizer. "
325
+ "See https://pypi.org/project/sacremoses/ for installation."
326
+ )
327
+
328
+ self.sm = sacremoses
329
+
330
+
331
+ __all__ = ["BioGptTokenizer"]
venv/lib/python3.13/site-packages/transformers/models/bit/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_bit import *
22
+ from .image_processing_bit import *
23
+ from .image_processing_bit_fast import *
24
+ from .modeling_bit import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)