flaubert commited on
Commit
811127f
·
verified ·
1 Parent(s): 359a52d

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "",
3
+ "activation_dropout": 0.0,
4
+ "add_cross_attention": false,
5
+ "architectures": [
6
+ "Data2Vec2MultiModel"
7
+ ],
8
+ "attention_dropout": 0.1,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_data2vec2.Data2Vec2MultiConfig",
11
+ "AutoModel": "modeling_data2vec2.Data2Vec2MultiModel"
12
+ },
13
+ "bad_words_ids": null,
14
+ "begin_suppress_tokens": null,
15
+ "bos_token_id": null,
16
+ "chunk_size_feed_forward": 0,
17
+ "clone_batch": 8,
18
+ "cross_attention_hidden_size": null,
19
+ "decoder_start_token_id": null,
20
+ "depth": 12,
21
+ "diversity_penalty": 0.0,
22
+ "do_sample": false,
23
+ "dropout_input": 0.0,
24
+ "dtype": "float32",
25
+ "early_stopping": false,
26
+ "embed_dim": 768,
27
+ "encoder_dropout": 0.1,
28
+ "encoder_no_repeat_ngram_size": 0,
29
+ "end_drop_path_rate": 0.0,
30
+ "end_of_block_targets": false,
31
+ "eos_token_id": null,
32
+ "exponential_decay_length_penalty": null,
33
+ "finetuning_task": null,
34
+ "forced_bos_token_id": null,
35
+ "forced_eos_token_id": null,
36
+ "hidden_size": 768,
37
+ "id2label": {
38
+ "0": "LABEL_0",
39
+ "1": "LABEL_1"
40
+ },
41
+ "is_decoder": false,
42
+ "is_encoder_decoder": false,
43
+ "label2id": {
44
+ "LABEL_0": 0,
45
+ "LABEL_1": 1
46
+ },
47
+ "layer_norm_first": false,
48
+ "layerdrop": 0.0,
49
+ "length_penalty": 1.0,
50
+ "log_norms": true,
51
+ "max_length": 20,
52
+ "min_length": 0,
53
+ "mlp_ratio": 4.0,
54
+ "modalities": {
55
+ "_name_or_path": "",
56
+ "add_cross_attention": false,
57
+ "architectures": null,
58
+ "audio": {
59
+ "_name_or_path": "",
60
+ "add_cross_attention": false,
61
+ "add_masks": false,
62
+ "alibi_max_pos": null,
63
+ "alibi_scale": 1.0,
64
+ "architectures": null,
65
+ "bad_words_ids": null,
66
+ "begin_suppress_tokens": null,
67
+ "bos_token_id": null,
68
+ "chunk_size_feed_forward": 0,
69
+ "conv_pos_depth": 5,
70
+ "conv_pos_groups": 16,
71
+ "conv_pos_pre_ln": false,
72
+ "conv_pos_width": 95,
73
+ "cross_attention_hidden_size": null,
74
+ "decoder_start_token_id": null,
75
+ "diversity_penalty": 0.0,
76
+ "do_sample": false,
77
+ "dtype": null,
78
+ "early_stopping": false,
79
+ "encoder_no_repeat_ngram_size": 0,
80
+ "encoder_zero_mask": true,
81
+ "end_drop_path_rate": 0.0,
82
+ "eos_token_id": null,
83
+ "exponential_decay_length_penalty": null,
84
+ "extractor_mode": "layer_norm",
85
+ "feature_encoder_spec": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
86
+ "finetuning_task": null,
87
+ "forced_bos_token_id": null,
88
+ "forced_eos_token_id": null,
89
+ "id2label": {
90
+ "0": "LABEL_0",
91
+ "1": "LABEL_1"
92
+ },
93
+ "init_extra_token_zero": true,
94
+ "inverse_mask": false,
95
+ "is_decoder": false,
96
+ "is_encoder_decoder": false,
97
+ "keep_masked_pct": 0.0,
98
+ "label2id": {
99
+ "LABEL_0": 0,
100
+ "LABEL_1": 1
101
+ },
102
+ "learned_alibi": false,
103
+ "learned_alibi_scale": false,
104
+ "learned_alibi_scale_per_head": false,
105
+ "learned_alibi_scale_per_layer": false,
106
+ "length_penalty": 1.0,
107
+ "local_grad_mult": 1.0,
108
+ "mask_channel_length": 64,
109
+ "mask_channel_prob": 0.0,
110
+ "mask_dropout": 0.0,
111
+ "mask_length": 5,
112
+ "mask_noise_std": 0.01,
113
+ "mask_prob": 0.7,
114
+ "mask_prob_adjust": 0.0,
115
+ "mask_prob_min": null,
116
+ "max_length": 20,
117
+ "min_length": 0,
118
+ "model_depth": 12,
119
+ "model_type": "",
120
+ "no_repeat_ngram_size": 0,
121
+ "num_alibi_heads": 12,
122
+ "num_beam_groups": 1,
123
+ "num_beams": 1,
124
+ "num_extra_tokens": 0,
125
+ "num_return_sequences": 1,
126
+ "output_attentions": false,
127
+ "output_hidden_states": false,
128
+ "output_scores": false,
129
+ "pad_token_id": null,
130
+ "prefix": null,
131
+ "prenet_depth": 4,
132
+ "prenet_dropout": 0.0,
133
+ "prenet_layerdrop": 0.0,
134
+ "problem_type": null,
135
+ "pruned_heads": {},
136
+ "remove_invalid_values": false,
137
+ "remove_masks": false,
138
+ "repetition_penalty": 1.0,
139
+ "return_dict": true,
140
+ "return_dict_in_generate": false,
141
+ "sep_token_id": null,
142
+ "start_drop_path_rate": 0.0,
143
+ "suppress_tokens": null,
144
+ "task_specific_params": null,
145
+ "temperature": 1.0,
146
+ "tie_encoder_decoder": false,
147
+ "tie_word_embeddings": true,
148
+ "tokenizer_class": null,
149
+ "top_k": 50,
150
+ "top_p": 1.0,
151
+ "torchscript": false,
152
+ "type": "AUDIO",
153
+ "typical_p": 1.0,
154
+ "use_alibi_encoder": false
155
+ },
156
+ "bad_words_ids": null,
157
+ "begin_suppress_tokens": null,
158
+ "bos_token_id": null,
159
+ "chunk_size_feed_forward": 0,
160
+ "cross_attention_hidden_size": null,
161
+ "decoder_start_token_id": null,
162
+ "diversity_penalty": 0.0,
163
+ "do_sample": false,
164
+ "dtype": null,
165
+ "early_stopping": false,
166
+ "encoder_no_repeat_ngram_size": 0,
167
+ "eos_token_id": null,
168
+ "exponential_decay_length_penalty": null,
169
+ "finetuning_task": null,
170
+ "forced_bos_token_id": null,
171
+ "forced_eos_token_id": null,
172
+ "id2label": {
173
+ "0": "LABEL_0",
174
+ "1": "LABEL_1"
175
+ },
176
+ "is_decoder": false,
177
+ "is_encoder_decoder": false,
178
+ "label2id": {
179
+ "LABEL_0": 0,
180
+ "LABEL_1": 1
181
+ },
182
+ "length_penalty": 1.0,
183
+ "max_length": 20,
184
+ "min_length": 0,
185
+ "model_type": "",
186
+ "no_repeat_ngram_size": 0,
187
+ "num_beam_groups": 1,
188
+ "num_beams": 1,
189
+ "num_return_sequences": 1,
190
+ "output_attentions": false,
191
+ "output_hidden_states": false,
192
+ "output_scores": false,
193
+ "pad_token_id": null,
194
+ "prefix": null,
195
+ "problem_type": null,
196
+ "pruned_heads": {},
197
+ "remove_invalid_values": false,
198
+ "repetition_penalty": 1.0,
199
+ "return_dict": true,
200
+ "return_dict_in_generate": false,
201
+ "sep_token_id": null,
202
+ "suppress_tokens": null,
203
+ "task_specific_params": null,
204
+ "temperature": 1.0,
205
+ "text": {
206
+ "_name_or_path": "",
207
+ "add_cross_attention": false,
208
+ "add_masks": false,
209
+ "alibi_max_pos": null,
210
+ "alibi_scale": 1.0,
211
+ "architectures": null,
212
+ "bad_words_ids": null,
213
+ "begin_suppress_tokens": null,
214
+ "bos_token_id": 0,
215
+ "chunk_size_feed_forward": 0,
216
+ "cross_attention_hidden_size": null,
217
+ "decoder_start_token_id": null,
218
+ "diversity_penalty": 0.0,
219
+ "do_sample": false,
220
+ "dropout": 0.1,
221
+ "dtype": null,
222
+ "early_stopping": false,
223
+ "encoder_no_repeat_ngram_size": 0,
224
+ "encoder_zero_mask": true,
225
+ "end_drop_path_rate": 0.0,
226
+ "eos_token_id": 2,
227
+ "exponential_decay_length_penalty": null,
228
+ "finetuning_task": null,
229
+ "forced_bos_token_id": null,
230
+ "forced_eos_token_id": null,
231
+ "id2label": {
232
+ "0": "LABEL_0",
233
+ "1": "LABEL_1"
234
+ },
235
+ "init_extra_token_zero": true,
236
+ "inverse_mask": false,
237
+ "is_decoder": false,
238
+ "is_encoder_decoder": false,
239
+ "keep_masked_pct": 0.0,
240
+ "label2id": {
241
+ "LABEL_0": 0,
242
+ "LABEL_1": 1
243
+ },
244
+ "layernorm_embedding": true,
245
+ "learned_alibi": false,
246
+ "learned_alibi_scale": false,
247
+ "learned_alibi_scale_per_head": false,
248
+ "learned_alibi_scale_per_layer": false,
249
+ "learned_pos": true,
250
+ "length_penalty": 1.0,
251
+ "local_grad_mult": 1.0,
252
+ "mask_channel_length": 64,
253
+ "mask_channel_prob": 0.0,
254
+ "mask_dropout": 0.0,
255
+ "mask_length": 3,
256
+ "mask_noise_std": 0.01,
257
+ "mask_prob": 0.6,
258
+ "mask_prob_adjust": 0.0,
259
+ "mask_prob_min": null,
260
+ "max_length": 20,
261
+ "max_source_positions": 512,
262
+ "min_length": 0,
263
+ "model_depth": 12,
264
+ "model_type": "",
265
+ "no_repeat_ngram_size": 0,
266
+ "no_scale_embedding": true,
267
+ "no_token_positional_embeddings": false,
268
+ "num_alibi_heads": 12,
269
+ "num_beam_groups": 1,
270
+ "num_beams": 1,
271
+ "num_extra_tokens": 0,
272
+ "num_return_sequences": 1,
273
+ "output_attentions": false,
274
+ "output_hidden_states": false,
275
+ "output_scores": false,
276
+ "pad_token_id": 1,
277
+ "prefix": null,
278
+ "prenet_depth": 0,
279
+ "prenet_dropout": 0.0,
280
+ "prenet_layerdrop": 0.0,
281
+ "problem_type": null,
282
+ "pruned_heads": {},
283
+ "remove_invalid_values": false,
284
+ "remove_masks": false,
285
+ "repetition_penalty": 1.0,
286
+ "return_dict": true,
287
+ "return_dict_in_generate": false,
288
+ "sep_token_id": null,
289
+ "start_drop_path_rate": 0.0,
290
+ "suppress_tokens": null,
291
+ "task_specific_params": null,
292
+ "temperature": 1.0,
293
+ "tie_encoder_decoder": false,
294
+ "tie_word_embeddings": true,
295
+ "tokenizer_class": null,
296
+ "top_k": 50,
297
+ "top_p": 1.0,
298
+ "torchscript": false,
299
+ "type": "TEXT",
300
+ "typical_p": 1.0,
301
+ "unk_token_id": 3,
302
+ "use_alibi_encoder": false,
303
+ "vocab_size": 50368
304
+ },
305
+ "tie_encoder_decoder": false,
306
+ "tie_word_embeddings": true,
307
+ "tokenizer_class": null,
308
+ "top_k": 50,
309
+ "top_p": 1.0,
310
+ "torchscript": false,
311
+ "typical_p": 1.0
312
+ },
313
+ "model_type": "data2vec2",
314
+ "n_layers": 12,
315
+ "no_repeat_ngram_size": 0,
316
+ "norm_affine": true,
317
+ "norm_eps": 1e-05,
318
+ "num_beam_groups": 1,
319
+ "num_beams": 1,
320
+ "num_heads": 12,
321
+ "num_hidden_layers": 12,
322
+ "num_layers": 12,
323
+ "num_return_sequences": 1,
324
+ "output_attentions": false,
325
+ "output_hidden_states": false,
326
+ "output_scores": false,
327
+ "pad_token_id": null,
328
+ "post_mlp_drop": 0.1,
329
+ "prefix": null,
330
+ "problem_type": null,
331
+ "pruned_heads": {},
332
+ "remove_invalid_values": false,
333
+ "repetition_penalty": 1.0,
334
+ "return_dict": true,
335
+ "return_dict_in_generate": false,
336
+ "sep_token_id": null,
337
+ "start_drop_path_rate": 0.0,
338
+ "supported_modality": "TEXT",
339
+ "suppress_tokens": null,
340
+ "task_specific_params": null,
341
+ "temperature": 1.0,
342
+ "tie_encoder_decoder": false,
343
+ "tie_word_embeddings": true,
344
+ "tokenizer_class": null,
345
+ "top_k": 50,
346
+ "top_p": 1.0,
347
+ "torchscript": false,
348
+ "transformers_version": "4.57.0.dev0",
349
+ "typical_p": 1.0
350
+ }
configuration_data2vec2.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ #
9
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ """ Data2Vec2 multi configuration"""
25
+
26
+ import os
27
+ from typing import Union, Dict, Any, Optional
28
+ from transformers.dynamic_module_utils import custom_object_save
29
+ from transformers.utils import logging
30
+ from transformers.configuration_utils import PretrainedConfig, CONFIG_NAME
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class MyPretrainedConfig(PretrainedConfig):
37
+ def __init__(self, **kwargs):
38
+ super().__init__(**kwargs)
39
+
40
+ def to_json_string(self, use_diff: bool = False) -> str:
41
+ return super().to_json_string(use_diff)
42
+
43
+ def update(self, config_dict):
44
+ for key, value in config_dict.items():
45
+ if not hasattr(self, key):
46
+ continue
47
+ if isinstance(getattr(self, key), MyPretrainedConfig):
48
+ getattr(self, key).update(config_dict[key])
49
+ else:
50
+ setattr(self, key, value)
51
+
52
+ # Copied from the parent class, only changed use_diff from True to False to correctly save nested config class
53
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
54
+ """
55
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
56
+ [`~PretrainedConfig.from_pretrained`] class method.
57
+
58
+ Args:
59
+ save_directory (`str` or `os.PathLike`):
60
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
61
+ push_to_hub (`bool`, *optional*, defaults to `False`):
62
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
63
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
64
+ namespace).
65
+ kwargs (`Dict[str, Any]`, *optional*):
66
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
67
+ """
68
+ self._set_token_in_kwargs(kwargs)
69
+
70
+ if os.path.isfile(save_directory):
71
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
72
+
73
+ non_default_generation_parameters = {}
74
+ for parameter_name, default_value in self._get_global_generation_defaults().items():
75
+ if hasattr(self, parameter_name) and getattr(self, parameter_name) != default_value:
76
+ non_default_generation_parameters[parameter_name] = getattr(self, parameter_name)
77
+ if len(non_default_generation_parameters) > 0:
78
+ logger.warning(
79
+ "Some non-default generation parameters are set in the model config. These should go into a "
80
+ "GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) "
81
+ "instead. This warning will be raised to an exception in v4.41.\n"
82
+ f"Non-default generation parameters: {str(non_default_generation_parameters)}"
83
+ )
84
+
85
+ os.makedirs(save_directory, exist_ok=True)
86
+
87
+ if push_to_hub:
88
+ commit_message = kwargs.pop("commit_message", None)
89
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
90
+ repo_id = self._create_repo(repo_id, **kwargs)
91
+ files_timestamps = self._get_files_timestamps(save_directory)
92
+
93
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
94
+ # loaded from the Hub.
95
+ if self._auto_class is not None:
96
+ custom_object_save(self, save_directory, config=self)
97
+
98
+ # If we save using the predefined names, we can load using `from_pretrained`
99
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
100
+
101
+ self.to_json_file(output_config_file, use_diff=False)
102
+ logger.info(f"Configuration saved in {output_config_file}")
103
+
104
+ if push_to_hub:
105
+ self._upload_modified_files(
106
+ save_directory,
107
+ repo_id,
108
+ files_timestamps,
109
+ commit_message=commit_message,
110
+ token=kwargs.get("token"),
111
+ )
112
+
113
+ # Copied from the parent class, change the instantiation and updating of class from config_dict to correctly load nested config
114
+ @classmethod
115
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "MyPretrainedConfig":
116
+ """
117
+ Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
118
+
119
+ Args:
120
+ config_dict (`Dict[str, Any]`):
121
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
122
+ retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
123
+ kwargs (`Dict[str, Any]`):
124
+ Additional parameters from which to initialize the configuration object.
125
+
126
+ Returns:
127
+ [`PretrainedConfig`]: The configuration object instantiated from those parameters.
128
+ """
129
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
130
+ # Those arguments may be passed along for our internal telemetry.
131
+ # We remove them so they don't appear in `return_unused_kwargs`.
132
+ kwargs.pop("_from_auto", None)
133
+ kwargs.pop("_from_pipeline", None)
134
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
135
+ if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
136
+ kwargs["_commit_hash"] = config_dict["_commit_hash"]
137
+
138
+ # We remove it from kwargs so that it does not appear in `return_unused_kwargs`.
139
+ config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
140
+
141
+ # config = cls(**config_dict)
142
+ # My updated config
143
+ config = cls()
144
+ for key, value in config_dict.items():
145
+ if not hasattr(config, key):
146
+ continue
147
+ if isinstance(getattr(config, key), MyPretrainedConfig):
148
+ getattr(config, key).update(config_dict[key])
149
+ else:
150
+ setattr(config, key, value)
151
+
152
+
153
+ if hasattr(config, "pruned_heads"):
154
+ config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
155
+
156
+ # Update config with kwargs if needed
157
+ if "num_labels" in kwargs and "id2label" in kwargs:
158
+ num_labels = kwargs["num_labels"]
159
+ id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
160
+ if len(id2label) != num_labels:
161
+ raise ValueError(
162
+ f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
163
+ f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
164
+ "one of them."
165
+ )
166
+ to_remove = []
167
+ for key, value in kwargs.items():
168
+ if hasattr(config, key):
169
+ current_attr = getattr(config, key)
170
+ # To authorize passing a custom subconfig as kwarg in models that have nested configs.
171
+ if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
172
+ value = current_attr.__class__(**value)
173
+ setattr(config, key, value)
174
+ if key != "torch_dtype":
175
+ to_remove.append(key)
176
+ for key in to_remove:
177
+ kwargs.pop(key, None)
178
+
179
+ logger.info(f"Model config {config}")
180
+ if return_unused_kwargs:
181
+ return config, kwargs
182
+ else:
183
+ return config
184
+
185
+
186
+ class D2v2ModalityConfig(MyPretrainedConfig):
187
+ def __init__(
188
+ self,
189
+ type="AUDIO",
190
+ prenet_depth=4,
191
+ prenet_layerdrop=0,
192
+ prenet_dropout=0.0,
193
+ start_drop_path_rate=0.0,
194
+ end_drop_path_rate=0.0,
195
+ num_extra_tokens=0,
196
+ init_extra_token_zero=True,
197
+ mask_noise_std=0.01,
198
+ mask_prob_min=None,
199
+ mask_prob=0.7,
200
+ inverse_mask=False,
201
+ mask_prob_adjust=0.0,
202
+ keep_masked_pct=0.0,
203
+ mask_length=5,
204
+ add_masks=False,
205
+ remove_masks=False,
206
+ mask_dropout=0.0,
207
+ encoder_zero_mask=True,
208
+ mask_channel_prob=0.0,
209
+ mask_channel_length=64,
210
+ local_grad_mult=1.0,
211
+ use_alibi_encoder=False,
212
+ alibi_scale=1.0,
213
+ learned_alibi=False,
214
+ alibi_max_pos=None,
215
+ learned_alibi_scale=False,
216
+ learned_alibi_scale_per_head=False,
217
+ learned_alibi_scale_per_layer=False,
218
+ num_alibi_heads=12,
219
+ model_depth=12,
220
+ ema_local_encoder=False,
221
+ decoder=None,
222
+ **kwargs,
223
+ ):
224
+ super().__init__(**kwargs)
225
+ self.type = type
226
+ self.prenet_depth = prenet_depth
227
+ self.prenet_layerdrop = prenet_layerdrop
228
+ self.prenet_dropout = prenet_dropout
229
+ self.start_drop_path_rate = start_drop_path_rate
230
+ self.end_drop_path_rate = end_drop_path_rate
231
+ self.num_extra_tokens = num_extra_tokens
232
+ self.init_extra_token_zero = init_extra_token_zero
233
+ self.mask_noise_std = mask_noise_std
234
+ self.mask_prob_min = mask_prob_min
235
+ self.mask_prob = mask_prob
236
+ self.inverse_mask = inverse_mask
237
+ self.mask_prob_adjust = mask_prob_adjust
238
+ self.keep_masked_pct = keep_masked_pct
239
+ self.mask_length = mask_length
240
+ self.add_masks = add_masks
241
+ self.remove_masks = remove_masks
242
+ self.mask_dropout = mask_dropout
243
+ self.encoder_zero_mask = encoder_zero_mask
244
+ self.mask_channel_prob = mask_channel_prob
245
+ self.mask_channel_length = mask_channel_length
246
+ self.local_grad_mult = local_grad_mult
247
+ self.use_alibi_encoder = use_alibi_encoder
248
+ self.alibi_scale = alibi_scale
249
+ self.learned_alibi = learned_alibi
250
+ self.alibi_max_pos = alibi_max_pos
251
+ self.learned_alibi_scale = learned_alibi_scale
252
+ self.learned_alibi_scale_per_head = learned_alibi_scale_per_head
253
+ self.learned_alibi_scale_per_layer = learned_alibi_scale_per_layer
254
+ self.num_alibi_heads = num_alibi_heads
255
+ self.model_depth = model_depth
256
+
257
+
258
+ class D2v2AudioConfig(D2v2ModalityConfig):
259
+ """
260
+ Configuration including common args and args specific to audio-only pre-training
261
+ """
262
+ def __init__(
263
+ self,
264
+ extractor_mode="layer_norm",
265
+ feature_encoder_spec="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
266
+ conv_pos_width=95,
267
+ conv_pos_groups=16,
268
+ conv_pos_depth=5,
269
+ conv_pos_pre_ln=False,
270
+ **kwargs,
271
+ ):
272
+ super().__init__(type="AUDIO", **kwargs)
273
+ self.extractor_mode = extractor_mode
274
+ self.feature_encoder_spec = feature_encoder_spec
275
+ self.conv_pos_width = conv_pos_width
276
+ self.conv_pos_groups = conv_pos_groups
277
+ self.conv_pos_depth = conv_pos_depth
278
+ self.conv_pos_pre_ln = conv_pos_pre_ln
279
+
280
+
281
+ class D2v2TextConfig(D2v2ModalityConfig):
282
+ """
283
+ Configuration including common args and args specific to text-only pre-training
284
+ """
285
+ def __init__(
286
+ self,
287
+ vocab_size=50000,
288
+ unk_token_id=3,
289
+ bos_token_id=0,
290
+ eos_token_id=2,
291
+ pad_token_id=1,
292
+ max_source_positions=512,
293
+ learned_pos=True,
294
+ dropout=0.1,
295
+ no_scale_embedding=True,
296
+ layernorm_embedding=True,
297
+ no_token_positional_embeddings=False,
298
+ **kwargs,
299
+ ):
300
+ super().__init__(type="TEXT", **kwargs)
301
+ self.vocab_size = vocab_size
302
+ self.unk_token_id = unk_token_id
303
+ self.bos_token_id = bos_token_id
304
+ self.eos_token_id = eos_token_id
305
+ self.pad_token_id = pad_token_id
306
+ self.max_source_positions = max_source_positions
307
+ self.learned_pos = learned_pos
308
+ self.dropout = dropout
309
+ self.no_scale_embedding = no_scale_embedding
310
+ self.layernorm_embedding = layernorm_embedding
311
+ self.no_token_positional_embeddings = no_token_positional_embeddings
312
+
313
+
314
+ class D2v2ModalitiesConfig(MyPretrainedConfig):
315
+ def __init__(
316
+ self,
317
+ audio_config=D2v2AudioConfig(),
318
+ text_config=D2v2TextConfig(),
319
+ **kwargs
320
+ ):
321
+ super().__init__(**kwargs)
322
+ self.audio = audio_config
323
+ self.text = text_config
324
+
325
+
326
+ class Data2Vec2MultiConfig(MyPretrainedConfig):
327
+ r"""
328
+ This is the configuration class to store the configuration of a [`Data2Vec2MultiModel`]. It is used to instantiate
329
+ an Data2Vec2MultiModel model according to the specified arguments, defining the model architecture.
330
+
331
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
332
+ documentation from [`PretrainedConfig`] for more information.
333
+
334
+
335
+ Args:
336
+ depth (`int`, *optional*, defaults to 12):
337
+ Number of Transformer layers in the encoder.
338
+
339
+ Example:
340
+
341
+ ```python
342
+ >>> from transformers import Data2Vec2MultiConfig, Data2Vec2MultiModel
343
+
344
+ >>> # Initializing a Data2Vec2MultiConfig for audio
345
+ >>> configuration = Data2Vec2MultiConfig()
346
+
347
+ >>> # Initializing a model (with random weights) with the configuration
348
+ >>> model = Data2Vec2MultiModel(configuration)
349
+
350
+ >>> # Accessing the model configuration
351
+ >>> configuration = model.config
352
+ ```"""
353
+
354
+ model_type = "data2vec2"
355
+
356
+ def __init__(
357
+ self,
358
+ depth=12,
359
+ start_drop_path_rate=0.0,
360
+ end_drop_path_rate=0.0,
361
+ num_heads=12,
362
+ norm_eps=1e-5,
363
+ norm_affine=True,
364
+ encoder_dropout=0.1,
365
+ post_mlp_drop=0.1,
366
+ attention_dropout=0.1,
367
+ activation_dropout=0.0,
368
+ dropout_input=0.0,
369
+ layerdrop=0.0,
370
+ embed_dim=768,
371
+ mlp_ratio=4.0,
372
+ layer_norm_first=False,
373
+ end_of_block_targets=False,
374
+ clone_batch=1,
375
+ log_norms=True,
376
+ modalities=D2v2ModalitiesConfig(),
377
+ supported_modality="AUDIO",
378
+ **kwargs,
379
+ ):
380
+ super().__init__(**kwargs)
381
+
382
+ self.depth = depth
383
+ self.start_drop_path_rate = start_drop_path_rate
384
+ self.end_drop_path_rate = end_drop_path_rate
385
+
386
+ self.num_heads = num_heads
387
+ self.norm_eps = norm_eps
388
+ self.norm_affine = norm_affine
389
+ self.post_mlp_drop = post_mlp_drop
390
+ self.encoder_dropout = encoder_dropout
391
+ self.attention_dropout = attention_dropout
392
+ self.activation_dropout = activation_dropout
393
+ self.dropout_input = dropout_input
394
+ self.layerdrop = layerdrop
395
+ self.embed_dim = embed_dim
396
+ self.mlp_ratio = mlp_ratio
397
+
398
+ self.layer_norm_first = layer_norm_first
399
+ self.end_of_block_targets = end_of_block_targets
400
+ self.clone_batch = clone_batch
401
+ self.log_norms = log_norms
402
+
403
+ self.modalities = modalities
404
+ self.supported_modality = supported_modality
405
+
406
+ # Attributes for hopsparser
407
+ self.hidden_size = embed_dim
408
+ self.num_layers = depth
409
+ self.n_layers = depth
410
+ self.num_hidden_layers = depth
411
+
412
+ self.auto_map = {
413
+ 'AutoConfig': 'configuration_data2vec2.Data2Vec2MultiConfig',
414
+ 'AutoModel': 'modeling_data2vec2.Data2Vec2MultiModel',
415
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1f3a5b7e501b52e0d9d4ad11a9396cb2956e01e2972e4062c2dbb844c1419b7
3
+ size 496547472
modeling_data2vec2.py ADDED
@@ -0,0 +1,1466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ # Copyright 2022 the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ # Copyright from Fairseq
23
+
24
+ """ PyTorch Data2Vec2 Multi model."""
25
+ import math
26
+ import warnings
27
+ from typing import Optional, Tuple, Dict, List, Callable, Any
28
+ from functools import partial
29
+
30
+ import numpy as np
31
+
32
+ import torch
33
+ import torch.nn.functional as F
34
+ from torch import nn
35
+ from torch import Tensor
36
+
37
+ from transformers import PreTrainedModel
38
+ from transformers.modeling_outputs import (
39
+ Wav2Vec2BaseModelOutput,
40
+ )
41
+ from .configuration_data2vec2 import (
42
+ Data2Vec2MultiConfig,
43
+ D2v2ModalityConfig,
44
+ D2v2AudioConfig,
45
+ D2v2TextConfig,
46
+ )
47
+
48
+ from .utils_data2vec2 import (
49
+ _learned_alibi_bias,
50
+ gather_unmasked,
51
+ gather_unmasked_mask,
52
+ masked_alibi,
53
+ random_masking,
54
+ get_alibi_bias,
55
+ compute_mask_indices,
56
+ index_put,
57
+ MaskInfo, MaskSeed,
58
+ make_positions,
59
+ )
60
+
61
+
62
+ #################################################
63
+ ### modeling_data2vec2_base.py
64
+ # copied from fairseq.modules.grad_multiply
65
+ class GradMultiply(torch.autograd.Function):
66
+ @staticmethod
67
+ def forward(ctx, x, scale):
68
+ ctx.scale = scale
69
+ res = x.new(x)
70
+ return res
71
+
72
+ @staticmethod
73
+ def backward(ctx, grad):
74
+ return grad * ctx.scale, None
75
+
76
+
77
+ # Copied from fairseq.modules.transpose_last.py
78
+ class TransposeLast(nn.Module):
79
+ def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
80
+ super().__init__()
81
+ self.deconstruct_idx = deconstruct_idx
82
+ self.tranpose_dim = tranpose_dim
83
+
84
+ def forward(self, x):
85
+ if self.deconstruct_idx is not None:
86
+ x = x[self.deconstruct_idx]
87
+ return x.transpose(self.tranpose_dim, -1)
88
+
89
+
90
+ # Copied from fairseq.modules.layer_norm.py
91
+ class Fp32LayerNorm(nn.LayerNorm):
92
+ def __init__(self, *args, **kwargs):
93
+ super().__init__(*args, **kwargs)
94
+
95
+ def forward(self, input):
96
+ output = F.layer_norm(
97
+ input.float(),
98
+ self.normalized_shape,
99
+ self.weight.float() if self.weight is not None else None,
100
+ self.bias.float() if self.bias is not None else None,
101
+ self.eps,
102
+ )
103
+ return output.type_as(input)
104
+
105
+
106
+ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
107
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
108
+
109
+
110
+ # Copied from fairseq.modules.fp32_group_norm.py
111
+ class Fp32GroupNorm(nn.GroupNorm):
112
+ def __init__(self, *args, **kwargs):
113
+ super().__init__(*args, **kwargs)
114
+
115
+ def forward(self, input):
116
+ output = F.group_norm(
117
+ input.float(),
118
+ self.num_groups,
119
+ self.weight.float() if self.weight is not None else None,
120
+ self.bias.float() if self.bias is not None else None,
121
+ self.eps,
122
+ )
123
+ return output.type_as(input)
124
+
125
+
126
+ # Copied from fairseq.modules.same_pad.py
127
+ class SamePad(nn.Module):
128
+ def __init__(self, kernel_size, causal=False):
129
+ super().__init__()
130
+ if causal:
131
+ self.remove = kernel_size - 1
132
+ else:
133
+ self.remove = 1 if kernel_size % 2 == 0 else 0
134
+
135
+ def forward(self, x):
136
+ if self.remove > 0:
137
+ x = x[:, :, : -self.remove]
138
+ return x
139
+
140
+
141
+ # Copied from fairseq.models.wav2vec.wav2vec2.py
142
+ class ConvFeatureExtractionModel(nn.Module):
143
+ def __init__(
144
+ self,
145
+ conv_layers: List[Tuple[int, int, int]],
146
+ dropout: float = 0.0,
147
+ mode: str = "default",
148
+ conv_bias: bool = False,
149
+ ):
150
+ super().__init__()
151
+
152
+ assert mode in {"default", "layer_norm"}
153
+
154
+ def block(
155
+ n_in,
156
+ n_out,
157
+ k,
158
+ stride,
159
+ is_layer_norm=False,
160
+ is_group_norm=False,
161
+ conv_bias=False,
162
+ ):
163
+ def make_conv():
164
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
165
+ nn.init.kaiming_normal_(conv.weight)
166
+ return conv
167
+
168
+ assert (
169
+ is_layer_norm and is_group_norm
170
+ ) == False, "layer norm and group norm are exclusive"
171
+
172
+ if is_layer_norm:
173
+ return nn.Sequential(
174
+ make_conv(),
175
+ nn.Dropout(p=dropout),
176
+ nn.Sequential(
177
+ TransposeLast(),
178
+ Fp32LayerNorm(dim, elementwise_affine=True),
179
+ TransposeLast(),
180
+ ),
181
+ nn.GELU(),
182
+ )
183
+ elif is_group_norm:
184
+ return nn.Sequential(
185
+ make_conv(),
186
+ nn.Dropout(p=dropout),
187
+ Fp32GroupNorm(dim, dim, affine=True),
188
+ nn.GELU(),
189
+ )
190
+ else:
191
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
192
+
193
+ in_d = 1
194
+ self.conv_layers = nn.ModuleList()
195
+ for i, cl in enumerate(conv_layers):
196
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
197
+ (dim, k, stride) = cl
198
+
199
+ self.conv_layers.append(
200
+ block(
201
+ in_d,
202
+ dim,
203
+ k,
204
+ stride,
205
+ is_layer_norm=mode == "layer_norm",
206
+ is_group_norm=mode == "default" and i == 0,
207
+ conv_bias=conv_bias,
208
+ )
209
+ )
210
+ in_d = dim
211
+
212
+ def forward(self, x):
213
+
214
+ # BxT -> BxCxT
215
+ x = x.unsqueeze(1)
216
+
217
+ for conv in self.conv_layers:
218
+ x = conv(x)
219
+
220
+ return x
221
+
222
+
223
+ # copied from fairseq.examples.data2vec.models.modalities.modules
224
+ class AltAttention(nn.Module):
225
+ def __init__(
226
+ self,
227
+ dim,
228
+ num_heads=8,
229
+ qkv_bias=False,
230
+ qk_scale=None,
231
+ attn_drop=0.0,
232
+ proj_drop=0.0,
233
+ cosine_attention=False,
234
+ ):
235
+ super().__init__()
236
+ self.num_heads = num_heads
237
+ head_dim = dim // num_heads
238
+ self.scale = qk_scale or head_dim ** -0.5
239
+
240
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
241
+ # self.attn_drop = nn.Dropout(attn_drop)
242
+ self.attn_drop = attn_drop
243
+ self.proj = nn.Linear(dim, dim)
244
+ # self.proj_drop = nn.Dropout(proj_drop)
245
+ self.proj_drop = proj_drop
246
+
247
+ self.cosine_attention = cosine_attention
248
+
249
+ if cosine_attention:
250
+ self.logit_scale = nn.Parameter(
251
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
252
+ )
253
+
254
+ def forward(self, x, padding_mask=None, alibi_bias=None, fast=True):
255
+ B, N, C = x.shape
256
+ qkv = (
257
+ self.qkv(x)
258
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
259
+ .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
260
+ )
261
+ q, k, v = (
262
+ qkv[0],
263
+ qkv[1],
264
+ qkv[2],
265
+ ) # make torchscript happy (cannot use tensor as tuple)
266
+
267
+ dtype = q.dtype
268
+
269
+ if not fast:
270
+ if self.cosine_attention:
271
+ # cosine attention
272
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
273
+ logit_scale = torch.clamp(
274
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
275
+ ).exp()
276
+ attn = attn * logit_scale
277
+ else:
278
+ q = q * self.scale
279
+ attn = q @ k.transpose(-2, -1) # B x C//H x L x L
280
+
281
+ if alibi_bias is not None:
282
+ attn = attn.type_as(alibi_bias)
283
+ attn[:, : alibi_bias.size(1)] += alibi_bias
284
+
285
+ if padding_mask is not None and padding_mask.any():
286
+ attn = attn.masked_fill(
287
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
288
+ float("-inf"),
289
+ )
290
+
291
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
292
+ # attn = self.attn_drop(attn)
293
+ attn = F.dropout(attn, p=self.attn_drop)
294
+ x = (attn @ v).transpose(1, 2)
295
+ else:
296
+ # Using pytorch 2's sdpa
297
+ assert not self.cosine_attention, "Not support cosine attention yet"
298
+ # Integrate padding_mask and alibi_bias
299
+ if padding_mask is not None and padding_mask.any():
300
+ if alibi_bias is not None:
301
+ padding_mask = alibi_bias.masked_fill(
302
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
303
+ float("-inf"),
304
+ ).to(dtype=dtype)
305
+ else:
306
+ padding_mask = padding_mask.unsqueeze(1).unsqueeze(2).to(
307
+ torch.bool).to(dtype=dtype)
308
+ else:
309
+ if alibi_bias is not None:
310
+ padding_mask = alibi_bias.to(dtype=dtype)
311
+ else:
312
+ padding_mask = None
313
+
314
+ x = F.scaled_dot_product_attention(q, k, v,
315
+ attn_mask=padding_mask,
316
+ dropout_p=self.attn_drop if self.training else 0.0,
317
+ scale=self.scale).transpose(1, 2)
318
+
319
+ x = x.reshape(B, N, C)
320
+ x = self.proj(x)
321
+ x = F.dropout(x, p=self.proj_drop if self.training else 0.0)
322
+ return x
323
+
324
+
325
+ # copied from fairseq.examples.data2vec.models.modalities.modules.py
326
+ class AltBlock(nn.Module):
327
+ def __init__(
328
+ self,
329
+ dim,
330
+ num_heads,
331
+ mlp_ratio=4.0,
332
+ qkv_bias=False,
333
+ qk_scale=None,
334
+ drop=0.0,
335
+ attn_drop=0.0,
336
+ mlp_drop=0.0,
337
+ post_mlp_drop=0.0,
338
+ drop_path=0.0,
339
+ act_layer=nn.GELU,
340
+ norm_layer=nn.LayerNorm,
341
+ layer_norm_first=True,
342
+ ffn_targets=False,
343
+ cosine_attention=False,
344
+ ):
345
+ super().__init__()
346
+
347
+ self.layer_norm_first = layer_norm_first
348
+ self.ffn_targets = ffn_targets
349
+
350
+ from timm.models.vision_transformer import DropPath, Mlp
351
+
352
+ self.norm1 = norm_layer(dim)
353
+ self.attn = AltAttention(
354
+ dim,
355
+ num_heads=num_heads,
356
+ qkv_bias=qkv_bias,
357
+ qk_scale=qk_scale,
358
+ attn_drop=attn_drop,
359
+ proj_drop=drop,
360
+ cosine_attention=cosine_attention,
361
+ )
362
+
363
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
364
+ self.norm2 = norm_layer(dim)
365
+ mlp_hidden_dim = int(dim * mlp_ratio)
366
+ self.mlp = Mlp(
367
+ in_features=dim,
368
+ hidden_features=mlp_hidden_dim,
369
+ act_layer=act_layer,
370
+ drop=mlp_drop,
371
+ )
372
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
373
+
374
+ def forward(self, x, padding_mask=None, alibi_bias=None):
375
+ if self.layer_norm_first:
376
+ x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
377
+ r = x = self.mlp(self.norm2(x))
378
+ t = x
379
+ x = r + self.drop_path(self.post_mlp_dropout(x))
380
+ if not self.ffn_targets:
381
+ t = x
382
+ else:
383
+ x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
384
+ r = x = self.norm1(x)
385
+ x = self.mlp(x)
386
+ t = x
387
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
388
+ if not self.ffn_targets:
389
+ t = x
390
+
391
+ return x, t
392
+
393
+
394
+ # copied from fairseq.data2vec.models.modalities.modules
395
+ class BlockEncoder(nn.Module):
396
+ def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout):
397
+ super().__init__()
398
+ self.blocks = blocks
399
+ self.norm = norm_layer
400
+ self.layer_norm_first = layer_norm_first
401
+ self.layerdrop = layerdrop
402
+ self.dropout = nn.Dropout(dropout, inplace=True)
403
+
404
+ def forward(self, x, padding_mask, alibi_bias, alibi_scale):
405
+ if self.norm is not None and not self.layer_norm_first:
406
+ x = self.norm(x)
407
+
408
+ x = self.dropout(x)
409
+
410
+ for i, blk in enumerate(self.blocks):
411
+ if (
412
+ not self.training
413
+ or self.layerdrop == 0
414
+ or (np.random.random() > self.layerdrop)
415
+ ):
416
+ ab = alibi_bias
417
+ if ab is not None and alibi_scale is not None:
418
+ scale = (
419
+ alibi_scale[i]
420
+ if alibi_scale.size(0) > 1
421
+ else alibi_scale.squeeze(0)
422
+ )
423
+ ab = ab * scale.type_as(ab)
424
+ x, _ = blk(x, padding_mask, ab)
425
+
426
+ if self.norm is not None and self.layer_norm_first:
427
+ x = self.norm(x)
428
+
429
+ return x
430
+
431
+
432
+ class ModalitySpecificEncoder(nn.Module):
433
+ def __init__(
434
+ self,
435
+ modality_cfg: D2v2ModalityConfig,
436
+ embed_dim: int,
437
+ local_encoder: nn.Module,
438
+ project_features: nn.Module,
439
+ fixed_positional_encoder: Optional[nn.Module],
440
+ relative_positional_encoder: Optional[nn.Module],
441
+ context_encoder: nn.Module,
442
+ decoder: nn.Module,
443
+ get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]],
444
+ ):
445
+ super().__init__()
446
+
447
+ self.modality_cfg = modality_cfg
448
+ self.local_encoder = local_encoder
449
+ self.project_features = project_features
450
+ self.fixed_positional_encoder = fixed_positional_encoder
451
+ self.relative_positional_encoder = relative_positional_encoder
452
+ self.context_encoder = context_encoder
453
+
454
+ self.decoder = None
455
+ self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None
456
+
457
+ self.local_grad_mult = self.modality_cfg.local_grad_mult
458
+
459
+ self.extra_tokens = None
460
+ if modality_cfg.num_extra_tokens > 0:
461
+ self.extra_tokens = nn.Parameter(
462
+ torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim)
463
+ )
464
+ if not modality_cfg.init_extra_token_zero:
465
+ nn.init.normal_(self.extra_tokens)
466
+ elif self.extra_tokens.size(1) > 1:
467
+ nn.init.normal_(self.extra_tokens[:, 1:])
468
+
469
+ self.alibi_scale = None
470
+ if self.get_alibi_bias is not None:
471
+ self.alibi_scale = nn.Parameter(
472
+ torch.full(
473
+ (
474
+ (modality_cfg.prenet_depth + modality_cfg.model_depth)
475
+ if modality_cfg.learned_alibi_scale_per_layer
476
+ else 1,
477
+ 1,
478
+ self.modality_cfg.num_alibi_heads
479
+ if modality_cfg.learned_alibi_scale_per_head
480
+ else 1,
481
+ 1,
482
+ 1,
483
+ ),
484
+ modality_cfg.alibi_scale,
485
+ dtype=torch.float,
486
+ ),
487
+ requires_grad=modality_cfg.learned_alibi_scale,
488
+ )
489
+
490
+ if modality_cfg.learned_alibi and self.get_alibi_bias is not None:
491
+ assert modality_cfg.alibi_max_pos is not None
492
+ alibi_bias = self.get_alibi_bias(
493
+ batch_size=1,
494
+ time_steps=modality_cfg.alibi_max_pos,
495
+ heads=modality_cfg.num_alibi_heads,
496
+ scale=1.0,
497
+ dtype=torch.float,
498
+ device="cpu",
499
+ )
500
+ self.alibi_bias = nn.Parameter(alibi_bias)
501
+ self.get_alibi_bias = partial(
502
+ _learned_alibi_bias, alibi_bias=self.alibi_bias
503
+ )
504
+
505
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder._freeze_parameters
506
+ def _freeze_parameters(self):
507
+ for param in self.parameters():
508
+ param.requires_grad = False
509
+ self._requires_grad = False
510
+
511
+ def convert_padding_mask(self, x, padding_mask):
512
+ return padding_mask
513
+
514
+ def local_features(self, features):
515
+ if self.local_grad_mult > 0:
516
+ if self.local_grad_mult == 1.0:
517
+ x = self.local_encoder(features)
518
+ else:
519
+ x = GradMultiply.apply(
520
+ self.local_encoder(features), self.local_grad_mult
521
+ )
522
+ else:
523
+ with torch.no_grad():
524
+ x = self.local_encoder(features)
525
+
526
+ x = self.project_features(x)
527
+ return x
528
+
529
+ def contextualized_features(
530
+ self,
531
+ x,
532
+ padding_mask,
533
+ mask,
534
+ remove_masked,
535
+ clone_batch: int = 1,
536
+ mask_seeds: Optional[torch.Tensor] = None,
537
+ precomputed_mask=None,
538
+ ):
539
+
540
+ if padding_mask is not None:
541
+ padding_mask = self.convert_padding_mask(x, padding_mask)
542
+
543
+ local_features = x
544
+ if mask and clone_batch == 1:
545
+ local_features = local_features.clone()
546
+
547
+ orig_B, orig_T, _ = x.shape
548
+ pre_mask_B = orig_B
549
+ mask_info = None
550
+
551
+ x_pos = None
552
+ if self.fixed_positional_encoder is not None:
553
+ x = x + self.fixed_positional_encoder(x, padding_mask)
554
+
555
+ if mask:
556
+ if clone_batch > 1:
557
+ x = x.repeat_interleave(clone_batch, 0)
558
+ if mask_seeds is not None:
559
+ clone_hash = [
560
+ int(hash((mask_seeds.seed, ind)) % 1e10)
561
+ for ind in range(clone_batch - 1)
562
+ ]
563
+ clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
564
+
565
+ id = mask_seeds.ids
566
+ id = id.repeat_interleave(clone_batch, 0)
567
+ id = id.view(-1, clone_batch) + clone_hash.to(id)
568
+ id = id.view(-1)
569
+ mask_seeds = MaskSeed(
570
+ seed=mask_seeds.seed, update=mask_seeds.update, ids=id
571
+ )
572
+ if padding_mask is not None:
573
+ padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
574
+
575
+ x, mask_info = self.compute_mask(
576
+ x,
577
+ padding_mask,
578
+ mask_seed=mask_seeds,
579
+ apply=self.relative_positional_encoder is not None or not remove_masked,
580
+ precomputed_mask=precomputed_mask,
581
+ )
582
+
583
+ if self.relative_positional_encoder is not None:
584
+ x_pos = self.relative_positional_encoder(x)
585
+
586
+ masked_padding_mask = padding_mask
587
+ if mask and remove_masked:
588
+ x = mask_info.x_unmasked
589
+ if x_pos is not None:
590
+ x = x + gather_unmasked(x_pos, mask_info)
591
+
592
+ if padding_mask is not None and padding_mask.any():
593
+ masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info)
594
+ if not masked_padding_mask.any():
595
+ masked_padding_mask = None
596
+ else:
597
+ masked_padding_mask = None
598
+
599
+ elif x_pos is not None:
600
+ x = x + x_pos
601
+
602
+ alibi_bias = None
603
+ alibi_scale = self.alibi_scale
604
+
605
+ if self.get_alibi_bias is not None:
606
+ alibi_bias = self.get_alibi_bias(
607
+ batch_size=pre_mask_B,
608
+ time_steps=orig_T,
609
+ heads=self.modality_cfg.num_alibi_heads,
610
+ dtype=torch.float32,
611
+ device=x.device,
612
+ )
613
+
614
+ if alibi_scale is not None:
615
+ alibi_scale = alibi_scale.clamp_min(0)
616
+ if alibi_scale.size(0) == 1:
617
+ alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias)
618
+ alibi_scale = None
619
+
620
+ if clone_batch > 1:
621
+ alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0)
622
+
623
+ if mask_info is not None and remove_masked:
624
+ alibi_bias = masked_alibi(alibi_bias, mask_info)
625
+
626
+ if self.extra_tokens is not None:
627
+ num = self.extra_tokens.size(1)
628
+ x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1)
629
+ if masked_padding_mask is not None:
630
+ # B x T
631
+ masked_padding_mask = F.pad(masked_padding_mask, (num, 0))
632
+ if alibi_bias is not None:
633
+ # B x H x T x T
634
+ alibi_bias = F.pad(alibi_bias, (num, 0, num, 0))
635
+
636
+ x = self.context_encoder(
637
+ x,
638
+ masked_padding_mask,
639
+ alibi_bias,
640
+ alibi_scale[: self.modality_cfg.prenet_depth]
641
+ if alibi_scale is not None
642
+ else None,
643
+ )
644
+
645
+ return {
646
+ "x": x,
647
+ "local_features": local_features,
648
+ "padding_mask": masked_padding_mask,
649
+ "alibi_bias": alibi_bias,
650
+ "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
651
+ if alibi_scale is not None and alibi_scale.size(0) > 1
652
+ else alibi_scale,
653
+ "encoder_mask": mask_info,
654
+ }
655
+
656
+ def forward(
657
+ self,
658
+ features,
659
+ padding_mask,
660
+ mask: bool,
661
+ remove_masked: bool,
662
+ clone_batch: int = 1,
663
+ mask_seeds: Optional[torch.Tensor] = None,
664
+ precomputed_mask=None,
665
+ ):
666
+ x = self.local_features(features)
667
+ return self.contextualized_features(
668
+ x,
669
+ padding_mask,
670
+ mask,
671
+ remove_masked,
672
+ clone_batch,
673
+ mask_seeds,
674
+ precomputed_mask,
675
+ )
676
+
677
+ def compute_mask(
678
+ self,
679
+ x,
680
+ padding_mask,
681
+ mask_seed: Optional[MaskSeed],
682
+ apply,
683
+ precomputed_mask,
684
+ ):
685
+ if precomputed_mask is not None:
686
+ mask = precomputed_mask
687
+ mask_info = self.make_maskinfo(x, mask)
688
+ else:
689
+ B, T, C = x.shape
690
+ cfg = self.modality_cfg
691
+
692
+ mask_prob = cfg.mask_prob
693
+
694
+ if (
695
+ cfg.mask_prob_min is not None
696
+ and cfg.mask_prob_min >= 0
697
+ and cfg.mask_prob_min < mask_prob
698
+ ):
699
+ mask_prob = np.random.uniform(cfg.mask_prob_min, mask_prob)
700
+
701
+ if mask_prob > 0:
702
+ if cfg.mask_length == 1:
703
+ mask_info = random_masking(x, mask_prob, mask_seed)
704
+ else:
705
+ if self.modality_cfg.inverse_mask:
706
+ mask_prob = 1 - mask_prob
707
+
708
+ mask = compute_mask_indices(
709
+ (B, T),
710
+ padding_mask,
711
+ mask_prob,
712
+ cfg.mask_length,
713
+ min_masks=1,
714
+ require_same_masks=True,
715
+ mask_dropout=cfg.mask_dropout,
716
+ add_masks=cfg.add_masks,
717
+ seed=mask_seed.seed if mask_seed is not None else None,
718
+ epoch=mask_seed.update if mask_seed is not None else None,
719
+ indices=mask_seed.ids if mask_seed is not None else None,
720
+ )
721
+
722
+ mask = torch.from_numpy(mask).to(device=x.device)
723
+ if self.modality_cfg.inverse_mask:
724
+ mask = 1 - mask
725
+ mask_info = self.make_maskinfo(x, mask)
726
+ else:
727
+ mask_info = None
728
+
729
+ if apply:
730
+ x = self.apply_mask(x, mask_info)
731
+
732
+ return x, mask_info
733
+
734
+ def make_maskinfo(self, x, mask, shape=None):
735
+ if shape is None:
736
+ B, T, D = x.shape
737
+ else:
738
+ B, T, D = shape
739
+
740
+ mask = mask.to(torch.uint8)
741
+ ids_shuffle = mask.argsort(dim=1)
742
+ ids_restore = ids_shuffle.argsort(dim=1).unsqueeze(-1).expand(-1, -1, D)
743
+
744
+ len_keep = T - mask[0].sum()
745
+ if self.modality_cfg.keep_masked_pct > 0:
746
+ len_keep += round((T - int(len_keep)) * self.modality_cfg.keep_masked_pct)
747
+
748
+ ids_keep = ids_shuffle[:, :len_keep]
749
+
750
+ if shape is not None:
751
+ x_unmasked = None
752
+ else:
753
+ ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
754
+ x_unmasked = torch.gather(x, dim=1, index=ids_keep)
755
+
756
+ mask_info = MaskInfo(
757
+ x_unmasked=x_unmasked,
758
+ mask=mask,
759
+ ids_restore=ids_restore,
760
+ ids_keep=ids_keep,
761
+ )
762
+ return mask_info
763
+
764
+ def apply_mask(self, x, mask_info):
765
+ cfg = self.modality_cfg
766
+ B, T, C = x.shape
767
+
768
+ if mask_info is not None:
769
+ mask = mask_info.mask
770
+ if cfg.encoder_zero_mask:
771
+ x = x * (1 - mask.type_as(x).unsqueeze(-1))
772
+ else:
773
+ num_masks = mask.sum().item()
774
+ masks = x.new_empty(num_masks, x.size(-1)).normal_(
775
+ 0, cfg.mask_noise_std
776
+ )
777
+ x = index_put(x, mask, masks)
778
+ if cfg.mask_channel_prob > 0:
779
+ mask_channel = compute_mask_indices(
780
+ (B, C),
781
+ None,
782
+ cfg.mask_channel_prob,
783
+ cfg.mask_channel_length,
784
+ )
785
+ mask_channel = (
786
+ torch.from_numpy(mask_channel)
787
+ .to(x.device)
788
+ .unsqueeze(1)
789
+ .expand(-1, T, -1)
790
+ )
791
+ x = index_put(x, mask_channel, 0)
792
+ return x
793
+
794
+
795
+ class AudioEncoder(ModalitySpecificEncoder):
796
+
797
+ modality_cfg: D2v2AudioConfig
798
+
799
+ def __init__(
800
+ self,
801
+ modality_cfg: D2v2AudioConfig,
802
+ embed_dim: int,
803
+ make_block: Callable[[float], nn.ModuleList],
804
+ norm_layer: Callable[[int], nn.LayerNorm],
805
+ layer_norm_first: bool,
806
+ alibi_biases: Dict,
807
+ ):
808
+
809
+ self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec)
810
+ feature_embed_dim = self.feature_enc_layers[-1][0]
811
+
812
+ local_encoder = ConvFeatureExtractionModel(
813
+ conv_layers=self.feature_enc_layers,
814
+ dropout=0.0,
815
+ mode=modality_cfg.extractor_mode,
816
+ conv_bias=False,
817
+ )
818
+
819
+ project_features = nn.Sequential(
820
+ TransposeLast(),
821
+ nn.LayerNorm(feature_embed_dim),
822
+ nn.Linear(feature_embed_dim, embed_dim),
823
+ )
824
+
825
+ num_pos_layers = modality_cfg.conv_pos_depth
826
+ k = max(3, modality_cfg.conv_pos_width // num_pos_layers)
827
+
828
+ positional_encoder = nn.Sequential(
829
+ TransposeLast(),
830
+ *[
831
+ nn.Sequential(
832
+ nn.Conv1d(
833
+ embed_dim,
834
+ embed_dim,
835
+ kernel_size=k,
836
+ padding=k // 2,
837
+ groups=modality_cfg.conv_pos_groups,
838
+ ),
839
+ SamePad(k),
840
+ TransposeLast(),
841
+ LayerNorm(embed_dim, elementwise_affine=False),
842
+ TransposeLast(),
843
+ nn.GELU(),
844
+ )
845
+ for _ in range(num_pos_layers)
846
+ ],
847
+ TransposeLast(),
848
+ )
849
+
850
+ if modality_cfg.conv_pos_pre_ln:
851
+ positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder)
852
+
853
+ dpr = np.linspace(
854
+ modality_cfg.start_drop_path_rate,
855
+ modality_cfg.end_drop_path_rate,
856
+ modality_cfg.prenet_depth,
857
+ )
858
+ context_encoder = BlockEncoder(
859
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
860
+ norm_layer(embed_dim) if not layer_norm_first else None,
861
+ layer_norm_first,
862
+ modality_cfg.prenet_layerdrop,
863
+ modality_cfg.prenet_dropout,
864
+ )
865
+
866
+ decoder = None
867
+
868
+ alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
869
+
870
+ super().__init__(
871
+ modality_cfg=modality_cfg,
872
+ embed_dim=embed_dim,
873
+ local_encoder=local_encoder,
874
+ project_features=project_features,
875
+ fixed_positional_encoder=None,
876
+ relative_positional_encoder=positional_encoder,
877
+ context_encoder=context_encoder,
878
+ decoder=decoder,
879
+ get_alibi_bias=alibi_bias_fn,
880
+ )
881
+
882
+ def convert_padding_mask(self, x, padding_mask):
883
+ def get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
884
+ """
885
+ Computes the output length of the convolutional layers
886
+ """
887
+
888
+ def _conv_out_length(input_length, kernel_size, stride):
889
+ return torch.floor((input_length - kernel_size) / stride + 1)
890
+
891
+ for i in range(len(self.feature_enc_layers)):
892
+ input_lengths = _conv_out_length(
893
+ input_lengths,
894
+ self.feature_enc_layers[i][1],
895
+ self.feature_enc_layers[i][2],
896
+ )
897
+
898
+ return input_lengths.to(torch.long)
899
+
900
+ if padding_mask is not None:
901
+ input_lengths = (1 - padding_mask.long()).sum(-1)
902
+ # apply conv formula to get real output_lengths
903
+ output_lengths = get_feat_extract_output_lengths(input_lengths)
904
+
905
+ if padding_mask.any():
906
+ padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device)
907
+
908
+ # these two operations makes sure that all values
909
+ # before the output lengths indices are attended to
910
+ padding_mask[
911
+ (
912
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
913
+ output_lengths - 1,
914
+ )
915
+ ] = 1
916
+ padding_mask = (
917
+ 1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
918
+ ).bool()
919
+ else:
920
+ padding_mask = torch.zeros(
921
+ x.shape[:2], dtype=torch.bool, device=x.device
922
+ )
923
+
924
+ return padding_mask
925
+
926
+
927
+ class LearnedPositionalEmbedding(nn.Embedding):
928
+ """
929
+ This module learns positional embeddings up to a fixed maximum size.
930
+ Padding ids are ignored by either offsetting based on padding_idx
931
+ or by setting padding_idx to None and ensuring that the appropriate
932
+ position ids are passed to the forward function.
933
+ """
934
+
935
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
936
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
937
+ self.onnx_trace = False
938
+ if self.padding_idx is not None:
939
+ self.max_positions = self.num_embeddings - self.padding_idx - 1
940
+ else:
941
+ self.max_positions = self.num_embeddings
942
+
943
+ def forward(
944
+ self,
945
+ input: Tensor,
946
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
947
+ positions: Optional[Tensor] = None,
948
+ ):
949
+ """Input is expected to be of size [bsz x seqlen]."""
950
+ assert (positions is None) or (
951
+ self.padding_idx is None
952
+ ), "If positions is pre-computed then padding_idx should not be set."
953
+
954
+ if positions is None:
955
+ if incremental_state is not None:
956
+ # positions is the same for every token when decoding a single step
957
+ # Without the int() cast, it doesn't work in some cases when exporting to ONNX
958
+ positions = torch.zeros(
959
+ (1, 1), device=input.device, dtype=input.dtype
960
+ ).fill_(int(self.padding_idx + input.size(1)))
961
+ else:
962
+ positions = make_positions(
963
+ input, self.padding_idx, onnx_trace=self.onnx_trace
964
+ )
965
+ return F.embedding(
966
+ positions,
967
+ self.weight,
968
+ self.padding_idx,
969
+ self.max_norm,
970
+ self.norm_type,
971
+ self.scale_grad_by_freq,
972
+ self.sparse,
973
+ )
974
+
975
+
976
+ class SinusoidalPositionalEmbedding(nn.Module):
977
+ """This module produces sinusoidal positional embeddings of any length.
978
+
979
+ Padding symbols are ignored.
980
+ """
981
+
982
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
983
+ super().__init__()
984
+ self.embedding_dim = embedding_dim
985
+ self.padding_idx = padding_idx if padding_idx is not None else 0
986
+ self.register_buffer("weights", SinusoidalPositionalEmbedding.get_embedding(
987
+ init_size, embedding_dim, padding_idx
988
+ ), persistent=False)
989
+ self.max_positions = int(1e5)
990
+ self.onnx_trace = False
991
+
992
+ def prepare_for_onnx_export_(self):
993
+ self.onnx_trace = True
994
+
995
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
996
+ # Ignore some deprecated keys that were used in older versions
997
+ deprecated_keys = ["weights", "_float_tensor"]
998
+ for key in deprecated_keys:
999
+ if prefix + key in state_dict:
1000
+ del state_dict[prefix + key]
1001
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
1002
+
1003
+ @staticmethod
1004
+ def get_embedding(
1005
+ num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
1006
+ ):
1007
+ """Build sinusoidal embeddings.
1008
+
1009
+ This matches the implementation in tensor2tensor, but differs slightly
1010
+ from the description in Section 3.5 of "Attention Is All You Need".
1011
+ """
1012
+ half_dim = embedding_dim // 2
1013
+ emb = math.log(10000) / (half_dim - 1)
1014
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
1015
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1016
+ 1
1017
+ ) * emb.unsqueeze(0)
1018
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
1019
+ num_embeddings, -1
1020
+ )
1021
+ if embedding_dim % 2 == 1:
1022
+ # zero pad
1023
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
1024
+ if padding_idx is not None:
1025
+ emb[padding_idx, :] = 0
1026
+ return emb
1027
+
1028
+ def forward(
1029
+ self,
1030
+ input,
1031
+ incremental_state: Optional[Any] = None,
1032
+ timestep: Optional[Tensor] = None,
1033
+ positions: Optional[Any] = None,
1034
+ ):
1035
+ """Input is expected to be of size [bsz x seqlen]."""
1036
+ bspair = torch.onnx.operators.shape_as_tensor(input)
1037
+ bsz, seq_len = bspair[0], bspair[1]
1038
+ max_pos = self.padding_idx + 1 + seq_len
1039
+ if max_pos > self.weights.size(0):
1040
+ # expand embeddings if needed
1041
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
1042
+ max_pos, self.embedding_dim, self.padding_idx
1043
+ ).to(self.weights)
1044
+
1045
+ if incremental_state is not None:
1046
+ # positions is the same for every token when decoding a single step
1047
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
1048
+ if self.onnx_trace:
1049
+ return (
1050
+ self.weights.index_select(index=self.padding_idx + pos, dim=0)
1051
+ .unsqueeze(1)
1052
+ .repeat(bsz, 1, 1)
1053
+ )
1054
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
1055
+
1056
+ positions = make_positions(
1057
+ input, self.padding_idx, onnx_trace=self.onnx_trace
1058
+ )
1059
+ if self.onnx_trace:
1060
+ flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
1061
+ embedding_shape = torch.cat(
1062
+ (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long))
1063
+ )
1064
+ embeddings = torch.onnx.operators.reshape_from_tensor_shape(
1065
+ flat_embeddings, embedding_shape
1066
+ )
1067
+ return embeddings
1068
+ return (
1069
+ self.weights.index_select(0, positions.view(-1))
1070
+ .view(bsz, seq_len, -1)
1071
+ .detach()
1072
+ )
1073
+
1074
+ def PositionalEmbedding(
1075
+ num_embeddings: int,
1076
+ embedding_dim: int,
1077
+ padding_idx: int,
1078
+ learned: bool = False,
1079
+ ):
1080
+ if learned:
1081
+ # if padding_idx is specified then offset the embedding ids by
1082
+ # this index and adjust num_embeddings appropriately
1083
+ # TODO: The right place for this offset would be inside
1084
+ # LearnedPositionalEmbedding. Move this there for a cleaner implementation.
1085
+ if padding_idx is not None:
1086
+ num_embeddings = num_embeddings + padding_idx + 1
1087
+ m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
1088
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
1089
+ if padding_idx is not None:
1090
+ nn.init.constant_(m.weight[padding_idx], 0)
1091
+ else:
1092
+ m = SinusoidalPositionalEmbedding(
1093
+ embedding_dim,
1094
+ padding_idx,
1095
+ init_size=num_embeddings + padding_idx + 1,
1096
+ )
1097
+ return m
1098
+
1099
+
1100
+ class TextLocalEncoder(nn.Module):
1101
+ def __init__(
1102
+ self,
1103
+ vocab_size,
1104
+ embed_dim,
1105
+ max_source_positions,
1106
+ pad_idx,
1107
+ no_scale_embedding,
1108
+ layernorm_embedding,
1109
+ dropout,
1110
+ no_token_positional_embeddings,
1111
+ learned_pos,
1112
+ ):
1113
+ super().__init__()
1114
+ self.pad_idx = pad_idx
1115
+ self.dropout_module = nn.Dropout(dropout)
1116
+
1117
+ self.embed_tokens = nn.Embedding(vocab_size, embed_dim, pad_idx)
1118
+ self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
1119
+ self.embed_positions = (
1120
+ PositionalEmbedding(
1121
+ max_source_positions,
1122
+ embed_dim,
1123
+ pad_idx,
1124
+ learned=learned_pos,
1125
+ )
1126
+ if not no_token_positional_embeddings
1127
+ else None
1128
+ )
1129
+ self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
1130
+
1131
+ self.layernorm_embedding = None
1132
+ if layernorm_embedding:
1133
+ self.layernorm_embedding = LayerNorm(embed_dim)
1134
+
1135
+ def forward(self, src_tokens):
1136
+ x = self.embed_scale * self.embed_tokens(src_tokens)
1137
+ if self.embed_positions is not None:
1138
+ x = x + self.embed_positions(src_tokens)
1139
+
1140
+ if self.layernorm_embedding is not None:
1141
+ x = self.layernorm_embedding(x)
1142
+ x = self.dropout_module(x)
1143
+ return x
1144
+
1145
+
1146
+ class TextEncoder(ModalitySpecificEncoder):
1147
+
1148
+ modality_cfg: D2v2TextConfig
1149
+
1150
+ def __init__(
1151
+ self,
1152
+ modality_cfg: D2v2TextConfig,
1153
+ embed_dim: int,
1154
+ make_block: Callable[[float], nn.ModuleList],
1155
+ norm_layer: Callable[[int], nn.LayerNorm],
1156
+ layer_norm_first: bool,
1157
+ alibi_biases: Dict,
1158
+ ):
1159
+ self.pad_idx = modality_cfg.pad_token_id
1160
+ self.vocab_size = modality_cfg.vocab_size
1161
+
1162
+ local_encoder = TextLocalEncoder(
1163
+ vocab_size=self.vocab_size,
1164
+ embed_dim=embed_dim,
1165
+ max_source_positions=modality_cfg.max_source_positions,
1166
+ pad_idx=self.pad_idx,
1167
+ no_scale_embedding=modality_cfg.no_scale_embedding,
1168
+ layernorm_embedding=modality_cfg.layernorm_embedding,
1169
+ dropout=modality_cfg.dropout,
1170
+ no_token_positional_embeddings=modality_cfg.no_token_positional_embeddings,
1171
+ learned_pos=modality_cfg.learned_pos,
1172
+ )
1173
+ dpr = np.linspace(
1174
+ modality_cfg.start_drop_path_rate,
1175
+ modality_cfg.end_drop_path_rate,
1176
+ modality_cfg.prenet_depth,
1177
+ )
1178
+ context_encoder = BlockEncoder(
1179
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
1180
+ norm_layer(embed_dim)
1181
+ if not layer_norm_first and modality_cfg.prenet_depth > 0
1182
+ else None,
1183
+ layer_norm_first,
1184
+ modality_cfg.prenet_layerdrop,
1185
+ modality_cfg.prenet_dropout if modality_cfg.prenet_depth > 0 else 0.0,
1186
+ )
1187
+ decoder = None
1188
+
1189
+ alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
1190
+
1191
+ super().__init__(
1192
+ modality_cfg=modality_cfg,
1193
+ embed_dim=embed_dim,
1194
+ local_encoder=local_encoder,
1195
+ project_features=nn.Identity(),
1196
+ fixed_positional_encoder=None,
1197
+ relative_positional_encoder=None,
1198
+ context_encoder=context_encoder,
1199
+ decoder=decoder,
1200
+ get_alibi_bias=alibi_bias_fn,
1201
+ )
1202
+
1203
+ def convert_padding_mask(self, x, padding_mask):
1204
+ if padding_mask is None or padding_mask.size(1) == x.size(1):
1205
+ return padding_mask
1206
+
1207
+ diff = self.downsample - padding_mask.size(1) % self.downsample
1208
+ if 0 < diff < self.downsample:
1209
+ padding_mask = F.pad(padding_mask, (0, diff), value=True)
1210
+
1211
+ padding_mask = padding_mask.view(padding_mask.size(0), -1, self.downsample)
1212
+ padding_mask = padding_mask.all(-1)
1213
+ if padding_mask.size(1) > x.size(1):
1214
+ padding_mask = padding_mask[:, : x.size(1)]
1215
+
1216
+ assert x.size(1) == padding_mask.size(
1217
+ 1
1218
+ ), f"{x.size(1), padding_mask.size(1), diff, self.downsample}"
1219
+
1220
+ return padding_mask
1221
+ #################################################
1222
+
1223
+
1224
+ class Data2Vec2MultiPreTrainedModel(PreTrainedModel):
1225
+ # use init_bert_params from fairseq
1226
+ # copied from fairseq.modules.transformer_sentence_encoder.py
1227
+ def _init_weights(self, module):
1228
+ """Initialize the weights"""
1229
+
1230
+ def normal_(data):
1231
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
1232
+ # so that the RNG is consistent with and without FSDP
1233
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
1234
+
1235
+ def _init(module):
1236
+ if isinstance(module, nn.Linear):
1237
+ normal_(module.weight.data)
1238
+ if module.bias is not None:
1239
+ module.bias.data.zero_()
1240
+ if isinstance(module, nn.Embedding):
1241
+ normal_(module.weight.data)
1242
+ if module.padding_idx is not None:
1243
+ module.weight.data[module.padding_idx].zero_()
1244
+ if isinstance(module, AltBlock):
1245
+ normal_(module.attn.proj.weight.data)
1246
+ # init strategy for audio encoder
1247
+ if isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
1248
+ if module.bias is not None:
1249
+ module.bias.data.zero_()
1250
+ if module.weight is not None:
1251
+ module.weight.data.fill_(1.0)
1252
+ if isinstance(module, nn.Conv1d):
1253
+ nn.init.kaiming_normal_(module.weight)
1254
+ if module.bias is not None:
1255
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
1256
+ nn.init.uniform_(module.bias, a=-k, b=k)
1257
+
1258
+ if isinstance(module, nn.ModuleList):
1259
+ for _, mod in enumerate(module):
1260
+ _init(mod)
1261
+ else:
1262
+ _init(module)
1263
+
1264
+ # @classmethod
1265
+ # def from_pretrained(
1266
+ # cls,
1267
+ # pretrained_model_name_or_path,
1268
+ # *model_args,
1269
+ # **kwargs,
1270
+ # ):
1271
+ # config = cls.config_class()
1272
+ # config.from_pretrained(pretrained_model_name_or_path)
1273
+ # print(f"Loading configuration from pre-trained model: {type(config)}")
1274
+ # return super().from_pretrained(pretrained_model_name_or_path,
1275
+ # *model_args,
1276
+ # config,
1277
+ # **kwargs,)
1278
+
1279
+
1280
+ class Data2Vec2MultiModel(Data2Vec2MultiPreTrainedModel):
1281
+ config_class = Data2Vec2MultiConfig
1282
+ base_model_prefix = "data2vec2"
1283
+
1284
+ def __init__(self, config: Data2Vec2MultiConfig):
1285
+ super().__init__(config)
1286
+ self.config = config
1287
+ modalities_cfg = config.modalities
1288
+ self.modalities = [config.supported_modality]
1289
+
1290
+ make_layer_norm = partial(
1291
+ nn.LayerNorm, eps=config.norm_eps, elementwise_affine=config.norm_affine
1292
+ )
1293
+
1294
+ def make_block(drop_path, dim=None, heads=None):
1295
+ return AltBlock(
1296
+ config.embed_dim if dim is None else dim,
1297
+ config.num_heads if heads is None else heads,
1298
+ config.mlp_ratio,
1299
+ qkv_bias=True,
1300
+ drop=config.encoder_dropout,
1301
+ attn_drop=config.attention_dropout,
1302
+ mlp_drop=config.activation_dropout,
1303
+ post_mlp_drop=config.post_mlp_drop,
1304
+ drop_path=drop_path,
1305
+ norm_layer=make_layer_norm,
1306
+ layer_norm_first=config.layer_norm_first,
1307
+ ffn_targets=not config.end_of_block_targets,
1308
+ )
1309
+
1310
+ self.alibi_biases = {}
1311
+ self.modality_encoders = nn.ModuleDict()
1312
+ for mod in self.modalities:
1313
+ mod_cfg = getattr(modalities_cfg, mod.lower())
1314
+ enc = self.make_modality_encoder(
1315
+ mod_cfg,
1316
+ config.embed_dim,
1317
+ make_block,
1318
+ make_layer_norm,
1319
+ config.layer_norm_first,
1320
+ self.alibi_biases,
1321
+ )
1322
+ self.modality_encoders[mod] = enc
1323
+
1324
+ self.dropout_input = nn.Dropout(config.dropout_input)
1325
+
1326
+ dpr = np.linspace(config.start_drop_path_rate, config.end_drop_path_rate, config.depth)
1327
+
1328
+ self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(config.depth)])
1329
+
1330
+ self.norm = None
1331
+ if config.layer_norm_first:
1332
+ self.norm = make_layer_norm(config.embed_dim)
1333
+
1334
+ self.num_updates = 0
1335
+
1336
+ # Initialize weights and apply final processing
1337
+ self.post_init()
1338
+
1339
+ def freeze_feature_extractor(self):
1340
+ """
1341
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
1342
+ not be updated during training.
1343
+ """
1344
+ warnings.warn(
1345
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
1346
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
1347
+ FutureWarning,
1348
+ )
1349
+ self.freeze_feature_encoder()
1350
+
1351
+ def freeze_feature_encoder(self):
1352
+ """
1353
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1354
+ not be updated during training.
1355
+ """
1356
+ for mod in self.modalities:
1357
+ self.modality_encoders[mod]._freeze_parameters()
1358
+
1359
+ def make_modality_encoder(
1360
+ self,
1361
+ cfg: D2v2ModalityConfig,
1362
+ embed_dim: int,
1363
+ make_block: Callable[[float], nn.ModuleList],
1364
+ norm_layer: Callable[[int], nn.LayerNorm],
1365
+ layer_norm_first: bool,
1366
+ alibi_biases,
1367
+ ) -> ModalitySpecificEncoder:
1368
+ if cfg.type == "AUDIO":
1369
+ enc_cls = AudioEncoder
1370
+ elif cfg.type == "TEXT":
1371
+ enc_cls = TextEncoder
1372
+ else:
1373
+ raise Exception(f"unsupported modality {cfg.type}")
1374
+
1375
+ return enc_cls(
1376
+ cfg,
1377
+ embed_dim,
1378
+ make_block,
1379
+ norm_layer,
1380
+ layer_norm_first,
1381
+ alibi_biases,
1382
+ )
1383
+
1384
+ def forward(
1385
+ self,
1386
+ input_values=None, # audio input
1387
+ input_ids=None, # text input
1388
+ attention_mask=None,
1389
+ padding_mask=None,
1390
+ mask=False,
1391
+ mode=None,
1392
+ output_hidden_states=True,
1393
+ return_dict=True,
1394
+ ):
1395
+ if mode is None:
1396
+ mode = "TEXT" if input_ids is not None else "AUDIO"
1397
+ feature_extractor = self.modality_encoders[mode]
1398
+ extractor_out = feature_extractor(
1399
+ input_ids if input_ids is not None else input_values,
1400
+ padding_mask,
1401
+ mask,
1402
+ remove_masked=False,
1403
+ clone_batch=1,
1404
+ mask_seeds=None,
1405
+ precomputed_mask=None,
1406
+ )
1407
+ x = extractor_out["x"]
1408
+ extract_features = x
1409
+
1410
+ # encoder_mask = extractor_out["encoder_mask"]
1411
+ masked_padding_mask = extractor_out["padding_mask"]
1412
+ masked_alibi_bias = extractor_out.get("alibi_bias", None)
1413
+ alibi_scale = extractor_out.get("alibi_scale", None)
1414
+
1415
+ if self.dropout_input is not None:
1416
+ x = self.dropout_input(x)
1417
+
1418
+ layer_results = []
1419
+ for i, blk in enumerate(self.blocks):
1420
+ if (
1421
+ not self.training
1422
+ or self.config.layerdrop == 0
1423
+ or (np.random.random() > self.config.layerdrop)
1424
+ ):
1425
+ ab = masked_alibi_bias
1426
+ if ab is not None and alibi_scale is not None:
1427
+ scale = (
1428
+ alibi_scale[i]
1429
+ if alibi_scale.size(0) > 1
1430
+ else alibi_scale.squeeze(0)
1431
+ )
1432
+ ab = ab * scale.type_as(ab)
1433
+
1434
+ x, lr = blk(
1435
+ x,
1436
+ padding_mask=masked_padding_mask,
1437
+ alibi_bias=ab,
1438
+ )
1439
+ layer_results.append(lr)
1440
+
1441
+ if self.norm is not None:
1442
+ x = self.norm(x)
1443
+
1444
+ x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
1445
+ if masked_padding_mask is not None:
1446
+ masked_padding_mask = masked_padding_mask[
1447
+ :, feature_extractor.modality_cfg.num_extra_tokens :
1448
+ ]
1449
+
1450
+ if not return_dict:
1451
+ return tuple(
1452
+ v
1453
+ for v in [
1454
+ x,
1455
+ extract_features,
1456
+ layer_results,
1457
+ ]
1458
+ if v is not None
1459
+ )
1460
+
1461
+ return Wav2Vec2BaseModelOutput(
1462
+ last_hidden_state=x,
1463
+ extract_features=extract_features,
1464
+ hidden_states=layer_results if output_hidden_states else None,
1465
+ attentions=None, # switch to manual implementation with fast=False in forward pass of AltAttention as pytorch's dspa does not output attention weights
1466
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": false,
26
+ "normalized": true,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": true,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": true,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,855 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": true,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": true,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "<mask>",
37
+ "lstrip": false,
38
+ "normalized": true,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "5": {
44
+ "content": "<extra_id_0>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "6": {
52
+ "content": "<extra_id_1>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "7": {
60
+ "content": "<extra_id_2>",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "8": {
68
+ "content": "<extra_id_3>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "9": {
76
+ "content": "<extra_id_4>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": true
82
+ },
83
+ "10": {
84
+ "content": "<extra_id_5>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": true
90
+ },
91
+ "11": {
92
+ "content": "<extra_id_6>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": true
98
+ },
99
+ "12": {
100
+ "content": "<extra_id_7>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": true
106
+ },
107
+ "13": {
108
+ "content": "<extra_id_8>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": true
114
+ },
115
+ "14": {
116
+ "content": "<extra_id_9>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "15": {
124
+ "content": "<extra_id_10>",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ },
131
+ "16": {
132
+ "content": "<extra_id_11>",
133
+ "lstrip": false,
134
+ "normalized": false,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": true
138
+ },
139
+ "17": {
140
+ "content": "<extra_id_12>",
141
+ "lstrip": false,
142
+ "normalized": false,
143
+ "rstrip": false,
144
+ "single_word": false,
145
+ "special": true
146
+ },
147
+ "18": {
148
+ "content": "<extra_id_13>",
149
+ "lstrip": false,
150
+ "normalized": false,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": true
154
+ },
155
+ "19": {
156
+ "content": "<extra_id_14>",
157
+ "lstrip": false,
158
+ "normalized": false,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": true
162
+ },
163
+ "20": {
164
+ "content": "<extra_id_15>",
165
+ "lstrip": false,
166
+ "normalized": false,
167
+ "rstrip": false,
168
+ "single_word": false,
169
+ "special": true
170
+ },
171
+ "21": {
172
+ "content": "<extra_id_16>",
173
+ "lstrip": false,
174
+ "normalized": false,
175
+ "rstrip": false,
176
+ "single_word": false,
177
+ "special": true
178
+ },
179
+ "22": {
180
+ "content": "<extra_id_17>",
181
+ "lstrip": false,
182
+ "normalized": false,
183
+ "rstrip": false,
184
+ "single_word": false,
185
+ "special": true
186
+ },
187
+ "23": {
188
+ "content": "<extra_id_18>",
189
+ "lstrip": false,
190
+ "normalized": false,
191
+ "rstrip": false,
192
+ "single_word": false,
193
+ "special": true
194
+ },
195
+ "24": {
196
+ "content": "<extra_id_19>",
197
+ "lstrip": false,
198
+ "normalized": false,
199
+ "rstrip": false,
200
+ "single_word": false,
201
+ "special": true
202
+ },
203
+ "25": {
204
+ "content": "<extra_id_20>",
205
+ "lstrip": false,
206
+ "normalized": false,
207
+ "rstrip": false,
208
+ "single_word": false,
209
+ "special": true
210
+ },
211
+ "26": {
212
+ "content": "<extra_id_21>",
213
+ "lstrip": false,
214
+ "normalized": false,
215
+ "rstrip": false,
216
+ "single_word": false,
217
+ "special": true
218
+ },
219
+ "27": {
220
+ "content": "<extra_id_22>",
221
+ "lstrip": false,
222
+ "normalized": false,
223
+ "rstrip": false,
224
+ "single_word": false,
225
+ "special": true
226
+ },
227
+ "28": {
228
+ "content": "<extra_id_23>",
229
+ "lstrip": false,
230
+ "normalized": false,
231
+ "rstrip": false,
232
+ "single_word": false,
233
+ "special": true
234
+ },
235
+ "29": {
236
+ "content": "<extra_id_24>",
237
+ "lstrip": false,
238
+ "normalized": false,
239
+ "rstrip": false,
240
+ "single_word": false,
241
+ "special": true
242
+ },
243
+ "30": {
244
+ "content": "<extra_id_25>",
245
+ "lstrip": false,
246
+ "normalized": false,
247
+ "rstrip": false,
248
+ "single_word": false,
249
+ "special": true
250
+ },
251
+ "31": {
252
+ "content": "<extra_id_26>",
253
+ "lstrip": false,
254
+ "normalized": false,
255
+ "rstrip": false,
256
+ "single_word": false,
257
+ "special": true
258
+ },
259
+ "32": {
260
+ "content": "<extra_id_27>",
261
+ "lstrip": false,
262
+ "normalized": false,
263
+ "rstrip": false,
264
+ "single_word": false,
265
+ "special": true
266
+ },
267
+ "33": {
268
+ "content": "<extra_id_28>",
269
+ "lstrip": false,
270
+ "normalized": false,
271
+ "rstrip": false,
272
+ "single_word": false,
273
+ "special": true
274
+ },
275
+ "34": {
276
+ "content": "<extra_id_29>",
277
+ "lstrip": false,
278
+ "normalized": false,
279
+ "rstrip": false,
280
+ "single_word": false,
281
+ "special": true
282
+ },
283
+ "35": {
284
+ "content": "<extra_id_30>",
285
+ "lstrip": false,
286
+ "normalized": false,
287
+ "rstrip": false,
288
+ "single_word": false,
289
+ "special": true
290
+ },
291
+ "36": {
292
+ "content": "<extra_id_31>",
293
+ "lstrip": false,
294
+ "normalized": false,
295
+ "rstrip": false,
296
+ "single_word": false,
297
+ "special": true
298
+ },
299
+ "37": {
300
+ "content": "<extra_id_32>",
301
+ "lstrip": false,
302
+ "normalized": false,
303
+ "rstrip": false,
304
+ "single_word": false,
305
+ "special": true
306
+ },
307
+ "38": {
308
+ "content": "<extra_id_33>",
309
+ "lstrip": false,
310
+ "normalized": false,
311
+ "rstrip": false,
312
+ "single_word": false,
313
+ "special": true
314
+ },
315
+ "39": {
316
+ "content": "<extra_id_34>",
317
+ "lstrip": false,
318
+ "normalized": false,
319
+ "rstrip": false,
320
+ "single_word": false,
321
+ "special": true
322
+ },
323
+ "40": {
324
+ "content": "<extra_id_35>",
325
+ "lstrip": false,
326
+ "normalized": false,
327
+ "rstrip": false,
328
+ "single_word": false,
329
+ "special": true
330
+ },
331
+ "41": {
332
+ "content": "<extra_id_36>",
333
+ "lstrip": false,
334
+ "normalized": false,
335
+ "rstrip": false,
336
+ "single_word": false,
337
+ "special": true
338
+ },
339
+ "42": {
340
+ "content": "<extra_id_37>",
341
+ "lstrip": false,
342
+ "normalized": false,
343
+ "rstrip": false,
344
+ "single_word": false,
345
+ "special": true
346
+ },
347
+ "43": {
348
+ "content": "<extra_id_38>",
349
+ "lstrip": false,
350
+ "normalized": false,
351
+ "rstrip": false,
352
+ "single_word": false,
353
+ "special": true
354
+ },
355
+ "44": {
356
+ "content": "<extra_id_39>",
357
+ "lstrip": false,
358
+ "normalized": false,
359
+ "rstrip": false,
360
+ "single_word": false,
361
+ "special": true
362
+ },
363
+ "45": {
364
+ "content": "<extra_id_40>",
365
+ "lstrip": false,
366
+ "normalized": false,
367
+ "rstrip": false,
368
+ "single_word": false,
369
+ "special": true
370
+ },
371
+ "46": {
372
+ "content": "<extra_id_41>",
373
+ "lstrip": false,
374
+ "normalized": false,
375
+ "rstrip": false,
376
+ "single_word": false,
377
+ "special": true
378
+ },
379
+ "47": {
380
+ "content": "<extra_id_42>",
381
+ "lstrip": false,
382
+ "normalized": false,
383
+ "rstrip": false,
384
+ "single_word": false,
385
+ "special": true
386
+ },
387
+ "48": {
388
+ "content": "<extra_id_43>",
389
+ "lstrip": false,
390
+ "normalized": false,
391
+ "rstrip": false,
392
+ "single_word": false,
393
+ "special": true
394
+ },
395
+ "49": {
396
+ "content": "<extra_id_44>",
397
+ "lstrip": false,
398
+ "normalized": false,
399
+ "rstrip": false,
400
+ "single_word": false,
401
+ "special": true
402
+ },
403
+ "50": {
404
+ "content": "<extra_id_45>",
405
+ "lstrip": false,
406
+ "normalized": false,
407
+ "rstrip": false,
408
+ "single_word": false,
409
+ "special": true
410
+ },
411
+ "51": {
412
+ "content": "<extra_id_46>",
413
+ "lstrip": false,
414
+ "normalized": false,
415
+ "rstrip": false,
416
+ "single_word": false,
417
+ "special": true
418
+ },
419
+ "52": {
420
+ "content": "<extra_id_47>",
421
+ "lstrip": false,
422
+ "normalized": false,
423
+ "rstrip": false,
424
+ "single_word": false,
425
+ "special": true
426
+ },
427
+ "53": {
428
+ "content": "<extra_id_48>",
429
+ "lstrip": false,
430
+ "normalized": false,
431
+ "rstrip": false,
432
+ "single_word": false,
433
+ "special": true
434
+ },
435
+ "54": {
436
+ "content": "<extra_id_49>",
437
+ "lstrip": false,
438
+ "normalized": false,
439
+ "rstrip": false,
440
+ "single_word": false,
441
+ "special": true
442
+ },
443
+ "55": {
444
+ "content": "<extra_id_50>",
445
+ "lstrip": false,
446
+ "normalized": false,
447
+ "rstrip": false,
448
+ "single_word": false,
449
+ "special": true
450
+ },
451
+ "56": {
452
+ "content": "<extra_id_51>",
453
+ "lstrip": false,
454
+ "normalized": false,
455
+ "rstrip": false,
456
+ "single_word": false,
457
+ "special": true
458
+ },
459
+ "57": {
460
+ "content": "<extra_id_52>",
461
+ "lstrip": false,
462
+ "normalized": false,
463
+ "rstrip": false,
464
+ "single_word": false,
465
+ "special": true
466
+ },
467
+ "58": {
468
+ "content": "<extra_id_53>",
469
+ "lstrip": false,
470
+ "normalized": false,
471
+ "rstrip": false,
472
+ "single_word": false,
473
+ "special": true
474
+ },
475
+ "59": {
476
+ "content": "<extra_id_54>",
477
+ "lstrip": false,
478
+ "normalized": false,
479
+ "rstrip": false,
480
+ "single_word": false,
481
+ "special": true
482
+ },
483
+ "60": {
484
+ "content": "<extra_id_55>",
485
+ "lstrip": false,
486
+ "normalized": false,
487
+ "rstrip": false,
488
+ "single_word": false,
489
+ "special": true
490
+ },
491
+ "61": {
492
+ "content": "<extra_id_56>",
493
+ "lstrip": false,
494
+ "normalized": false,
495
+ "rstrip": false,
496
+ "single_word": false,
497
+ "special": true
498
+ },
499
+ "62": {
500
+ "content": "<extra_id_57>",
501
+ "lstrip": false,
502
+ "normalized": false,
503
+ "rstrip": false,
504
+ "single_word": false,
505
+ "special": true
506
+ },
507
+ "63": {
508
+ "content": "<extra_id_58>",
509
+ "lstrip": false,
510
+ "normalized": false,
511
+ "rstrip": false,
512
+ "single_word": false,
513
+ "special": true
514
+ },
515
+ "64": {
516
+ "content": "<extra_id_59>",
517
+ "lstrip": false,
518
+ "normalized": false,
519
+ "rstrip": false,
520
+ "single_word": false,
521
+ "special": true
522
+ },
523
+ "65": {
524
+ "content": "<extra_id_60>",
525
+ "lstrip": false,
526
+ "normalized": false,
527
+ "rstrip": false,
528
+ "single_word": false,
529
+ "special": true
530
+ },
531
+ "66": {
532
+ "content": "<extra_id_61>",
533
+ "lstrip": false,
534
+ "normalized": false,
535
+ "rstrip": false,
536
+ "single_word": false,
537
+ "special": true
538
+ },
539
+ "67": {
540
+ "content": "<extra_id_62>",
541
+ "lstrip": false,
542
+ "normalized": false,
543
+ "rstrip": false,
544
+ "single_word": false,
545
+ "special": true
546
+ },
547
+ "68": {
548
+ "content": "<extra_id_63>",
549
+ "lstrip": false,
550
+ "normalized": false,
551
+ "rstrip": false,
552
+ "single_word": false,
553
+ "special": true
554
+ },
555
+ "69": {
556
+ "content": "<extra_id_64>",
557
+ "lstrip": false,
558
+ "normalized": false,
559
+ "rstrip": false,
560
+ "single_word": false,
561
+ "special": true
562
+ },
563
+ "70": {
564
+ "content": "<extra_id_65>",
565
+ "lstrip": false,
566
+ "normalized": false,
567
+ "rstrip": false,
568
+ "single_word": false,
569
+ "special": true
570
+ },
571
+ "71": {
572
+ "content": "<extra_id_66>",
573
+ "lstrip": false,
574
+ "normalized": false,
575
+ "rstrip": false,
576
+ "single_word": false,
577
+ "special": true
578
+ },
579
+ "72": {
580
+ "content": "<extra_id_67>",
581
+ "lstrip": false,
582
+ "normalized": false,
583
+ "rstrip": false,
584
+ "single_word": false,
585
+ "special": true
586
+ },
587
+ "73": {
588
+ "content": "<extra_id_68>",
589
+ "lstrip": false,
590
+ "normalized": false,
591
+ "rstrip": false,
592
+ "single_word": false,
593
+ "special": true
594
+ },
595
+ "74": {
596
+ "content": "<extra_id_69>",
597
+ "lstrip": false,
598
+ "normalized": false,
599
+ "rstrip": false,
600
+ "single_word": false,
601
+ "special": true
602
+ },
603
+ "75": {
604
+ "content": "<extra_id_70>",
605
+ "lstrip": false,
606
+ "normalized": false,
607
+ "rstrip": false,
608
+ "single_word": false,
609
+ "special": true
610
+ },
611
+ "76": {
612
+ "content": "<extra_id_71>",
613
+ "lstrip": false,
614
+ "normalized": false,
615
+ "rstrip": false,
616
+ "single_word": false,
617
+ "special": true
618
+ },
619
+ "77": {
620
+ "content": "<extra_id_72>",
621
+ "lstrip": false,
622
+ "normalized": false,
623
+ "rstrip": false,
624
+ "single_word": false,
625
+ "special": true
626
+ },
627
+ "78": {
628
+ "content": "<extra_id_73>",
629
+ "lstrip": false,
630
+ "normalized": false,
631
+ "rstrip": false,
632
+ "single_word": false,
633
+ "special": true
634
+ },
635
+ "79": {
636
+ "content": "<extra_id_74>",
637
+ "lstrip": false,
638
+ "normalized": false,
639
+ "rstrip": false,
640
+ "single_word": false,
641
+ "special": true
642
+ },
643
+ "80": {
644
+ "content": "<extra_id_75>",
645
+ "lstrip": false,
646
+ "normalized": false,
647
+ "rstrip": false,
648
+ "single_word": false,
649
+ "special": true
650
+ },
651
+ "81": {
652
+ "content": "<extra_id_76>",
653
+ "lstrip": false,
654
+ "normalized": false,
655
+ "rstrip": false,
656
+ "single_word": false,
657
+ "special": true
658
+ },
659
+ "82": {
660
+ "content": "<extra_id_77>",
661
+ "lstrip": false,
662
+ "normalized": false,
663
+ "rstrip": false,
664
+ "single_word": false,
665
+ "special": true
666
+ },
667
+ "83": {
668
+ "content": "<extra_id_78>",
669
+ "lstrip": false,
670
+ "normalized": false,
671
+ "rstrip": false,
672
+ "single_word": false,
673
+ "special": true
674
+ },
675
+ "84": {
676
+ "content": "<extra_id_79>",
677
+ "lstrip": false,
678
+ "normalized": false,
679
+ "rstrip": false,
680
+ "single_word": false,
681
+ "special": true
682
+ },
683
+ "85": {
684
+ "content": "<extra_id_80>",
685
+ "lstrip": false,
686
+ "normalized": false,
687
+ "rstrip": false,
688
+ "single_word": false,
689
+ "special": true
690
+ },
691
+ "86": {
692
+ "content": "<extra_id_81>",
693
+ "lstrip": false,
694
+ "normalized": false,
695
+ "rstrip": false,
696
+ "single_word": false,
697
+ "special": true
698
+ },
699
+ "87": {
700
+ "content": "<extra_id_82>",
701
+ "lstrip": false,
702
+ "normalized": false,
703
+ "rstrip": false,
704
+ "single_word": false,
705
+ "special": true
706
+ },
707
+ "88": {
708
+ "content": "<extra_id_83>",
709
+ "lstrip": false,
710
+ "normalized": false,
711
+ "rstrip": false,
712
+ "single_word": false,
713
+ "special": true
714
+ },
715
+ "89": {
716
+ "content": "<extra_id_84>",
717
+ "lstrip": false,
718
+ "normalized": false,
719
+ "rstrip": false,
720
+ "single_word": false,
721
+ "special": true
722
+ },
723
+ "90": {
724
+ "content": "<extra_id_85>",
725
+ "lstrip": false,
726
+ "normalized": false,
727
+ "rstrip": false,
728
+ "single_word": false,
729
+ "special": true
730
+ },
731
+ "91": {
732
+ "content": "<extra_id_86>",
733
+ "lstrip": false,
734
+ "normalized": false,
735
+ "rstrip": false,
736
+ "single_word": false,
737
+ "special": true
738
+ },
739
+ "92": {
740
+ "content": "<extra_id_87>",
741
+ "lstrip": false,
742
+ "normalized": false,
743
+ "rstrip": false,
744
+ "single_word": false,
745
+ "special": true
746
+ },
747
+ "93": {
748
+ "content": "<extra_id_88>",
749
+ "lstrip": false,
750
+ "normalized": false,
751
+ "rstrip": false,
752
+ "single_word": false,
753
+ "special": true
754
+ },
755
+ "94": {
756
+ "content": "<extra_id_89>",
757
+ "lstrip": false,
758
+ "normalized": false,
759
+ "rstrip": false,
760
+ "single_word": false,
761
+ "special": true
762
+ },
763
+ "95": {
764
+ "content": "<extra_id_90>",
765
+ "lstrip": false,
766
+ "normalized": false,
767
+ "rstrip": false,
768
+ "single_word": false,
769
+ "special": true
770
+ },
771
+ "96": {
772
+ "content": "<extra_id_91>",
773
+ "lstrip": false,
774
+ "normalized": false,
775
+ "rstrip": false,
776
+ "single_word": false,
777
+ "special": true
778
+ },
779
+ "97": {
780
+ "content": "<extra_id_92>",
781
+ "lstrip": false,
782
+ "normalized": false,
783
+ "rstrip": false,
784
+ "single_word": false,
785
+ "special": true
786
+ },
787
+ "98": {
788
+ "content": "<extra_id_93>",
789
+ "lstrip": false,
790
+ "normalized": false,
791
+ "rstrip": false,
792
+ "single_word": false,
793
+ "special": true
794
+ },
795
+ "99": {
796
+ "content": "<extra_id_94>",
797
+ "lstrip": false,
798
+ "normalized": false,
799
+ "rstrip": false,
800
+ "single_word": false,
801
+ "special": true
802
+ },
803
+ "100": {
804
+ "content": "<extra_id_95>",
805
+ "lstrip": false,
806
+ "normalized": false,
807
+ "rstrip": false,
808
+ "single_word": false,
809
+ "special": true
810
+ },
811
+ "101": {
812
+ "content": "<extra_id_96>",
813
+ "lstrip": false,
814
+ "normalized": false,
815
+ "rstrip": false,
816
+ "single_word": false,
817
+ "special": true
818
+ },
819
+ "102": {
820
+ "content": "<extra_id_97>",
821
+ "lstrip": false,
822
+ "normalized": false,
823
+ "rstrip": false,
824
+ "single_word": false,
825
+ "special": true
826
+ },
827
+ "103": {
828
+ "content": "<extra_id_98>",
829
+ "lstrip": false,
830
+ "normalized": false,
831
+ "rstrip": false,
832
+ "single_word": false,
833
+ "special": true
834
+ },
835
+ "104": {
836
+ "content": "<extra_id_99>",
837
+ "lstrip": false,
838
+ "normalized": false,
839
+ "rstrip": false,
840
+ "single_word": false,
841
+ "special": true
842
+ }
843
+ },
844
+ "bos_token": "<s>",
845
+ "clean_up_tokenization_spaces": false,
846
+ "cls_token": "<s>",
847
+ "eos_token": "</s>",
848
+ "extra_special_tokens": {},
849
+ "mask_token": "<mask>",
850
+ "model_max_length": 1000000000000000019884624838656,
851
+ "pad_token": "<pad>",
852
+ "sep_token": "</s>",
853
+ "tokenizer_class": "PreTrainedTokenizerFast",
854
+ "unk_token": "<unk>"
855
+ }
utils_data2vec2.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+
9
+ import math
10
+ import numpy as np
11
+ from collections import namedtuple
12
+ from typing import Optional, Tuple
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+
18
+ MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
19
+ MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
20
+
21
+
22
+ def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
23
+ return torch.gather(
24
+ x,
25
+ dim=1,
26
+ index=mask_info.ids_keep,
27
+ )
28
+
29
+
30
+ def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
31
+ return torch.gather(
32
+ x,
33
+ dim=1,
34
+ index=mask_info.ids_keep[..., 0], # ignore the feature dimension
35
+ )
36
+
37
+
38
+ def masked_alibi(alibi_bias, mask_info):
39
+ H = alibi_bias.size(1)
40
+
41
+ orig_bias = alibi_bias
42
+
43
+ index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
44
+ alibi_bias = torch.gather(
45
+ orig_bias,
46
+ dim=-2,
47
+ index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
48
+ )
49
+ alibi_bias = torch.gather(
50
+ alibi_bias,
51
+ dim=-1,
52
+ index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
53
+ )
54
+
55
+ return alibi_bias
56
+
57
+
58
+ def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]):
59
+ N, L, D = x.shape # batch, length, dim
60
+ len_keep = int(L * (1 - mask_ratio))
61
+
62
+ generator = None
63
+ if mask_seed is not None:
64
+ seed = int(
65
+ hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
66
+ )
67
+ generator = torch.Generator(device=x.device)
68
+ generator.manual_seed(seed)
69
+
70
+ noise = torch.rand(N, L, generator=generator, device=x.device) # noise in [0, 1]
71
+
72
+ # sort noise for each sample
73
+ ids_shuffle = noise.argsort(dim=1) # ascend: small is keep, large is remove
74
+ ids_restore = ids_shuffle.argsort(dim=1)
75
+
76
+ # keep the first subset
77
+ ids_keep = ids_shuffle[:, :len_keep]
78
+ ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
79
+ x_unmasked = torch.gather(x, dim=1, index=ids_keep)
80
+
81
+ # generate the binary mask: 0 is keep, 1 is remove
82
+ mask = torch.ones([N, L], dtype=x.dtype, device=x.device)
83
+ mask[:, :len_keep] = 0
84
+ # unshuffle to get the binary mask
85
+ mask = torch.gather(mask, dim=1, index=ids_restore)
86
+
87
+ ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
88
+
89
+ return MaskInfo(
90
+ x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
91
+ )
92
+
93
+
94
+ def get_alibi(
95
+ max_positions: int,
96
+ attention_heads: int,
97
+ dims: int = 1,
98
+ distance: str = "manhattan",
99
+ ):
100
+ def get_slopes(n):
101
+ def get_slopes_power_of_2(n):
102
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
103
+ ratio = start
104
+ return [start * ratio**i for i in range(n)]
105
+
106
+ # In the paper, we only train models that have 2^a heads for some
107
+ # a. This function has some good properties that only occur when
108
+ # the input is a power of 2. To maintain that even when the number
109
+ # of heads is not a power of 2, we use this workaround.
110
+ if math.log2(n).is_integer():
111
+ return get_slopes_power_of_2(n)
112
+ else:
113
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
114
+ return (
115
+ get_slopes_power_of_2(closest_power_of_2)
116
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
117
+ )
118
+
119
+ maxpos = max_positions
120
+ attn_heads = attention_heads
121
+ slopes = torch.Tensor(get_slopes(attn_heads))
122
+
123
+ if dims == 1:
124
+ # prepare alibi position linear bias. Note that wav2vec2 is non
125
+ # autoregressive model so we want a symmetric mask with 0 on the
126
+ # diagonal and other wise linear decreasing valuees
127
+ pos_bias = (
128
+ torch.abs(
129
+ torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
130
+ )
131
+ * -1
132
+ )
133
+ elif dims == 2:
134
+ if distance == "manhattan":
135
+ df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
136
+ elif distance == "euclidean":
137
+ df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
138
+
139
+ n = math.sqrt(max_positions)
140
+ assert n.is_integer(), n
141
+ n = int(n)
142
+
143
+ pos_bias = torch.zeros((max_positions, max_positions))
144
+
145
+ for i in range(n):
146
+ for j in range(n):
147
+ for k in range(n):
148
+ for l in range(n):
149
+ new_x = i * n + j
150
+ new_y = k * n + l
151
+ pos_bias[new_x, new_y] = -df(i, j, k, l)
152
+
153
+ else:
154
+ raise Exception(f"unsupported number of alibi dims: {dims}")
155
+
156
+ alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
157
+ attn_heads, -1, -1
158
+ )
159
+
160
+ return alibi_bias
161
+
162
+
163
+ def get_alibi_bias(
164
+ alibi_biases,
165
+ batch_size,
166
+ time_steps,
167
+ heads,
168
+ dtype,
169
+ device,
170
+ dims=1,
171
+ distance="manhattan",
172
+ ):
173
+ cache_key = f"{dims}_{heads}_{distance}"
174
+
175
+ buffered = alibi_biases.get(cache_key, None)
176
+
177
+ target_size = heads * batch_size
178
+ if (
179
+ buffered is None
180
+ or buffered.size(0) < target_size
181
+ or buffered.size(1) < time_steps
182
+ or buffered.dtype != dtype
183
+ or buffered.device != device
184
+ ):
185
+ bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
186
+ bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
187
+
188
+ buffered = (
189
+ get_alibi(bt, heads, dims=dims, distance=distance)
190
+ .to(dtype=dtype, device=device)
191
+ .repeat(bn, 1, 1)
192
+ )
193
+
194
+ alibi_biases[cache_key] = buffered
195
+
196
+ b = buffered[:target_size, :time_steps, :time_steps]
197
+ b = b.view(batch_size, heads, time_steps, time_steps)
198
+ return b
199
+
200
+
201
+ def is_xla_tensor(tensor):
202
+ return torch.is_tensor(tensor) and tensor.device.type == "xla"
203
+
204
+
205
+ def index_put(tensor, indices, value):
206
+ if is_xla_tensor(tensor):
207
+ for _ in range(indices.dim(), tensor.dim()):
208
+ indices = indices.unsqueeze(-1)
209
+ if indices.size(-1) < tensor.size(-1):
210
+ indices = indices.expand_as(tensor)
211
+ tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
212
+ else:
213
+ tensor[indices] = value
214
+ return tensor
215
+
216
+
217
+ def compute_mask_indices(
218
+ shape: Tuple[int, int],
219
+ padding_mask: Optional[torch.Tensor],
220
+ mask_prob: float,
221
+ mask_length: int,
222
+ mask_type: str = "static",
223
+ mask_other: float = 0.0,
224
+ min_masks: int = 0,
225
+ no_overlap: bool = False,
226
+ min_space: int = 0,
227
+ require_same_masks: bool = True,
228
+ mask_dropout: float = 0.0,
229
+ add_masks: bool = False,
230
+ seed: Optional[int] = None,
231
+ epoch: Optional[int] = None,
232
+ indices: Optional[torch.Tensor] = None,
233
+ idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
234
+ num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
235
+ ) -> np.ndarray:
236
+ """
237
+ Computes random mask spans for a given shape
238
+
239
+ Args:
240
+ shape: the the shape for which to compute masks.
241
+ should be of size 2 where first element is batch size and 2nd is timesteps
242
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
243
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
244
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
245
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
246
+ mask_type: how to compute mask lengths
247
+ static = fixed size
248
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
249
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
250
+ poisson = sample from possion distribution with lambda = mask length
251
+ min_masks: minimum number of masked spans
252
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
253
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
254
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
255
+ mask_dropout: randomly dropout this percentage of masks in each example
256
+ """
257
+
258
+ bsz, all_sz = shape
259
+ mask = np.full((bsz, all_sz), False)
260
+
261
+ if num_mask_ver == 1:
262
+ all_num_mask = int(
263
+ # add a random number for probabilistic rounding
264
+ mask_prob * all_sz / float(mask_length)
265
+ + np.random.rand()
266
+ )
267
+ all_num_mask = max(min_masks, all_num_mask)
268
+
269
+ mask_idcs = []
270
+ for i in range(bsz):
271
+ if seed is not None and epoch is not None and indices is not None:
272
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
273
+ else:
274
+ seed_i = None
275
+
276
+ rng = np.random.default_rng(seed_i)
277
+
278
+ if padding_mask is not None:
279
+ sz = all_sz - padding_mask[i].long().sum().item()
280
+ assert sz >= 0, sz
281
+ else:
282
+ sz = all_sz
283
+
284
+ if num_mask_ver == 1:
285
+ if padding_mask is not None:
286
+ num_mask = int(
287
+ # add a random number for probabilistic rounding
288
+ mask_prob * sz / float(mask_length)
289
+ + np.random.rand()
290
+ )
291
+ num_mask = max(min_masks, num_mask)
292
+ else:
293
+ num_mask = all_num_mask
294
+ elif num_mask_ver == 2:
295
+ num_mask = int(
296
+ # add a random number for probabilistic rounding
297
+ mask_prob * sz / float(mask_length)
298
+ + rng.random()
299
+ )
300
+ num_mask = max(min_masks, num_mask)
301
+ else:
302
+ raise ValueError()
303
+
304
+ if mask_type == "static":
305
+ lengths = np.full(num_mask, mask_length)
306
+ elif mask_type == "uniform":
307
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
308
+ elif mask_type == "normal":
309
+ lengths = rng.normal(mask_length, mask_other, size=num_mask)
310
+ lengths = [max(1, int(round(x))) for x in lengths]
311
+ elif mask_type == "poisson":
312
+ lengths = rng.poisson(mask_length, size=num_mask)
313
+ lengths = [int(round(x)) for x in lengths]
314
+ else:
315
+ raise Exception("unknown mask selection " + mask_type)
316
+
317
+ if sum(lengths) == 0:
318
+ if mask_type == "static":
319
+ raise ValueError(f"this should never happens")
320
+ else:
321
+ lengths = [min(mask_length, sz - 1)]
322
+
323
+ if no_overlap:
324
+ mask_idc = []
325
+
326
+ def arrange(s, e, length, keep_length):
327
+ span_start = rng.randint(s, e - length)
328
+ mask_idc.extend(span_start + i for i in range(length))
329
+
330
+ new_parts = []
331
+ if span_start - s - min_space >= keep_length:
332
+ new_parts.append((s, span_start - min_space + 1))
333
+ if e - span_start - length - min_space > keep_length:
334
+ new_parts.append((span_start + length + min_space, e))
335
+ return new_parts
336
+
337
+ parts = [(0, sz)]
338
+ min_length = min(lengths)
339
+ for length in sorted(lengths, reverse=True):
340
+ lens = np.fromiter(
341
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
342
+ np.int,
343
+ )
344
+ l_sum = np.sum(lens)
345
+ if l_sum == 0:
346
+ break
347
+ probs = lens / np.sum(lens)
348
+ c = rng.choice(len(parts), p=probs)
349
+ s, e = parts.pop(c)
350
+ parts.extend(arrange(s, e, length, min_length))
351
+ mask_idc = np.asarray(mask_idc)
352
+ else:
353
+ if idc_select_ver == 1:
354
+ min_len = min(lengths)
355
+ if sz - min_len <= num_mask:
356
+ min_len = sz - num_mask - 1
357
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
358
+ elif idc_select_ver == 2:
359
+ mask_idc = rng.choice(sz, num_mask, replace=False)
360
+ else:
361
+ raise ValueError()
362
+
363
+ mask_idc = np.asarray(
364
+ [
365
+ mask_idc[j] + offset
366
+ for j in range(len(mask_idc))
367
+ for offset in range(lengths[j])
368
+ ]
369
+ )
370
+
371
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
372
+ if len(mask_idc) >= sz:
373
+ raise ValueError(
374
+ (
375
+ f"the entire sequence is masked. "
376
+ f"sz={sz}; mask_idc[mask_idc]; "
377
+ f"index={indices[i] if indices is not None else None}"
378
+ )
379
+ )
380
+ mask_idcs.append(mask_idc)
381
+
382
+ target_len = None
383
+ if require_same_masks:
384
+ if add_masks:
385
+ target_len = max([len(m) for m in mask_idcs])
386
+ else:
387
+ target_len = min([len(m) for m in mask_idcs])
388
+
389
+ for i, mask_idc in enumerate(mask_idcs):
390
+ if target_len is not None and len(mask_idc) > target_len:
391
+ mask_idc = rng.choice(mask_idc, target_len, replace=False)
392
+
393
+ mask[i, mask_idc] = True
394
+
395
+ if target_len is not None and len(mask_idc) < target_len:
396
+ unmasked = np.flatnonzero(~mask[i])
397
+ to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
398
+ mask[i, to_mask] = True
399
+
400
+ if mask_dropout > 0:
401
+ masked = np.flatnonzero(mask[i])
402
+ num_holes = np.rint(len(masked) * mask_dropout).astype(int)
403
+ to_drop = rng.choice(masked, num_holes, replace=False)
404
+ mask[i, to_drop] = False
405
+
406
+ return mask
407
+
408
+
409
+ def _learned_alibi_bias(
410
+ alibi_bias,
411
+ batch_size,
412
+ time_steps,
413
+ heads,
414
+ scale,
415
+ dtype,
416
+ device,
417
+ ):
418
+ assert alibi_bias.size(1) == heads, alibi_bias.shape
419
+ assert alibi_bias.dtype == dtype, alibi_bias.dtype
420
+ assert alibi_bias.device == device, alibi_bias.device
421
+
422
+ if alibi_bias.size(-1) < time_steps:
423
+ psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
424
+ alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
425
+
426
+ alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
427
+ return alibi_bias[..., :time_steps, :time_steps]
428
+
429
+ def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
430
+ """Replace non-padding symbols with their position numbers.
431
+
432
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
433
+ """
434
+ # The series of casts and type-conversions here are carefully
435
+ # balanced to both work with ONNX export and XLA. In particular XLA
436
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
437
+ # how to handle the dtype kwarg in cumsum.
438
+ mask = tensor.ne(padding_idx).int()
439
+ return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx