flaubert commited on
Commit
b547b8e
·
verified ·
1 Parent(s): ca8bc38

Upload folder using huggingface_hub

Browse files
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<mask>": 32004
3
+ }
config.json CHANGED
@@ -3,17 +3,25 @@
3
  "activation_dropout": 0.0,
4
  "add_cross_attention": false,
5
  "architectures": [
6
- "Data2Vec2MultiModel"
7
  ],
8
  "attention_dropout": 0.1,
9
  "auto_map": {
10
- "AutoConfig": "configuration_data2vec2.Data2Vec2MultiConfig",
11
- "AutoModel": "modeling_data2vec2.Data2Vec2MultiModel"
 
 
 
 
 
 
 
12
  },
13
  "bad_words_ids": null,
14
  "begin_suppress_tokens": null,
15
  "bos_token_id": null,
16
  "chunk_size_feed_forward": 0,
 
17
  "clone_batch": 8,
18
  "cross_attention_hidden_size": null,
19
  "decoder_start_token_id": null,
@@ -30,6 +38,7 @@
30
  "end_of_block_targets": false,
31
  "eos_token_id": null,
32
  "exponential_decay_length_penalty": null,
 
33
  "finetuning_task": null,
34
  "forced_bos_token_id": null,
35
  "forced_eos_token_id": null,
@@ -57,6 +66,9 @@
57
  "architectures": null,
58
  "audio": {
59
  "_name_or_path": "",
 
 
 
60
  "add_cross_attention": false,
61
  "add_masks": false,
62
  "alibi_max_pos": null,
@@ -66,11 +78,14 @@
66
  "begin_suppress_tokens": null,
67
  "bos_token_id": null,
68
  "chunk_size_feed_forward": 0,
 
69
  "conv_pos_depth": 5,
70
  "conv_pos_groups": 16,
71
  "conv_pos_pre_ln": false,
72
  "conv_pos_width": 95,
73
  "cross_attention_hidden_size": null,
 
 
74
  "decoder_start_token_id": null,
75
  "diversity_penalty": 0.0,
76
  "do_sample": false,
@@ -108,22 +123,30 @@
108
  "mask_channel_length": 64,
109
  "mask_channel_prob": 0.0,
110
  "mask_dropout": 0.0,
 
 
 
111
  "mask_length": 5,
112
  "mask_noise_std": 0.01,
113
  "mask_prob": 0.7,
114
  "mask_prob_adjust": 0.0,
115
  "mask_prob_min": null,
 
 
 
116
  "max_length": 20,
117
  "min_length": 0,
118
  "model_depth": 12,
119
  "model_type": "",
120
  "no_repeat_ngram_size": 0,
 
121
  "num_alibi_heads": 12,
122
  "num_beam_groups": 1,
123
  "num_beams": 1,
124
  "num_extra_tokens": 0,
125
  "num_return_sequences": 1,
126
  "output_attentions": false,
 
127
  "output_hidden_states": false,
128
  "output_scores": false,
129
  "pad_token_id": null,
@@ -142,6 +165,27 @@
142
  "start_drop_path_rate": 0.0,
143
  "suppress_tokens": null,
144
  "task_specific_params": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  "temperature": 1.0,
146
  "tie_encoder_decoder": false,
147
  "tie_word_embeddings": true,
@@ -151,7 +195,10 @@
151
  "torchscript": false,
152
  "type": "AUDIO",
153
  "typical_p": 1.0,
154
- "use_alibi_encoder": false
 
 
 
155
  },
156
  "bad_words_ids": null,
157
  "begin_suppress_tokens": null,
@@ -310,7 +357,7 @@
310
  "torchscript": false,
311
  "typical_p": 1.0
312
  },
313
- "model_type": "data2vec2",
314
  "n_layers": 12,
315
  "no_repeat_ngram_size": 0,
316
  "norm_affine": true,
 
3
  "activation_dropout": 0.0,
4
  "add_cross_attention": false,
5
  "architectures": [
6
+ "PantagruelUniModel"
7
  ],
8
  "attention_dropout": 0.1,
9
  "auto_map": {
10
+ "AutoConfig": "configuration_pantagruel_uni.PantagruelUniConfig",
11
+ "AutoModel": "modeling_pantagruel_uni.PantagruelUniModel",
12
+ "AutoModelForAudioFrameClassification": "modeling_pantagruel_uni.PantagruelUniForAudioFrameClassification",
13
+ "AutoModelForCTC": "modeling_pantagruel_uni.PantagruelUniForCTC",
14
+ "AutoModelForMaskedLM": "modeling_pantagruel_uni.PantagruelUniForMaskedLM",
15
+ "AutoModelForMultipleChoice": "modeling_pantagruel_uni.PantagruelUniForMultipleChoice",
16
+ "AutoModelForQuestionAnswering": "modeling_pantagruel_uni.PantagruelUniForQuestionAnswering",
17
+ "AutoModelForSequenceClassification": "modeling_pantagruel_uni.PantagruelUniForSequenceClassification",
18
+ "AutoModelForTokenClassification": "modeling_pantagruel_uni.PantagruelUniForTokenClassification"
19
  },
20
  "bad_words_ids": null,
21
  "begin_suppress_tokens": null,
22
  "bos_token_id": null,
23
  "chunk_size_feed_forward": 0,
24
+ "classifier_dropout": null,
25
  "clone_batch": 8,
26
  "cross_attention_hidden_size": null,
27
  "decoder_start_token_id": null,
 
38
  "end_of_block_targets": false,
39
  "eos_token_id": null,
40
  "exponential_decay_length_penalty": null,
41
+ "final_dropout": 0.1,
42
  "finetuning_task": null,
43
  "forced_bos_token_id": null,
44
  "forced_eos_token_id": null,
 
66
  "architectures": null,
67
  "audio": {
68
  "_name_or_path": "",
69
+ "adapter_kernel_size": 3,
70
+ "adapter_stride": 2,
71
+ "add_adapter": false,
72
  "add_cross_attention": false,
73
  "add_masks": false,
74
  "alibi_max_pos": null,
 
78
  "begin_suppress_tokens": null,
79
  "bos_token_id": null,
80
  "chunk_size_feed_forward": 0,
81
+ "classifier_proj_size": 256,
82
  "conv_pos_depth": 5,
83
  "conv_pos_groups": 16,
84
  "conv_pos_pre_ln": false,
85
  "conv_pos_width": 95,
86
  "cross_attention_hidden_size": null,
87
+ "ctc_loss_reduction": "sum",
88
+ "ctc_zero_infinity": false,
89
  "decoder_start_token_id": null,
90
  "diversity_penalty": 0.0,
91
  "do_sample": false,
 
123
  "mask_channel_length": 64,
124
  "mask_channel_prob": 0.0,
125
  "mask_dropout": 0.0,
126
+ "mask_feature_length": 10,
127
+ "mask_feature_min_masks": 0,
128
+ "mask_feature_prob": 0.0,
129
  "mask_length": 5,
130
  "mask_noise_std": 0.01,
131
  "mask_prob": 0.7,
132
  "mask_prob_adjust": 0.0,
133
  "mask_prob_min": null,
134
+ "mask_time_length": 10,
135
+ "mask_time_min_masks": 2,
136
+ "mask_time_prob": 0.05,
137
  "max_length": 20,
138
  "min_length": 0,
139
  "model_depth": 12,
140
  "model_type": "",
141
  "no_repeat_ngram_size": 0,
142
+ "num_adapter_layers": 3,
143
  "num_alibi_heads": 12,
144
  "num_beam_groups": 1,
145
  "num_beams": 1,
146
  "num_extra_tokens": 0,
147
  "num_return_sequences": 1,
148
  "output_attentions": false,
149
+ "output_hidden_size": null,
150
  "output_hidden_states": false,
151
  "output_scores": false,
152
  "pad_token_id": null,
 
165
  "start_drop_path_rate": 0.0,
166
  "suppress_tokens": null,
167
  "task_specific_params": null,
168
+ "tdnn_dilation": [
169
+ 1,
170
+ 2,
171
+ 3,
172
+ 1,
173
+ 1
174
+ ],
175
+ "tdnn_dim": [
176
+ 512,
177
+ 512,
178
+ 512,
179
+ 512,
180
+ 1500
181
+ ],
182
+ "tdnn_kernel": [
183
+ 5,
184
+ 3,
185
+ 3,
186
+ 1,
187
+ 1
188
+ ],
189
  "temperature": 1.0,
190
  "tie_encoder_decoder": false,
191
  "tie_word_embeddings": true,
 
195
  "torchscript": false,
196
  "type": "AUDIO",
197
  "typical_p": 1.0,
198
+ "use_alibi_encoder": false,
199
+ "use_weighted_layer_sum": false,
200
+ "vocab_size": 80,
201
+ "xvector_output_dim": 512
202
  },
203
  "bad_words_ids": null,
204
  "begin_suppress_tokens": null,
 
357
  "torchscript": false,
358
  "typical_p": 1.0
359
  },
360
+ "model_type": "pantagruel_uni",
361
  "n_layers": 12,
362
  "no_repeat_ngram_size": 0,
363
  "norm_affine": true,
configuration_data2vec2.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ #
9
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ """ Data2Vec2 multi configuration"""
25
+
26
+ import os
27
+ from typing import Union, Dict, Any, Optional
28
+ from transformers.dynamic_module_utils import custom_object_save
29
+ from transformers.utils import logging
30
+ from transformers.configuration_utils import PretrainedConfig, CONFIG_NAME
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class MyPretrainedConfig(PretrainedConfig):
37
+ def __init__(self, **kwargs):
38
+ super().__init__(**kwargs)
39
+
40
+ def to_json_string(self, use_diff: bool = False) -> str:
41
+ return super().to_json_string(use_diff)
42
+
43
+ def update(self, config_dict):
44
+ for key, value in config_dict.items():
45
+ if not hasattr(self, key):
46
+ continue
47
+ if isinstance(getattr(self, key), MyPretrainedConfig):
48
+ getattr(self, key).update(config_dict[key])
49
+ else:
50
+ setattr(self, key, value)
51
+
52
+ # Copied from the parent class, only changed use_diff from True to False to correctly save nested config class
53
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
54
+ """
55
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
56
+ [`~PretrainedConfig.from_pretrained`] class method.
57
+
58
+ Args:
59
+ save_directory (`str` or `os.PathLike`):
60
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
61
+ push_to_hub (`bool`, *optional*, defaults to `False`):
62
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
63
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
64
+ namespace).
65
+ kwargs (`Dict[str, Any]`, *optional*):
66
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
67
+ """
68
+ self._set_token_in_kwargs(kwargs)
69
+
70
+ if os.path.isfile(save_directory):
71
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
72
+
73
+ non_default_generation_parameters = {}
74
+ for parameter_name, default_value in self._get_global_generation_defaults().items():
75
+ if hasattr(self, parameter_name) and getattr(self, parameter_name) != default_value:
76
+ non_default_generation_parameters[parameter_name] = getattr(self, parameter_name)
77
+ if len(non_default_generation_parameters) > 0:
78
+ logger.warning(
79
+ "Some non-default generation parameters are set in the model config. These should go into a "
80
+ "GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) "
81
+ "instead. This warning will be raised to an exception in v4.41.\n"
82
+ f"Non-default generation parameters: {str(non_default_generation_parameters)}"
83
+ )
84
+
85
+ os.makedirs(save_directory, exist_ok=True)
86
+
87
+ if push_to_hub:
88
+ commit_message = kwargs.pop("commit_message", None)
89
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
90
+ repo_id = self._create_repo(repo_id, **kwargs)
91
+ files_timestamps = self._get_files_timestamps(save_directory)
92
+
93
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
94
+ # loaded from the Hub.
95
+ if self._auto_class is not None:
96
+ custom_object_save(self, save_directory, config=self)
97
+
98
+ # If we save using the predefined names, we can load using `from_pretrained`
99
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
100
+
101
+ self.to_json_file(output_config_file, use_diff=False)
102
+ logger.info(f"Configuration saved in {output_config_file}")
103
+
104
+ if push_to_hub:
105
+ self._upload_modified_files(
106
+ save_directory,
107
+ repo_id,
108
+ files_timestamps,
109
+ commit_message=commit_message,
110
+ token=kwargs.get("token"),
111
+ )
112
+
113
+ # Copied from the parent class, change the instantiation and updating of class from config_dict to correctly load nested config
114
+ @classmethod
115
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "MyPretrainedConfig":
116
+ """
117
+ Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
118
+
119
+ Args:
120
+ config_dict (`Dict[str, Any]`):
121
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
122
+ retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
123
+ kwargs (`Dict[str, Any]`):
124
+ Additional parameters from which to initialize the configuration object.
125
+
126
+ Returns:
127
+ [`PretrainedConfig`]: The configuration object instantiated from those parameters.
128
+ """
129
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
130
+ # Those arguments may be passed along for our internal telemetry.
131
+ # We remove them so they don't appear in `return_unused_kwargs`.
132
+ kwargs.pop("_from_auto", None)
133
+ kwargs.pop("_from_pipeline", None)
134
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
135
+ if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
136
+ kwargs["_commit_hash"] = config_dict["_commit_hash"]
137
+
138
+ # We remove it from kwargs so that it does not appear in `return_unused_kwargs`.
139
+ config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
140
+
141
+ # config = cls(**config_dict)
142
+ # My updated config
143
+ config = cls()
144
+ for key, value in config_dict.items():
145
+ if not hasattr(config, key):
146
+ continue
147
+ if isinstance(getattr(config, key), MyPretrainedConfig):
148
+ getattr(config, key).update(config_dict[key])
149
+ else:
150
+ setattr(config, key, value)
151
+
152
+
153
+ if hasattr(config, "pruned_heads"):
154
+ config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
155
+
156
+ # Update config with kwargs if needed
157
+ if "num_labels" in kwargs and "id2label" in kwargs:
158
+ num_labels = kwargs["num_labels"]
159
+ id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
160
+ if len(id2label) != num_labels:
161
+ raise ValueError(
162
+ f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
163
+ f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
164
+ "one of them."
165
+ )
166
+ to_remove = []
167
+ for key, value in kwargs.items():
168
+ if hasattr(config, key):
169
+ current_attr = getattr(config, key)
170
+ # To authorize passing a custom subconfig as kwarg in models that have nested configs.
171
+ if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
172
+ value = current_attr.__class__(**value)
173
+ setattr(config, key, value)
174
+ if key != "torch_dtype":
175
+ to_remove.append(key)
176
+ for key in to_remove:
177
+ kwargs.pop(key, None)
178
+
179
+ logger.info(f"Model config {config}")
180
+ if return_unused_kwargs:
181
+ return config, kwargs
182
+ else:
183
+ return config
184
+
185
+
186
+ class D2v2ModalityConfig(MyPretrainedConfig):
187
+ def __init__(
188
+ self,
189
+ type="AUDIO",
190
+ prenet_depth=4,
191
+ prenet_layerdrop=0,
192
+ prenet_dropout=0.0,
193
+ start_drop_path_rate=0.0,
194
+ end_drop_path_rate=0.0,
195
+ num_extra_tokens=0,
196
+ init_extra_token_zero=True,
197
+ mask_noise_std=0.01,
198
+ mask_prob_min=None,
199
+ mask_prob=0.7,
200
+ inverse_mask=False,
201
+ mask_prob_adjust=0.0,
202
+ keep_masked_pct=0.0,
203
+ mask_length=5,
204
+ add_masks=False,
205
+ remove_masks=False,
206
+ mask_dropout=0.0,
207
+ encoder_zero_mask=True,
208
+ mask_channel_prob=0.0,
209
+ mask_channel_length=64,
210
+ local_grad_mult=1.0,
211
+ use_alibi_encoder=False,
212
+ alibi_scale=1.0,
213
+ learned_alibi=False,
214
+ alibi_max_pos=None,
215
+ learned_alibi_scale=False,
216
+ learned_alibi_scale_per_head=False,
217
+ learned_alibi_scale_per_layer=False,
218
+ num_alibi_heads=12,
219
+ model_depth=12,
220
+ ema_local_encoder=False,
221
+ decoder=None,
222
+ **kwargs,
223
+ ):
224
+ super().__init__(**kwargs)
225
+ self.type = type
226
+ self.prenet_depth = prenet_depth
227
+ self.prenet_layerdrop = prenet_layerdrop
228
+ self.prenet_dropout = prenet_dropout
229
+ self.start_drop_path_rate = start_drop_path_rate
230
+ self.end_drop_path_rate = end_drop_path_rate
231
+ self.num_extra_tokens = num_extra_tokens
232
+ self.init_extra_token_zero = init_extra_token_zero
233
+ self.mask_noise_std = mask_noise_std
234
+ self.mask_prob_min = mask_prob_min
235
+ self.mask_prob = mask_prob
236
+ self.inverse_mask = inverse_mask
237
+ self.mask_prob_adjust = mask_prob_adjust
238
+ self.keep_masked_pct = keep_masked_pct
239
+ self.mask_length = mask_length
240
+ self.add_masks = add_masks
241
+ self.remove_masks = remove_masks
242
+ self.mask_dropout = mask_dropout
243
+ self.encoder_zero_mask = encoder_zero_mask
244
+ self.mask_channel_prob = mask_channel_prob
245
+ self.mask_channel_length = mask_channel_length
246
+ self.local_grad_mult = local_grad_mult
247
+ self.use_alibi_encoder = use_alibi_encoder
248
+ self.alibi_scale = alibi_scale
249
+ self.learned_alibi = learned_alibi
250
+ self.alibi_max_pos = alibi_max_pos
251
+ self.learned_alibi_scale = learned_alibi_scale
252
+ self.learned_alibi_scale_per_head = learned_alibi_scale_per_head
253
+ self.learned_alibi_scale_per_layer = learned_alibi_scale_per_layer
254
+ self.num_alibi_heads = num_alibi_heads
255
+ self.model_depth = model_depth
256
+
257
+
258
+ class D2v2AudioConfig(D2v2ModalityConfig):
259
+ """
260
+ Configuration including common args and args specific to audio-only pre-training
261
+ """
262
+ def __init__(
263
+ self,
264
+ extractor_mode="layer_norm",
265
+ feature_encoder_spec="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
266
+ conv_pos_width=95,
267
+ conv_pos_groups=16,
268
+ conv_pos_depth=5,
269
+ conv_pos_pre_ln=False,
270
+ **kwargs,
271
+ ):
272
+ super().__init__(type="AUDIO", **kwargs)
273
+ self.extractor_mode = extractor_mode
274
+ self.feature_encoder_spec = feature_encoder_spec
275
+ self.conv_pos_width = conv_pos_width
276
+ self.conv_pos_groups = conv_pos_groups
277
+ self.conv_pos_depth = conv_pos_depth
278
+ self.conv_pos_pre_ln = conv_pos_pre_ln
279
+
280
+
281
+ class D2v2TextConfig(D2v2ModalityConfig):
282
+ """
283
+ Configuration including common args and args specific to text-only pre-training
284
+ """
285
+ def __init__(
286
+ self,
287
+ vocab_size=50000,
288
+ unk_token_id=3,
289
+ bos_token_id=0,
290
+ eos_token_id=2,
291
+ pad_token_id=1,
292
+ max_source_positions=512,
293
+ learned_pos=True,
294
+ dropout=0.1,
295
+ no_scale_embedding=True,
296
+ layernorm_embedding=True,
297
+ no_token_positional_embeddings=False,
298
+ **kwargs,
299
+ ):
300
+ super().__init__(type="TEXT", **kwargs)
301
+ self.vocab_size = vocab_size
302
+ self.unk_token_id = unk_token_id
303
+ self.bos_token_id = bos_token_id
304
+ self.eos_token_id = eos_token_id
305
+ self.pad_token_id = pad_token_id
306
+ self.max_source_positions = max_source_positions
307
+ self.learned_pos = learned_pos
308
+ self.dropout = dropout
309
+ self.no_scale_embedding = no_scale_embedding
310
+ self.layernorm_embedding = layernorm_embedding
311
+ self.no_token_positional_embeddings = no_token_positional_embeddings
312
+
313
+
314
+ class D2v2ModalitiesConfig(MyPretrainedConfig):
315
+ def __init__(
316
+ self,
317
+ audio_config=D2v2AudioConfig(),
318
+ text_config=D2v2TextConfig(),
319
+ **kwargs
320
+ ):
321
+ super().__init__(**kwargs)
322
+ self.audio = audio_config
323
+ self.text = text_config
324
+
325
+
326
+ class Data2Vec2MultiConfig(MyPretrainedConfig):
327
+ r"""
328
+ This is the configuration class to store the configuration of a [`Data2Vec2MultiModel`]. It is used to instantiate
329
+ an Data2Vec2MultiModel model according to the specified arguments, defining the model architecture.
330
+
331
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
332
+ documentation from [`PretrainedConfig`] for more information.
333
+
334
+
335
+ Args:
336
+ depth (`int`, *optional*, defaults to 12):
337
+ Number of Transformer layers in the encoder.
338
+
339
+ Example:
340
+
341
+ ```python
342
+ >>> from transformers import Data2Vec2MultiConfig, Data2Vec2MultiModel
343
+
344
+ >>> # Initializing a Data2Vec2MultiConfig for audio
345
+ >>> configuration = Data2Vec2MultiConfig()
346
+
347
+ >>> # Initializing a model (with random weights) with the configuration
348
+ >>> model = Data2Vec2MultiModel(configuration)
349
+
350
+ >>> # Accessing the model configuration
351
+ >>> configuration = model.config
352
+ ```"""
353
+
354
+ model_type = "data2vec2"
355
+
356
+ def __init__(
357
+ self,
358
+ depth=12,
359
+ start_drop_path_rate=0.0,
360
+ end_drop_path_rate=0.0,
361
+ num_heads=12,
362
+ norm_eps=1e-5,
363
+ norm_affine=True,
364
+ encoder_dropout=0.1,
365
+ post_mlp_drop=0.1,
366
+ attention_dropout=0.1,
367
+ activation_dropout=0.0,
368
+ dropout_input=0.0,
369
+ layerdrop=0.0,
370
+ embed_dim=768,
371
+ mlp_ratio=4.0,
372
+ layer_norm_first=False,
373
+ end_of_block_targets=False,
374
+ clone_batch=1,
375
+ log_norms=True,
376
+ modalities=D2v2ModalitiesConfig(),
377
+ supported_modality="AUDIO",
378
+ **kwargs,
379
+ ):
380
+ super().__init__(**kwargs)
381
+
382
+ self.depth = depth
383
+ self.start_drop_path_rate = start_drop_path_rate
384
+ self.end_drop_path_rate = end_drop_path_rate
385
+
386
+ self.num_heads = num_heads
387
+ self.norm_eps = norm_eps
388
+ self.norm_affine = norm_affine
389
+ self.post_mlp_drop = post_mlp_drop
390
+ self.encoder_dropout = encoder_dropout
391
+ self.attention_dropout = attention_dropout
392
+ self.activation_dropout = activation_dropout
393
+ self.dropout_input = dropout_input
394
+ self.layerdrop = layerdrop
395
+ self.embed_dim = embed_dim
396
+ self.mlp_ratio = mlp_ratio
397
+
398
+ self.layer_norm_first = layer_norm_first
399
+ self.end_of_block_targets = end_of_block_targets
400
+ self.clone_batch = clone_batch
401
+ self.log_norms = log_norms
402
+
403
+ self.modalities = modalities
404
+ self.supported_modality = supported_modality
405
+
406
+ # Attributes for hopsparser
407
+ self.hidden_size = embed_dim
408
+ self.num_layers = depth
409
+ self.n_layers = depth
410
+ self.num_hidden_layers = depth
411
+
412
+ self.auto_map = {
413
+ 'AutoConfig': 'configuration_data2vec2.Data2Vec2MultiConfig',
414
+ 'AutoModel': 'modeling_data2vec2.Data2Vec2MultiModel',
415
+ }
configuration_pantagruel_uni.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ #
9
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ """ Pantagruel unimodal configuration"""
25
+
26
+ import os
27
+ from typing import Union, Dict, Any, Optional
28
+ from transformers.dynamic_module_utils import custom_object_save
29
+ from transformers.utils import logging
30
+ from transformers.configuration_utils import PretrainedConfig, CONFIG_NAME
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class MyPretrainedConfig(PretrainedConfig):
37
+ def __init__(self, **kwargs):
38
+ super().__init__(**kwargs)
39
+
40
+ def to_json_string(self, use_diff: bool = False) -> str:
41
+ return super().to_json_string(use_diff)
42
+
43
+ def update(self, config_dict):
44
+ for key, value in config_dict.items():
45
+ if not hasattr(self, key):
46
+ continue
47
+ if isinstance(getattr(self, key), MyPretrainedConfig):
48
+ getattr(self, key).update(config_dict[key])
49
+ else:
50
+ setattr(self, key, value)
51
+
52
+ # Copied from the parent class, only changed use_diff from True to False to correctly save nested config class
53
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
54
+ """
55
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
56
+ [`~PretrainedConfig.from_pretrained`] class method.
57
+
58
+ Args:
59
+ save_directory (`str` or `os.PathLike`):
60
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
61
+ push_to_hub (`bool`, *optional*, defaults to `False`):
62
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
63
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
64
+ namespace).
65
+ kwargs (`Dict[str, Any]`, *optional*):
66
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
67
+ """
68
+ self._set_token_in_kwargs(kwargs)
69
+
70
+ if os.path.isfile(save_directory):
71
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
72
+
73
+ non_default_generation_parameters = {}
74
+ for parameter_name, default_value in self._get_global_generation_defaults().items():
75
+ if hasattr(self, parameter_name) and getattr(self, parameter_name) != default_value:
76
+ non_default_generation_parameters[parameter_name] = getattr(self, parameter_name)
77
+ if len(non_default_generation_parameters) > 0:
78
+ logger.warning(
79
+ "Some non-default generation parameters are set in the model config. These should go into a "
80
+ "GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) "
81
+ "instead. This warning will be raised to an exception in v4.41.\n"
82
+ f"Non-default generation parameters: {str(non_default_generation_parameters)}"
83
+ )
84
+
85
+ os.makedirs(save_directory, exist_ok=True)
86
+
87
+ if push_to_hub:
88
+ commit_message = kwargs.pop("commit_message", None)
89
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
90
+ repo_id = self._create_repo(repo_id, **kwargs)
91
+ files_timestamps = self._get_files_timestamps(save_directory)
92
+
93
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
94
+ # loaded from the Hub.
95
+ if self._auto_class is not None:
96
+ custom_object_save(self, save_directory, config=self)
97
+
98
+ # If we save using the predefined names, we can load using `from_pretrained`
99
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
100
+
101
+ self.to_json_file(output_config_file, use_diff=False)
102
+ logger.info(f"Configuration saved in {output_config_file}")
103
+
104
+ if push_to_hub:
105
+ self._upload_modified_files(
106
+ save_directory,
107
+ repo_id,
108
+ files_timestamps,
109
+ commit_message=commit_message,
110
+ token=kwargs.get("token"),
111
+ )
112
+
113
+ # Copied from the parent class, change the instantiation and updating of class from config_dict to correctly load nested config
114
+ @classmethod
115
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "MyPretrainedConfig":
116
+ """
117
+ Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
118
+
119
+ Args:
120
+ config_dict (`Dict[str, Any]`):
121
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
122
+ retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
123
+ kwargs (`Dict[str, Any]`):
124
+ Additional parameters from which to initialize the configuration object.
125
+
126
+ Returns:
127
+ [`PretrainedConfig`]: The configuration object instantiated from those parameters.
128
+ """
129
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
130
+ # Those arguments may be passed along for our internal telemetry.
131
+ # We remove them so they don't appear in `return_unused_kwargs`.
132
+ kwargs.pop("_from_auto", None)
133
+ kwargs.pop("_from_pipeline", None)
134
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
135
+ if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
136
+ kwargs["_commit_hash"] = config_dict["_commit_hash"]
137
+
138
+ # We remove it from kwargs so that it does not appear in `return_unused_kwargs`.
139
+ config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
140
+
141
+ # config = cls(**config_dict)
142
+ # My updated config
143
+ config = cls()
144
+ for key, value in config_dict.items():
145
+ if not hasattr(config, key):
146
+ continue
147
+ if isinstance(getattr(config, key), MyPretrainedConfig):
148
+ getattr(config, key).update(config_dict[key])
149
+ else:
150
+ setattr(config, key, value)
151
+
152
+
153
+ if hasattr(config, "pruned_heads"):
154
+ config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
155
+
156
+ # Update config with kwargs if needed
157
+ if "num_labels" in kwargs and "id2label" in kwargs:
158
+ num_labels = kwargs["num_labels"]
159
+ id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
160
+ if len(id2label) != num_labels:
161
+ raise ValueError(
162
+ f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
163
+ f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
164
+ "one of them."
165
+ )
166
+ to_remove = []
167
+ for key, value in kwargs.items():
168
+ if hasattr(config, key):
169
+ current_attr = getattr(config, key)
170
+ # To authorize passing a custom subconfig as kwarg in models that have nested configs.
171
+ if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
172
+ value = current_attr.__class__(**value)
173
+ setattr(config, key, value)
174
+ if key != "torch_dtype":
175
+ to_remove.append(key)
176
+ for key in to_remove:
177
+ kwargs.pop(key, None)
178
+
179
+ logger.info(f"Model config {config}")
180
+ if return_unused_kwargs:
181
+ return config, kwargs
182
+ else:
183
+ return config
184
+
185
+
186
+ class PantagruelModalityConfig(MyPretrainedConfig):
187
+ """
188
+ Configuration including common args to both speech and text modality
189
+ """
190
+ def __init__(
191
+ self,
192
+ type="AUDIO",
193
+ prenet_depth=4,
194
+ prenet_layerdrop=0,
195
+ prenet_dropout=0.0,
196
+ start_drop_path_rate=0.0,
197
+ end_drop_path_rate=0.0,
198
+ num_extra_tokens=0,
199
+ init_extra_token_zero=True,
200
+ mask_noise_std=0.01,
201
+ mask_prob_min=None,
202
+ mask_prob=0.7,
203
+ inverse_mask=False,
204
+ mask_prob_adjust=0.0,
205
+ keep_masked_pct=0.0,
206
+ mask_length=5,
207
+ add_masks=False,
208
+ remove_masks=False,
209
+ mask_dropout=0.0,
210
+ encoder_zero_mask=True,
211
+ mask_channel_prob=0.0,
212
+ mask_channel_length=64,
213
+ local_grad_mult=1.0,
214
+ use_alibi_encoder=False,
215
+ alibi_scale=1.0,
216
+ learned_alibi=False,
217
+ alibi_max_pos=None,
218
+ learned_alibi_scale=False,
219
+ learned_alibi_scale_per_head=False,
220
+ learned_alibi_scale_per_layer=False,
221
+ num_alibi_heads=12,
222
+ model_depth=12,
223
+ ema_local_encoder=False,
224
+ decoder=None,
225
+ **kwargs,
226
+ ):
227
+ super().__init__(**kwargs)
228
+ self.type = type
229
+ self.prenet_depth = prenet_depth
230
+ self.prenet_layerdrop = prenet_layerdrop
231
+ self.prenet_dropout = prenet_dropout
232
+ self.start_drop_path_rate = start_drop_path_rate
233
+ self.end_drop_path_rate = end_drop_path_rate
234
+ self.num_extra_tokens = num_extra_tokens
235
+ self.init_extra_token_zero = init_extra_token_zero
236
+ self.mask_noise_std = mask_noise_std
237
+ self.mask_prob_min = mask_prob_min
238
+ self.mask_prob = mask_prob
239
+ self.inverse_mask = inverse_mask
240
+ self.mask_prob_adjust = mask_prob_adjust
241
+ self.keep_masked_pct = keep_masked_pct
242
+ self.mask_length = mask_length
243
+ self.add_masks = add_masks
244
+ self.remove_masks = remove_masks
245
+ self.mask_dropout = mask_dropout
246
+ self.encoder_zero_mask = encoder_zero_mask
247
+ self.mask_channel_prob = mask_channel_prob
248
+ self.mask_channel_length = mask_channel_length
249
+ self.local_grad_mult = local_grad_mult
250
+ self.use_alibi_encoder = use_alibi_encoder
251
+ self.alibi_scale = alibi_scale
252
+ self.learned_alibi = learned_alibi
253
+ self.alibi_max_pos = alibi_max_pos
254
+ self.learned_alibi_scale = learned_alibi_scale
255
+ self.learned_alibi_scale_per_head = learned_alibi_scale_per_head
256
+ self.learned_alibi_scale_per_layer = learned_alibi_scale_per_layer
257
+ self.num_alibi_heads = num_alibi_heads
258
+ self.model_depth = model_depth
259
+
260
+
261
+ class PantagruelAudioConfig(PantagruelModalityConfig):
262
+ """
263
+ Configuration including args specific to audio-only tasks
264
+ """
265
+ def __init__(
266
+ self,
267
+ vocab_size=80,
268
+ extractor_mode="layer_norm",
269
+ feature_encoder_spec="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
270
+ conv_pos_width=95,
271
+ conv_pos_groups=16,
272
+ conv_pos_depth=5,
273
+ conv_pos_pre_ln=False,
274
+ mask_time_prob=0.05,
275
+ mask_time_length=10,
276
+ mask_time_min_masks=2,
277
+ mask_feature_prob=0.0,
278
+ mask_feature_length=10,
279
+ mask_feature_min_masks=0,
280
+ ctc_loss_reduction="sum",
281
+ ctc_zero_infinity=False,
282
+ use_weighted_layer_sum=False,
283
+ classifier_proj_size=256,
284
+ tdnn_dim=(512, 512, 512, 512, 1500),
285
+ tdnn_kernel=(5, 3, 3, 1, 1),
286
+ tdnn_dilation=(1, 2, 3, 1, 1),
287
+ xvector_output_dim=512,
288
+ pad_token_id=0,
289
+ bos_token_id=1,
290
+ eos_token_id=2,
291
+ add_adapter=False,
292
+ adapter_kernel_size=3,
293
+ adapter_stride=2,
294
+ num_adapter_layers=3,
295
+ output_hidden_size=None,
296
+ **kwargs,
297
+ ):
298
+ super().__init__(type="AUDIO", **kwargs)
299
+ self.extractor_mode = extractor_mode
300
+ self.feature_encoder_spec = feature_encoder_spec
301
+ self.conv_pos_width = conv_pos_width
302
+ self.conv_pos_groups = conv_pos_groups
303
+ self.conv_pos_depth = conv_pos_depth
304
+ self.conv_pos_pre_ln = conv_pos_pre_ln
305
+
306
+ self.vocab_size = vocab_size
307
+ self.use_weighted_layer_sum = use_weighted_layer_sum
308
+
309
+ # fine-tuning config parameters for SpecAugment: https://huggingface.co/papers/1904.08779
310
+ self.mask_time_prob = mask_time_prob
311
+ self.mask_time_length = mask_time_length
312
+ self.mask_time_min_masks = mask_time_min_masks
313
+ self.mask_feature_prob = mask_feature_prob
314
+ self.mask_feature_length = mask_feature_length
315
+ self.mask_feature_min_masks = mask_feature_min_masks
316
+
317
+ # ctc loss
318
+ self.ctc_loss_reduction = ctc_loss_reduction
319
+ self.ctc_zero_infinity = ctc_zero_infinity
320
+
321
+ # adapter
322
+ self.add_adapter = add_adapter
323
+ self.adapter_kernel_size = adapter_kernel_size
324
+ self.adapter_stride = adapter_stride
325
+ self.num_adapter_layers = num_adapter_layers
326
+ self.output_hidden_size = output_hidden_size
327
+
328
+ # SequenceClassification-specific parameter. Feel free to ignore for other classes.
329
+ self.classifier_proj_size = classifier_proj_size
330
+
331
+ # XVector-specific parameters. Feel free to ignore for other classes.
332
+ self.tdnn_dim = list(tdnn_dim)
333
+ self.tdnn_kernel = list(tdnn_kernel)
334
+ self.tdnn_dilation = list(tdnn_dilation)
335
+ self.xvector_output_dim = xvector_output_dim
336
+
337
+
338
+ class PantagruelTextConfig(PantagruelModalityConfig):
339
+ """
340
+ Configuration including args specific to text-only tasks
341
+ """
342
+ def __init__(
343
+ self,
344
+ vocab_size=50000,
345
+ unk_token_id=3,
346
+ bos_token_id=0,
347
+ eos_token_id=2,
348
+ pad_token_id=1,
349
+ max_source_positions=512,
350
+ learned_pos=True,
351
+ dropout=0.1,
352
+ no_scale_embedding=True,
353
+ layernorm_embedding=True,
354
+ no_token_positional_embeddings=False,
355
+ **kwargs,
356
+ ):
357
+ super().__init__(type="TEXT", **kwargs)
358
+ self.vocab_size = vocab_size
359
+ self.unk_token_id = unk_token_id
360
+ self.bos_token_id = bos_token_id
361
+ self.eos_token_id = eos_token_id
362
+ self.pad_token_id = pad_token_id
363
+ self.max_source_positions = max_source_positions
364
+ self.learned_pos = learned_pos
365
+ self.dropout = dropout
366
+ self.no_scale_embedding = no_scale_embedding
367
+ self.layernorm_embedding = layernorm_embedding
368
+ self.no_token_positional_embeddings = no_token_positional_embeddings
369
+
370
+
371
+ class PantagruelModalitiesConfig(MyPretrainedConfig):
372
+ """
373
+ Container class for both audio and text modality configurations
374
+ """
375
+ def __init__(
376
+ self,
377
+ audio_config=PantagruelAudioConfig(),
378
+ text_config=PantagruelTextConfig(),
379
+ **kwargs
380
+ ):
381
+ super().__init__(**kwargs)
382
+ self.audio = audio_config
383
+ self.text = text_config
384
+
385
+
386
+ class PantagruelUniConfig(MyPretrainedConfig):
387
+ r"""
388
+ This is the configuration class to store the configuration of a [`PantagruelUniModel`].
389
+ It is used to instantiate an PantagruelUniModel model according to the specified arguments,
390
+ defining the model architecture.
391
+
392
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to
393
+ control the model outputs. Read the documentation from [`PretrainedConfig`] for more information.
394
+
395
+ Args:
396
+ depth (`int`, *optional*, defaults to 12):
397
+ Number of Transformer layers in the encoder.
398
+
399
+ Example:
400
+
401
+ ```python
402
+ >>> from transformers import PantagruelUniConfig, PantagruelUniModel
403
+
404
+ >>> # Initializing a PantagruelUniConfig for audio
405
+ >>> configuration = PantagruelUniConfig()
406
+
407
+ >>> # Initializing a model (with random weights) with the configuration
408
+ >>> model = PantagruelUniModel(configuration)
409
+
410
+ >>> # Accessing the model configuration
411
+ >>> configuration = model.config
412
+ ```
413
+ """
414
+
415
+ model_type = "pantagruel_uni"
416
+
417
+ def __init__(
418
+ self,
419
+ depth=12,
420
+ start_drop_path_rate=0.0,
421
+ end_drop_path_rate=0.0,
422
+ num_heads=12,
423
+ norm_eps=1e-5,
424
+ norm_affine=True,
425
+ encoder_dropout=0.1,
426
+ post_mlp_drop=0.1,
427
+ attention_dropout=0.1,
428
+ activation_dropout=0.0,
429
+ dropout_input=0.0,
430
+ final_dropout=0.1,
431
+ layerdrop=0.0,
432
+ embed_dim=768,
433
+ mlp_ratio=4.0,
434
+ layer_norm_first=False,
435
+ end_of_block_targets=False,
436
+ clone_batch=1,
437
+ log_norms=True,
438
+ modalities=PantagruelModalitiesConfig(),
439
+ supported_modality="AUDIO",
440
+ classifier_dropout=None,
441
+ **kwargs,
442
+ ):
443
+ super().__init__(**kwargs)
444
+
445
+ self.depth = depth
446
+ self.start_drop_path_rate = start_drop_path_rate
447
+ self.end_drop_path_rate = end_drop_path_rate
448
+
449
+ self.num_heads = num_heads
450
+ self.norm_eps = norm_eps
451
+ self.norm_affine = norm_affine
452
+ self.post_mlp_drop = post_mlp_drop
453
+ self.encoder_dropout = encoder_dropout
454
+ self.attention_dropout = attention_dropout
455
+ self.activation_dropout = activation_dropout
456
+ self.dropout_input = dropout_input
457
+ self.final_dropout = final_dropout
458
+ self.layerdrop = layerdrop
459
+ self.embed_dim = embed_dim
460
+ self.mlp_ratio = mlp_ratio
461
+
462
+ self.layer_norm_first = layer_norm_first
463
+ self.end_of_block_targets = end_of_block_targets
464
+ self.clone_batch = clone_batch
465
+ self.log_norms = log_norms
466
+
467
+ self.modalities = modalities
468
+ self.supported_modality = supported_modality
469
+
470
+ # Attributes for hopsparser
471
+ self.hidden_size = embed_dim
472
+ self.num_layers = depth
473
+ self.n_layers = depth
474
+ self.num_hidden_layers = depth
475
+
476
+ self.classifier_dropout = classifier_dropout
477
+
478
+ self.auto_map = {
479
+ 'AutoConfig': 'configuration_pantagruel_uni.PantagruelUniConfig',
480
+ 'AutoModel': 'modeling_pantagruel_uni.PantagruelUniModel',
481
+ 'AutoModelForMaskedLM': 'modeling_pantagruel_uni.PantagruelUniForMaskedLM',
482
+ 'AutoModelForSequenceClassification': 'modeling_pantagruel_uni.PantagruelUniForSequenceClassification',
483
+ 'AutoModelForMultipleChoice': 'modeling_pantagruel_uni.PantagruelUniForMultipleChoice',
484
+ 'AutoModelForTokenClassification': 'modeling_pantagruel_uni.PantagruelUniForTokenClassification',
485
+ 'AutoModelForQuestionAnswering': 'modeling_pantagruel_uni.PantagruelUniForQuestionAnswering',
486
+ 'AutoModelForAudioFrameClassification': 'modeling_pantagruel_uni.PantagruelUniForAudioFrameClassification',
487
+ 'AutoModelForCTC': 'modeling_pantagruel_uni.PantagruelUniForCTC',
488
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:89a7dbbd1f3880bc6ffe840ab10a8d3d5a5423bd8988b96f20a0a39e0a770f1d
3
- size 442498896
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:317e4bbe04553ef8b52c350a9221e58ab6d086c0e5ef4c7594007d7dd019aaf1
3
+ size 440136336
modeling_data2vec2.py ADDED
@@ -0,0 +1,1505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ # Copyright 2022 the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ # Copyright from Fairseq
23
+
24
+ """ PyTorch Data2Vec2 Multi model."""
25
+ import math
26
+ import warnings
27
+ from typing import Optional, Tuple, Dict, List, Callable, Any
28
+ from functools import partial
29
+ from dataclasses import dataclass
30
+
31
+ import numpy as np
32
+
33
+ import torch
34
+ import torch.nn.functional as F
35
+ from torch import nn
36
+ from torch import Tensor
37
+
38
+ from transformers import PreTrainedModel
39
+ from transformers.utils import ModelOutput
40
+ from .configuration_data2vec2 import (
41
+ Data2Vec2MultiConfig,
42
+ D2v2ModalityConfig,
43
+ D2v2AudioConfig,
44
+ D2v2TextConfig,
45
+ )
46
+
47
+ from .utils_data2vec2 import (
48
+ _learned_alibi_bias,
49
+ gather_unmasked,
50
+ gather_unmasked_mask,
51
+ masked_alibi,
52
+ random_masking,
53
+ get_alibi_bias,
54
+ compute_mask_indices,
55
+ index_put,
56
+ MaskInfo, MaskSeed,
57
+ make_positions,
58
+ )
59
+
60
+
61
+ @dataclass
62
+ class Data2vec2BaseModelOutput(ModelOutput):
63
+ last_hidden_state: Optional[torch.FloatTensor] = None # output of the encoder-only model
64
+ pooler_output: Optional[torch.FloatTensor] = None # pooled output for text tasks, which is the first token representation followed by a dense layer and activation function
65
+ local_features: Optional[torch.FloatTensor] = None # features before the Transformer encoder
66
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
67
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None # TODO: only support manual implementation with fast=False in the forward pass of AltAttention as pytorch's dspa does not output attention weights
68
+
69
+
70
+ #################################################
71
+ ### modeling_data2vec2_base.py
72
+ # copied from fairseq.modules.grad_multiply
73
+ class GradMultiply(torch.autograd.Function):
74
+ @staticmethod
75
+ def forward(ctx, x, scale):
76
+ ctx.scale = scale
77
+ res = x.new(x)
78
+ return res
79
+
80
+ @staticmethod
81
+ def backward(ctx, grad):
82
+ return grad * ctx.scale, None
83
+
84
+
85
+ # Copied from fairseq.modules.transpose_last.py
86
+ class TransposeLast(nn.Module):
87
+ def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
88
+ super().__init__()
89
+ self.deconstruct_idx = deconstruct_idx
90
+ self.tranpose_dim = tranpose_dim
91
+
92
+ def forward(self, x):
93
+ if self.deconstruct_idx is not None:
94
+ x = x[self.deconstruct_idx]
95
+ return x.transpose(self.tranpose_dim, -1)
96
+
97
+
98
+ # Copied from fairseq.modules.layer_norm.py
99
+ class Fp32LayerNorm(nn.LayerNorm):
100
+ def __init__(self, *args, **kwargs):
101
+ super().__init__(*args, **kwargs)
102
+
103
+ def forward(self, input):
104
+ output = F.layer_norm(
105
+ input.float(),
106
+ self.normalized_shape,
107
+ self.weight.float() if self.weight is not None else None,
108
+ self.bias.float() if self.bias is not None else None,
109
+ self.eps,
110
+ )
111
+ return output.type_as(input)
112
+
113
+
114
+ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
115
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
116
+
117
+
118
+ # Copied from fairseq.modules.fp32_group_norm.py
119
+ class Fp32GroupNorm(nn.GroupNorm):
120
+ def __init__(self, *args, **kwargs):
121
+ super().__init__(*args, **kwargs)
122
+
123
+ def forward(self, input):
124
+ output = F.group_norm(
125
+ input.float(),
126
+ self.num_groups,
127
+ self.weight.float() if self.weight is not None else None,
128
+ self.bias.float() if self.bias is not None else None,
129
+ self.eps,
130
+ )
131
+ return output.type_as(input)
132
+
133
+
134
+ # Copied from fairseq.modules.same_pad.py
135
+ class SamePad(nn.Module):
136
+ def __init__(self, kernel_size, causal=False):
137
+ super().__init__()
138
+ if causal:
139
+ self.remove = kernel_size - 1
140
+ else:
141
+ self.remove = 1 if kernel_size % 2 == 0 else 0
142
+
143
+ def forward(self, x):
144
+ if self.remove > 0:
145
+ x = x[:, :, : -self.remove]
146
+ return x
147
+
148
+
149
+ # Copied from fairseq.models.wav2vec.wav2vec2.py
150
+ class ConvFeatureExtractionModel(nn.Module):
151
+ def __init__(
152
+ self,
153
+ conv_layers: List[Tuple[int, int, int]],
154
+ dropout: float = 0.0,
155
+ mode: str = "default",
156
+ conv_bias: bool = False,
157
+ ):
158
+ super().__init__()
159
+
160
+ assert mode in {"default", "layer_norm"}
161
+
162
+ def block(
163
+ n_in,
164
+ n_out,
165
+ k,
166
+ stride,
167
+ is_layer_norm=False,
168
+ is_group_norm=False,
169
+ conv_bias=False,
170
+ ):
171
+ def make_conv():
172
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
173
+ nn.init.kaiming_normal_(conv.weight)
174
+ return conv
175
+
176
+ assert (
177
+ is_layer_norm and is_group_norm
178
+ ) == False, "layer norm and group norm are exclusive"
179
+
180
+ if is_layer_norm:
181
+ return nn.Sequential(
182
+ make_conv(),
183
+ nn.Dropout(p=dropout),
184
+ nn.Sequential(
185
+ TransposeLast(),
186
+ Fp32LayerNorm(dim, elementwise_affine=True),
187
+ TransposeLast(),
188
+ ),
189
+ nn.GELU(),
190
+ )
191
+ elif is_group_norm:
192
+ return nn.Sequential(
193
+ make_conv(),
194
+ nn.Dropout(p=dropout),
195
+ Fp32GroupNorm(dim, dim, affine=True),
196
+ nn.GELU(),
197
+ )
198
+ else:
199
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
200
+
201
+ in_d = 1
202
+ self.conv_layers = nn.ModuleList()
203
+ for i, cl in enumerate(conv_layers):
204
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
205
+ (dim, k, stride) = cl
206
+
207
+ self.conv_layers.append(
208
+ block(
209
+ in_d,
210
+ dim,
211
+ k,
212
+ stride,
213
+ is_layer_norm=mode == "layer_norm",
214
+ is_group_norm=mode == "default" and i == 0,
215
+ conv_bias=conv_bias,
216
+ )
217
+ )
218
+ in_d = dim
219
+
220
+ def forward(self, x):
221
+
222
+ # BxT -> BxCxT
223
+ x = x.unsqueeze(1)
224
+
225
+ for conv in self.conv_layers:
226
+ x = conv(x)
227
+
228
+ return x
229
+
230
+
231
+ # copied from fairseq.examples.data2vec.models.modalities.modules
232
+ class AltAttention(nn.Module):
233
+ def __init__(
234
+ self,
235
+ dim,
236
+ num_heads=8,
237
+ qkv_bias=False,
238
+ qk_scale=None,
239
+ attn_drop=0.0,
240
+ proj_drop=0.0,
241
+ cosine_attention=False,
242
+ ):
243
+ super().__init__()
244
+ self.num_heads = num_heads
245
+ head_dim = dim // num_heads
246
+ self.scale = qk_scale or head_dim ** -0.5
247
+
248
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
249
+ # self.attn_drop = nn.Dropout(attn_drop)
250
+ self.attn_drop = attn_drop
251
+ self.proj = nn.Linear(dim, dim)
252
+ # self.proj_drop = nn.Dropout(proj_drop)
253
+ self.proj_drop = proj_drop
254
+
255
+ self.cosine_attention = cosine_attention
256
+
257
+ if cosine_attention:
258
+ self.logit_scale = nn.Parameter(
259
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
260
+ )
261
+
262
+ def forward(self, x, padding_mask=None, alibi_bias=None, fast=True):
263
+ B, N, C = x.shape
264
+ qkv = (
265
+ self.qkv(x)
266
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
267
+ .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
268
+ )
269
+ q, k, v = (
270
+ qkv[0],
271
+ qkv[1],
272
+ qkv[2],
273
+ ) # make torchscript happy (cannot use tensor as tuple)
274
+
275
+ dtype = q.dtype
276
+
277
+ if not fast:
278
+ if self.cosine_attention:
279
+ # cosine attention
280
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
281
+ logit_scale = torch.clamp(
282
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
283
+ ).exp()
284
+ attn = attn * logit_scale
285
+ else:
286
+ q = q * self.scale
287
+ attn = q @ k.transpose(-2, -1) # B x C//H x L x L
288
+
289
+ if alibi_bias is not None:
290
+ attn = attn.type_as(alibi_bias)
291
+ attn[:, : alibi_bias.size(1)] += alibi_bias
292
+
293
+ if padding_mask is not None and padding_mask.any():
294
+ attn = attn.masked_fill(
295
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
296
+ float("-inf"),
297
+ )
298
+
299
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
300
+ # attn = self.attn_drop(attn)
301
+ attn = F.dropout(attn, p=self.attn_drop)
302
+ x = (attn @ v).transpose(1, 2)
303
+ else:
304
+ # Using pytorch 2's sdpa
305
+ assert not self.cosine_attention, "Not support cosine attention yet"
306
+ # Integrate padding_mask and alibi_bias
307
+ if padding_mask is not None and padding_mask.any():
308
+ if alibi_bias is not None:
309
+ padding_mask = alibi_bias.masked_fill(
310
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
311
+ float("-inf"),
312
+ ).to(dtype=dtype)
313
+ else:
314
+ padding_mask = padding_mask.unsqueeze(1).unsqueeze(2).to(
315
+ torch.bool).to(dtype=dtype)
316
+ else:
317
+ if alibi_bias is not None:
318
+ padding_mask = alibi_bias.to(dtype=dtype)
319
+ else:
320
+ padding_mask = None
321
+
322
+ x = F.scaled_dot_product_attention(q, k, v,
323
+ attn_mask=padding_mask,
324
+ dropout_p=self.attn_drop if self.training else 0.0,
325
+ scale=self.scale).transpose(1, 2)
326
+
327
+ x = x.reshape(B, N, C)
328
+ x = self.proj(x)
329
+ x = F.dropout(x, p=self.proj_drop if self.training else 0.0)
330
+ return x
331
+
332
+
333
+ # copied from fairseq.examples.data2vec.models.modalities.modules.py
334
+ class AltBlock(nn.Module):
335
+ def __init__(
336
+ self,
337
+ dim,
338
+ num_heads,
339
+ mlp_ratio=4.0,
340
+ qkv_bias=False,
341
+ qk_scale=None,
342
+ drop=0.0,
343
+ attn_drop=0.0,
344
+ mlp_drop=0.0,
345
+ post_mlp_drop=0.0,
346
+ drop_path=0.0,
347
+ act_layer=nn.GELU,
348
+ norm_layer=nn.LayerNorm,
349
+ layer_norm_first=True,
350
+ ffn_targets=False,
351
+ cosine_attention=False,
352
+ ):
353
+ super().__init__()
354
+
355
+ self.layer_norm_first = layer_norm_first
356
+ self.ffn_targets = ffn_targets
357
+
358
+ from timm.models.vision_transformer import DropPath, Mlp
359
+
360
+ self.norm1 = norm_layer(dim)
361
+ self.attn = AltAttention(
362
+ dim,
363
+ num_heads=num_heads,
364
+ qkv_bias=qkv_bias,
365
+ qk_scale=qk_scale,
366
+ attn_drop=attn_drop,
367
+ proj_drop=drop,
368
+ cosine_attention=cosine_attention,
369
+ )
370
+
371
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
372
+ self.norm2 = norm_layer(dim)
373
+ mlp_hidden_dim = int(dim * mlp_ratio)
374
+ self.mlp = Mlp(
375
+ in_features=dim,
376
+ hidden_features=mlp_hidden_dim,
377
+ act_layer=act_layer,
378
+ drop=mlp_drop,
379
+ )
380
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
381
+
382
+ def forward(self, x, padding_mask=None, alibi_bias=None):
383
+ if self.layer_norm_first:
384
+ x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
385
+ r = x = self.mlp(self.norm2(x))
386
+ t = x
387
+ x = r + self.drop_path(self.post_mlp_dropout(x))
388
+ if not self.ffn_targets:
389
+ t = x
390
+ else:
391
+ x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
392
+ r = x = self.norm1(x)
393
+ x = self.mlp(x)
394
+ t = x
395
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
396
+ if not self.ffn_targets:
397
+ t = x
398
+
399
+ return x, t
400
+
401
+
402
+ # copied from fairseq.data2vec.models.modalities.modules
403
+ class BlockEncoder(nn.Module):
404
+ def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout):
405
+ super().__init__()
406
+ self.blocks = blocks
407
+ self.norm = norm_layer
408
+ self.layer_norm_first = layer_norm_first
409
+ self.layerdrop = layerdrop
410
+ self.dropout = nn.Dropout(dropout, inplace=True)
411
+
412
+ def forward(self, x, padding_mask, alibi_bias, alibi_scale):
413
+ if self.norm is not None and not self.layer_norm_first:
414
+ x = self.norm(x)
415
+
416
+ x = self.dropout(x)
417
+
418
+ for i, blk in enumerate(self.blocks):
419
+ if (
420
+ not self.training
421
+ or self.layerdrop == 0
422
+ or (np.random.random() > self.layerdrop)
423
+ ):
424
+ ab = alibi_bias
425
+ if ab is not None and alibi_scale is not None:
426
+ scale = (
427
+ alibi_scale[i]
428
+ if alibi_scale.size(0) > 1
429
+ else alibi_scale.squeeze(0)
430
+ )
431
+ ab = ab * scale.type_as(ab)
432
+ x, _ = blk(x, padding_mask, ab)
433
+
434
+ if self.norm is not None and self.layer_norm_first:
435
+ x = self.norm(x)
436
+
437
+ return x
438
+
439
+
440
+ class ModalitySpecificEncoder(nn.Module):
441
+ def __init__(
442
+ self,
443
+ modality_cfg: D2v2ModalityConfig,
444
+ embed_dim: int,
445
+ local_encoder: nn.Module,
446
+ project_features: nn.Module,
447
+ fixed_positional_encoder: Optional[nn.Module],
448
+ relative_positional_encoder: Optional[nn.Module],
449
+ context_encoder: nn.Module,
450
+ decoder: nn.Module,
451
+ get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]],
452
+ ):
453
+ super().__init__()
454
+
455
+ self.modality_cfg = modality_cfg
456
+ self.local_encoder = local_encoder
457
+ self.project_features = project_features
458
+ self.fixed_positional_encoder = fixed_positional_encoder
459
+ self.relative_positional_encoder = relative_positional_encoder
460
+ self.context_encoder = context_encoder
461
+
462
+ self.decoder = None
463
+ self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None
464
+
465
+ self.local_grad_mult = self.modality_cfg.local_grad_mult
466
+
467
+ self.extra_tokens = None
468
+ if modality_cfg.num_extra_tokens > 0:
469
+ self.extra_tokens = nn.Parameter(
470
+ torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim)
471
+ )
472
+ if not modality_cfg.init_extra_token_zero:
473
+ nn.init.normal_(self.extra_tokens)
474
+ elif self.extra_tokens.size(1) > 1:
475
+ nn.init.normal_(self.extra_tokens[:, 1:])
476
+
477
+ self.alibi_scale = None
478
+ if self.get_alibi_bias is not None:
479
+ self.alibi_scale = nn.Parameter(
480
+ torch.full(
481
+ (
482
+ (modality_cfg.prenet_depth + modality_cfg.model_depth)
483
+ if modality_cfg.learned_alibi_scale_per_layer
484
+ else 1,
485
+ 1,
486
+ self.modality_cfg.num_alibi_heads
487
+ if modality_cfg.learned_alibi_scale_per_head
488
+ else 1,
489
+ 1,
490
+ 1,
491
+ ),
492
+ modality_cfg.alibi_scale,
493
+ dtype=torch.float,
494
+ ),
495
+ requires_grad=modality_cfg.learned_alibi_scale,
496
+ )
497
+
498
+ if modality_cfg.learned_alibi and self.get_alibi_bias is not None:
499
+ assert modality_cfg.alibi_max_pos is not None
500
+ alibi_bias = self.get_alibi_bias(
501
+ batch_size=1,
502
+ time_steps=modality_cfg.alibi_max_pos,
503
+ heads=modality_cfg.num_alibi_heads,
504
+ scale=1.0,
505
+ dtype=torch.float,
506
+ device="cpu",
507
+ )
508
+ self.alibi_bias = nn.Parameter(alibi_bias)
509
+ self.get_alibi_bias = partial(
510
+ _learned_alibi_bias, alibi_bias=self.alibi_bias
511
+ )
512
+
513
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder._freeze_parameters
514
+ def _freeze_parameters(self):
515
+ for param in self.parameters():
516
+ param.requires_grad = False
517
+ self._requires_grad = False
518
+
519
+ def convert_padding_mask(self, x, padding_mask):
520
+ return padding_mask
521
+
522
+ def local_features(self, features):
523
+ if self.local_grad_mult > 0:
524
+ if self.local_grad_mult == 1.0:
525
+ x = self.local_encoder(features)
526
+ else:
527
+ x = GradMultiply.apply(
528
+ self.local_encoder(features), self.local_grad_mult
529
+ )
530
+ else:
531
+ with torch.no_grad():
532
+ x = self.local_encoder(features)
533
+
534
+ x = self.project_features(x)
535
+ return x
536
+
537
+ def contextualized_features(
538
+ self,
539
+ x,
540
+ padding_mask,
541
+ mask,
542
+ remove_masked,
543
+ clone_batch: int = 1,
544
+ mask_seeds: Optional[torch.Tensor] = None,
545
+ precomputed_mask=None,
546
+ ):
547
+
548
+ if padding_mask is not None:
549
+ padding_mask = self.convert_padding_mask(x, padding_mask)
550
+
551
+ local_features = x
552
+ if mask and clone_batch == 1:
553
+ local_features = local_features.clone()
554
+
555
+ orig_B, orig_T, _ = x.shape
556
+ pre_mask_B = orig_B
557
+ mask_info = None
558
+
559
+ x_pos = None
560
+ if self.fixed_positional_encoder is not None:
561
+ x = x + self.fixed_positional_encoder(x, padding_mask)
562
+
563
+ if mask:
564
+ if clone_batch > 1:
565
+ x = x.repeat_interleave(clone_batch, 0)
566
+ if mask_seeds is not None:
567
+ clone_hash = [
568
+ int(hash((mask_seeds.seed, ind)) % 1e10)
569
+ for ind in range(clone_batch - 1)
570
+ ]
571
+ clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
572
+
573
+ id = mask_seeds.ids
574
+ id = id.repeat_interleave(clone_batch, 0)
575
+ id = id.view(-1, clone_batch) + clone_hash.to(id)
576
+ id = id.view(-1)
577
+ mask_seeds = MaskSeed(
578
+ seed=mask_seeds.seed, update=mask_seeds.update, ids=id
579
+ )
580
+ if padding_mask is not None:
581
+ padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
582
+
583
+ x, mask_info = self.compute_mask(
584
+ x,
585
+ padding_mask,
586
+ mask_seed=mask_seeds,
587
+ apply=self.relative_positional_encoder is not None or not remove_masked,
588
+ precomputed_mask=precomputed_mask,
589
+ )
590
+
591
+ if self.relative_positional_encoder is not None:
592
+ x_pos = self.relative_positional_encoder(x)
593
+
594
+ masked_padding_mask = padding_mask
595
+ if mask and remove_masked:
596
+ x = mask_info.x_unmasked
597
+ if x_pos is not None:
598
+ x = x + gather_unmasked(x_pos, mask_info)
599
+
600
+ if padding_mask is not None and padding_mask.any():
601
+ masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info)
602
+ if not masked_padding_mask.any():
603
+ masked_padding_mask = None
604
+ else:
605
+ masked_padding_mask = None
606
+
607
+ elif x_pos is not None:
608
+ x = x + x_pos
609
+
610
+ alibi_bias = None
611
+ alibi_scale = self.alibi_scale
612
+
613
+ if self.get_alibi_bias is not None:
614
+ alibi_bias = self.get_alibi_bias(
615
+ batch_size=pre_mask_B,
616
+ time_steps=orig_T,
617
+ heads=self.modality_cfg.num_alibi_heads,
618
+ dtype=torch.float32,
619
+ device=x.device,
620
+ )
621
+
622
+ if alibi_scale is not None:
623
+ alibi_scale = alibi_scale.clamp_min(0)
624
+ if alibi_scale.size(0) == 1:
625
+ alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias)
626
+ alibi_scale = None
627
+
628
+ if clone_batch > 1:
629
+ alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0)
630
+
631
+ if mask_info is not None and remove_masked:
632
+ alibi_bias = masked_alibi(alibi_bias, mask_info)
633
+
634
+ if self.extra_tokens is not None:
635
+ num = self.extra_tokens.size(1)
636
+ x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1)
637
+ if masked_padding_mask is not None:
638
+ # B x T
639
+ masked_padding_mask = F.pad(masked_padding_mask, (num, 0))
640
+ if alibi_bias is not None:
641
+ # B x H x T x T
642
+ alibi_bias = F.pad(alibi_bias, (num, 0, num, 0))
643
+
644
+ x = self.context_encoder(
645
+ x,
646
+ masked_padding_mask,
647
+ alibi_bias,
648
+ alibi_scale[: self.modality_cfg.prenet_depth]
649
+ if alibi_scale is not None
650
+ else None,
651
+ )
652
+
653
+ return {
654
+ "x": x,
655
+ "local_features": local_features,
656
+ "padding_mask": masked_padding_mask,
657
+ "alibi_bias": alibi_bias,
658
+ "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
659
+ if alibi_scale is not None and alibi_scale.size(0) > 1
660
+ else alibi_scale,
661
+ "encoder_mask": mask_info,
662
+ }
663
+
664
+ def forward(
665
+ self,
666
+ features,
667
+ padding_mask,
668
+ mask: bool,
669
+ remove_masked: bool,
670
+ clone_batch: int = 1,
671
+ mask_seeds: Optional[torch.Tensor] = None,
672
+ precomputed_mask=None,
673
+ ):
674
+ x = self.local_features(features)
675
+ return self.contextualized_features(
676
+ x,
677
+ padding_mask,
678
+ mask,
679
+ remove_masked,
680
+ clone_batch,
681
+ mask_seeds,
682
+ precomputed_mask,
683
+ )
684
+
685
+ def compute_mask(
686
+ self,
687
+ x,
688
+ padding_mask,
689
+ mask_seed: Optional[MaskSeed],
690
+ apply,
691
+ precomputed_mask,
692
+ ):
693
+ if precomputed_mask is not None:
694
+ mask = precomputed_mask
695
+ mask_info = self.make_maskinfo(x, mask)
696
+ else:
697
+ B, T, C = x.shape
698
+ cfg = self.modality_cfg
699
+
700
+ mask_prob = cfg.mask_prob
701
+
702
+ if (
703
+ cfg.mask_prob_min is not None
704
+ and cfg.mask_prob_min >= 0
705
+ and cfg.mask_prob_min < mask_prob
706
+ ):
707
+ mask_prob = np.random.uniform(cfg.mask_prob_min, mask_prob)
708
+
709
+ if mask_prob > 0:
710
+ if cfg.mask_length == 1:
711
+ mask_info = random_masking(x, mask_prob, mask_seed)
712
+ else:
713
+ if self.modality_cfg.inverse_mask:
714
+ mask_prob = 1 - mask_prob
715
+
716
+ mask = compute_mask_indices(
717
+ (B, T),
718
+ padding_mask,
719
+ mask_prob,
720
+ cfg.mask_length,
721
+ min_masks=1,
722
+ require_same_masks=True,
723
+ mask_dropout=cfg.mask_dropout,
724
+ add_masks=cfg.add_masks,
725
+ seed=mask_seed.seed if mask_seed is not None else None,
726
+ epoch=mask_seed.update if mask_seed is not None else None,
727
+ indices=mask_seed.ids if mask_seed is not None else None,
728
+ )
729
+
730
+ mask = torch.from_numpy(mask).to(device=x.device)
731
+ if self.modality_cfg.inverse_mask:
732
+ mask = 1 - mask
733
+ mask_info = self.make_maskinfo(x, mask)
734
+ else:
735
+ mask_info = None
736
+
737
+ if apply:
738
+ x = self.apply_mask(x, mask_info)
739
+
740
+ return x, mask_info
741
+
742
+ def make_maskinfo(self, x, mask, shape=None):
743
+ if shape is None:
744
+ B, T, D = x.shape
745
+ else:
746
+ B, T, D = shape
747
+
748
+ mask = mask.to(torch.uint8)
749
+ ids_shuffle = mask.argsort(dim=1)
750
+ ids_restore = ids_shuffle.argsort(dim=1).unsqueeze(-1).expand(-1, -1, D)
751
+
752
+ len_keep = T - mask[0].sum()
753
+ if self.modality_cfg.keep_masked_pct > 0:
754
+ len_keep += round((T - int(len_keep)) * self.modality_cfg.keep_masked_pct)
755
+
756
+ ids_keep = ids_shuffle[:, :len_keep]
757
+
758
+ if shape is not None:
759
+ x_unmasked = None
760
+ else:
761
+ ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
762
+ x_unmasked = torch.gather(x, dim=1, index=ids_keep)
763
+
764
+ mask_info = MaskInfo(
765
+ x_unmasked=x_unmasked,
766
+ mask=mask,
767
+ ids_restore=ids_restore,
768
+ ids_keep=ids_keep,
769
+ )
770
+ return mask_info
771
+
772
+ def apply_mask(self, x, mask_info):
773
+ cfg = self.modality_cfg
774
+ B, T, C = x.shape
775
+
776
+ if mask_info is not None:
777
+ mask = mask_info.mask
778
+ if cfg.encoder_zero_mask:
779
+ x = x * (1 - mask.type_as(x).unsqueeze(-1))
780
+ else:
781
+ num_masks = mask.sum().item()
782
+ masks = x.new_empty(num_masks, x.size(-1)).normal_(
783
+ 0, cfg.mask_noise_std
784
+ )
785
+ x = index_put(x, mask, masks)
786
+ if cfg.mask_channel_prob > 0:
787
+ mask_channel = compute_mask_indices(
788
+ (B, C),
789
+ None,
790
+ cfg.mask_channel_prob,
791
+ cfg.mask_channel_length,
792
+ )
793
+ mask_channel = (
794
+ torch.from_numpy(mask_channel)
795
+ .to(x.device)
796
+ .unsqueeze(1)
797
+ .expand(-1, T, -1)
798
+ )
799
+ x = index_put(x, mask_channel, 0)
800
+ return x
801
+
802
+
803
+ class AudioEncoder(ModalitySpecificEncoder):
804
+
805
+ modality_cfg: D2v2AudioConfig
806
+
807
+ def __init__(
808
+ self,
809
+ modality_cfg: D2v2AudioConfig,
810
+ embed_dim: int,
811
+ make_block: Callable[[float], nn.ModuleList],
812
+ norm_layer: Callable[[int], nn.LayerNorm],
813
+ layer_norm_first: bool,
814
+ alibi_biases: Dict,
815
+ ):
816
+
817
+ self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec)
818
+ feature_embed_dim = self.feature_enc_layers[-1][0]
819
+
820
+ local_encoder = ConvFeatureExtractionModel(
821
+ conv_layers=self.feature_enc_layers,
822
+ dropout=0.0,
823
+ mode=modality_cfg.extractor_mode,
824
+ conv_bias=False,
825
+ )
826
+
827
+ project_features = nn.Sequential(
828
+ TransposeLast(),
829
+ nn.LayerNorm(feature_embed_dim),
830
+ nn.Linear(feature_embed_dim, embed_dim),
831
+ )
832
+
833
+ num_pos_layers = modality_cfg.conv_pos_depth
834
+ k = max(3, modality_cfg.conv_pos_width // num_pos_layers)
835
+
836
+ positional_encoder = nn.Sequential(
837
+ TransposeLast(),
838
+ *[
839
+ nn.Sequential(
840
+ nn.Conv1d(
841
+ embed_dim,
842
+ embed_dim,
843
+ kernel_size=k,
844
+ padding=k // 2,
845
+ groups=modality_cfg.conv_pos_groups,
846
+ ),
847
+ SamePad(k),
848
+ TransposeLast(),
849
+ LayerNorm(embed_dim, elementwise_affine=False),
850
+ TransposeLast(),
851
+ nn.GELU(),
852
+ )
853
+ for _ in range(num_pos_layers)
854
+ ],
855
+ TransposeLast(),
856
+ )
857
+
858
+ if modality_cfg.conv_pos_pre_ln:
859
+ positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder)
860
+
861
+ dpr = np.linspace(
862
+ modality_cfg.start_drop_path_rate,
863
+ modality_cfg.end_drop_path_rate,
864
+ modality_cfg.prenet_depth,
865
+ )
866
+ context_encoder = BlockEncoder(
867
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
868
+ norm_layer(embed_dim) if not layer_norm_first else None,
869
+ layer_norm_first,
870
+ modality_cfg.prenet_layerdrop,
871
+ modality_cfg.prenet_dropout,
872
+ )
873
+
874
+ decoder = None
875
+
876
+ alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
877
+
878
+ super().__init__(
879
+ modality_cfg=modality_cfg,
880
+ embed_dim=embed_dim,
881
+ local_encoder=local_encoder,
882
+ project_features=project_features,
883
+ fixed_positional_encoder=None,
884
+ relative_positional_encoder=positional_encoder,
885
+ context_encoder=context_encoder,
886
+ decoder=decoder,
887
+ get_alibi_bias=alibi_bias_fn,
888
+ )
889
+
890
+ def convert_padding_mask(self, x, padding_mask):
891
+ def get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
892
+ """
893
+ Computes the output length of the convolutional layers
894
+ """
895
+
896
+ def _conv_out_length(input_length, kernel_size, stride):
897
+ return torch.floor((input_length - kernel_size) / stride + 1)
898
+
899
+ for i in range(len(self.feature_enc_layers)):
900
+ input_lengths = _conv_out_length(
901
+ input_lengths,
902
+ self.feature_enc_layers[i][1],
903
+ self.feature_enc_layers[i][2],
904
+ )
905
+
906
+ return input_lengths.to(torch.long)
907
+
908
+ if padding_mask is not None:
909
+ input_lengths = (1 - padding_mask.long()).sum(-1)
910
+ # apply conv formula to get real output_lengths
911
+ output_lengths = get_feat_extract_output_lengths(input_lengths)
912
+
913
+ if padding_mask.any():
914
+ padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device)
915
+
916
+ # these two operations makes sure that all values
917
+ # before the output lengths indices are attended to
918
+ padding_mask[
919
+ (
920
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
921
+ output_lengths - 1,
922
+ )
923
+ ] = 1
924
+ padding_mask = (
925
+ 1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
926
+ ).bool()
927
+ else:
928
+ padding_mask = torch.zeros(
929
+ x.shape[:2], dtype=torch.bool, device=x.device
930
+ )
931
+
932
+ return padding_mask
933
+
934
+
935
+ class LearnedPositionalEmbedding(nn.Embedding):
936
+ """
937
+ This module learns positional embeddings up to a fixed maximum size.
938
+ Padding ids are ignored by either offsetting based on padding_idx
939
+ or by setting padding_idx to None and ensuring that the appropriate
940
+ position ids are passed to the forward function.
941
+ """
942
+
943
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
944
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
945
+ self.onnx_trace = False
946
+ if self.padding_idx is not None:
947
+ self.max_positions = self.num_embeddings - self.padding_idx - 1
948
+ else:
949
+ self.max_positions = self.num_embeddings
950
+
951
+ def forward(
952
+ self,
953
+ input: Tensor,
954
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
955
+ positions: Optional[Tensor] = None,
956
+ ):
957
+ """Input is expected to be of size [bsz x seqlen]."""
958
+ assert (positions is None) or (
959
+ self.padding_idx is None
960
+ ), "If positions is pre-computed then padding_idx should not be set."
961
+
962
+ if positions is None:
963
+ if incremental_state is not None:
964
+ # positions is the same for every token when decoding a single step
965
+ # Without the int() cast, it doesn't work in some cases when exporting to ONNX
966
+ positions = torch.zeros(
967
+ (1, 1), device=input.device, dtype=input.dtype
968
+ ).fill_(int(self.padding_idx + input.size(1)))
969
+ else:
970
+ positions = make_positions(
971
+ input, self.padding_idx, onnx_trace=self.onnx_trace
972
+ )
973
+ return F.embedding(
974
+ positions,
975
+ self.weight,
976
+ self.padding_idx,
977
+ self.max_norm,
978
+ self.norm_type,
979
+ self.scale_grad_by_freq,
980
+ self.sparse,
981
+ )
982
+
983
+
984
+ class SinusoidalPositionalEmbedding(nn.Module):
985
+ """This module produces sinusoidal positional embeddings of any length.
986
+
987
+ Padding symbols are ignored.
988
+ """
989
+
990
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
991
+ super().__init__()
992
+ self.embedding_dim = embedding_dim
993
+ self.padding_idx = padding_idx if padding_idx is not None else 0
994
+ self.register_buffer("weights", SinusoidalPositionalEmbedding.get_embedding(
995
+ init_size, embedding_dim, padding_idx
996
+ ), persistent=False)
997
+ self.max_positions = int(1e5)
998
+ self.onnx_trace = False
999
+
1000
+ def prepare_for_onnx_export_(self):
1001
+ self.onnx_trace = True
1002
+
1003
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
1004
+ # Ignore some deprecated keys that were used in older versions
1005
+ deprecated_keys = ["weights", "_float_tensor"]
1006
+ for key in deprecated_keys:
1007
+ if prefix + key in state_dict:
1008
+ del state_dict[prefix + key]
1009
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
1010
+
1011
+ @staticmethod
1012
+ def get_embedding(
1013
+ num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
1014
+ ):
1015
+ """Build sinusoidal embeddings.
1016
+
1017
+ This matches the implementation in tensor2tensor, but differs slightly
1018
+ from the description in Section 3.5 of "Attention Is All You Need".
1019
+ """
1020
+ half_dim = embedding_dim // 2
1021
+ emb = math.log(10000) / (half_dim - 1)
1022
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
1023
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1024
+ 1
1025
+ ) * emb.unsqueeze(0)
1026
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
1027
+ num_embeddings, -1
1028
+ )
1029
+ if embedding_dim % 2 == 1:
1030
+ # zero pad
1031
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
1032
+ if padding_idx is not None:
1033
+ emb[padding_idx, :] = 0
1034
+ return emb
1035
+
1036
+ def forward(
1037
+ self,
1038
+ input,
1039
+ incremental_state: Optional[Any] = None,
1040
+ timestep: Optional[Tensor] = None,
1041
+ positions: Optional[Any] = None,
1042
+ ):
1043
+ """Input is expected to be of size [bsz x seqlen]."""
1044
+ bspair = torch.onnx.operators.shape_as_tensor(input)
1045
+ bsz, seq_len = bspair[0], bspair[1]
1046
+ max_pos = self.padding_idx + 1 + seq_len
1047
+ if max_pos > self.weights.size(0):
1048
+ # expand embeddings if needed
1049
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
1050
+ max_pos, self.embedding_dim, self.padding_idx
1051
+ ).to(self.weights)
1052
+
1053
+ if incremental_state is not None:
1054
+ # positions is the same for every token when decoding a single step
1055
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
1056
+ if self.onnx_trace:
1057
+ return (
1058
+ self.weights.index_select(index=self.padding_idx + pos, dim=0)
1059
+ .unsqueeze(1)
1060
+ .repeat(bsz, 1, 1)
1061
+ )
1062
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
1063
+
1064
+ positions = make_positions(
1065
+ input, self.padding_idx, onnx_trace=self.onnx_trace
1066
+ )
1067
+ if self.onnx_trace:
1068
+ flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
1069
+ embedding_shape = torch.cat(
1070
+ (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long))
1071
+ )
1072
+ embeddings = torch.onnx.operators.reshape_from_tensor_shape(
1073
+ flat_embeddings, embedding_shape
1074
+ )
1075
+ return embeddings
1076
+ return (
1077
+ self.weights.index_select(0, positions.view(-1))
1078
+ .view(bsz, seq_len, -1)
1079
+ .detach()
1080
+ )
1081
+
1082
+ def PositionalEmbedding(
1083
+ num_embeddings: int,
1084
+ embedding_dim: int,
1085
+ padding_idx: int,
1086
+ learned: bool = False,
1087
+ ):
1088
+ if learned:
1089
+ # if padding_idx is specified then offset the embedding ids by
1090
+ # this index and adjust num_embeddings appropriately
1091
+ # TODO: The right place for this offset would be inside
1092
+ # LearnedPositionalEmbedding. Move this there for a cleaner implementation.
1093
+ if padding_idx is not None:
1094
+ num_embeddings = num_embeddings + padding_idx + 1
1095
+ m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
1096
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
1097
+ if padding_idx is not None:
1098
+ nn.init.constant_(m.weight[padding_idx], 0)
1099
+ else:
1100
+ m = SinusoidalPositionalEmbedding(
1101
+ embedding_dim,
1102
+ padding_idx,
1103
+ init_size=num_embeddings + padding_idx + 1,
1104
+ )
1105
+ return m
1106
+
1107
+
1108
+ class TextLocalEncoder(nn.Module):
1109
+ def __init__(
1110
+ self,
1111
+ vocab_size,
1112
+ embed_dim,
1113
+ max_source_positions,
1114
+ pad_idx,
1115
+ no_scale_embedding,
1116
+ layernorm_embedding,
1117
+ dropout,
1118
+ no_token_positional_embeddings,
1119
+ learned_pos,
1120
+ ):
1121
+ super().__init__()
1122
+ self.pad_idx = pad_idx
1123
+ self.dropout_module = nn.Dropout(dropout)
1124
+
1125
+ self.embed_tokens = nn.Embedding(vocab_size, embed_dim, pad_idx)
1126
+ self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
1127
+ self.embed_positions = (
1128
+ PositionalEmbedding(
1129
+ max_source_positions,
1130
+ embed_dim,
1131
+ pad_idx,
1132
+ learned=learned_pos,
1133
+ )
1134
+ if not no_token_positional_embeddings
1135
+ else None
1136
+ )
1137
+ self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
1138
+
1139
+ self.layernorm_embedding = None
1140
+ if layernorm_embedding:
1141
+ self.layernorm_embedding = LayerNorm(embed_dim)
1142
+
1143
+ def forward(self, src_tokens):
1144
+ x = self.embed_scale * self.embed_tokens(src_tokens)
1145
+ if self.embed_positions is not None:
1146
+ x = x + self.embed_positions(src_tokens)
1147
+
1148
+ if self.layernorm_embedding is not None:
1149
+ x = self.layernorm_embedding(x)
1150
+ x = self.dropout_module(x)
1151
+ return x
1152
+
1153
+
1154
+ class TextEncoder(ModalitySpecificEncoder):
1155
+
1156
+ modality_cfg: D2v2TextConfig
1157
+
1158
+ def __init__(
1159
+ self,
1160
+ modality_cfg: D2v2TextConfig,
1161
+ embed_dim: int,
1162
+ make_block: Callable[[float], nn.ModuleList],
1163
+ norm_layer: Callable[[int], nn.LayerNorm],
1164
+ layer_norm_first: bool,
1165
+ alibi_biases: Dict,
1166
+ ):
1167
+ self.pad_idx = modality_cfg.pad_token_id
1168
+ self.vocab_size = modality_cfg.vocab_size
1169
+
1170
+ local_encoder = TextLocalEncoder(
1171
+ vocab_size=self.vocab_size,
1172
+ embed_dim=embed_dim,
1173
+ max_source_positions=modality_cfg.max_source_positions,
1174
+ pad_idx=self.pad_idx,
1175
+ no_scale_embedding=modality_cfg.no_scale_embedding,
1176
+ layernorm_embedding=modality_cfg.layernorm_embedding,
1177
+ dropout=modality_cfg.dropout,
1178
+ no_token_positional_embeddings=modality_cfg.no_token_positional_embeddings,
1179
+ learned_pos=modality_cfg.learned_pos,
1180
+ )
1181
+ dpr = np.linspace(
1182
+ modality_cfg.start_drop_path_rate,
1183
+ modality_cfg.end_drop_path_rate,
1184
+ modality_cfg.prenet_depth,
1185
+ )
1186
+ context_encoder = BlockEncoder(
1187
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
1188
+ norm_layer(embed_dim)
1189
+ if not layer_norm_first and modality_cfg.prenet_depth > 0
1190
+ else None,
1191
+ layer_norm_first,
1192
+ modality_cfg.prenet_layerdrop,
1193
+ modality_cfg.prenet_dropout if modality_cfg.prenet_depth > 0 else 0.0,
1194
+ )
1195
+ decoder = None
1196
+
1197
+ alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
1198
+
1199
+ super().__init__(
1200
+ modality_cfg=modality_cfg,
1201
+ embed_dim=embed_dim,
1202
+ local_encoder=local_encoder,
1203
+ project_features=nn.Identity(),
1204
+ fixed_positional_encoder=None,
1205
+ relative_positional_encoder=None,
1206
+ context_encoder=context_encoder,
1207
+ decoder=decoder,
1208
+ get_alibi_bias=alibi_bias_fn,
1209
+ )
1210
+
1211
+ def convert_padding_mask(self, x, padding_mask):
1212
+ if padding_mask is None or padding_mask.size(1) == x.size(1):
1213
+ return padding_mask
1214
+
1215
+ diff = self.downsample - padding_mask.size(1) % self.downsample
1216
+ if 0 < diff < self.downsample:
1217
+ padding_mask = F.pad(padding_mask, (0, diff), value=True)
1218
+
1219
+ padding_mask = padding_mask.view(padding_mask.size(0), -1, self.downsample)
1220
+ padding_mask = padding_mask.all(-1)
1221
+ if padding_mask.size(1) > x.size(1):
1222
+ padding_mask = padding_mask[:, : x.size(1)]
1223
+
1224
+ assert x.size(1) == padding_mask.size(
1225
+ 1
1226
+ ), f"{x.size(1), padding_mask.size(1), diff, self.downsample}"
1227
+
1228
+ return padding_mask
1229
+ #################################################
1230
+
1231
+
1232
+ # copied from transformers.models.data2vec.modeling_data2vec.Data2VecTextPooler
1233
+ class Data2VecTextPooler(nn.Module):
1234
+ def __init__(self, config):
1235
+ super().__init__()
1236
+ self.dense = nn.Linear(config.embed_dim, config.embed_dim)
1237
+ self.activation = nn.Tanh()
1238
+
1239
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1240
+ # We "pool" the model by simply taking the hidden state corresponding
1241
+ # to the first token.
1242
+ first_token_tensor = hidden_states[:, 0]
1243
+ pooled_output = self.dense(first_token_tensor)
1244
+ pooled_output = self.activation(pooled_output)
1245
+ return pooled_output
1246
+
1247
+
1248
+ class Data2Vec2MultiPreTrainedModel(PreTrainedModel):
1249
+ # use init_bert_params from fairseq
1250
+ # copied from fairseq.modules.transformer_sentence_encoder.py
1251
+ def _init_weights(self, module):
1252
+ """Initialize the weights"""
1253
+
1254
+ def normal_(data):
1255
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
1256
+ # so that the RNG is consistent with and without FSDP
1257
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
1258
+
1259
+ def _init(module):
1260
+ if isinstance(module, nn.Linear):
1261
+ normal_(module.weight.data)
1262
+ if module.bias is not None:
1263
+ module.bias.data.zero_()
1264
+ if isinstance(module, nn.Embedding):
1265
+ normal_(module.weight.data)
1266
+ if module.padding_idx is not None:
1267
+ module.weight.data[module.padding_idx].zero_()
1268
+ if isinstance(module, AltBlock):
1269
+ normal_(module.attn.proj.weight.data)
1270
+ # init strategy for audio encoder
1271
+ if isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
1272
+ if module.bias is not None:
1273
+ module.bias.data.zero_()
1274
+ if module.weight is not None:
1275
+ module.weight.data.fill_(1.0)
1276
+ if isinstance(module, nn.Conv1d):
1277
+ nn.init.kaiming_normal_(module.weight)
1278
+ if module.bias is not None:
1279
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
1280
+ nn.init.uniform_(module.bias, a=-k, b=k)
1281
+
1282
+ if isinstance(module, nn.ModuleList):
1283
+ for _, mod in enumerate(module):
1284
+ _init(mod)
1285
+ else:
1286
+ _init(module)
1287
+
1288
+ # @classmethod
1289
+ # def from_pretrained(
1290
+ # cls,
1291
+ # pretrained_model_name_or_path,
1292
+ # *model_args,
1293
+ # **kwargs,
1294
+ # ):
1295
+ # config = cls.config_class()
1296
+ # config.from_pretrained(pretrained_model_name_or_path)
1297
+ # print(f"Loading configuration from pre-trained model: {type(config)}")
1298
+ # return super().from_pretrained(pretrained_model_name_or_path,
1299
+ # *model_args,
1300
+ # config,
1301
+ # **kwargs,)
1302
+
1303
+
1304
+ class Data2Vec2MultiModel(Data2Vec2MultiPreTrainedModel):
1305
+ config_class = Data2Vec2MultiConfig
1306
+ base_model_prefix = "data2vec2"
1307
+
1308
+ def __init__(
1309
+ self, config: Data2Vec2MultiConfig, add_pooling_layer: bool = True
1310
+ ):
1311
+ super().__init__(config)
1312
+ self.config = config
1313
+ modalities_cfg = config.modalities
1314
+ self.modalities = [config.supported_modality]
1315
+
1316
+ make_layer_norm = partial(
1317
+ nn.LayerNorm, eps=config.norm_eps, elementwise_affine=config.norm_affine
1318
+ )
1319
+
1320
+ def make_block(drop_path, dim=None, heads=None):
1321
+ return AltBlock(
1322
+ config.embed_dim if dim is None else dim,
1323
+ config.num_heads if heads is None else heads,
1324
+ config.mlp_ratio,
1325
+ qkv_bias=True,
1326
+ drop=config.encoder_dropout,
1327
+ attn_drop=config.attention_dropout,
1328
+ mlp_drop=config.activation_dropout,
1329
+ post_mlp_drop=config.post_mlp_drop,
1330
+ drop_path=drop_path,
1331
+ norm_layer=make_layer_norm,
1332
+ layer_norm_first=config.layer_norm_first,
1333
+ ffn_targets=not config.end_of_block_targets,
1334
+ )
1335
+
1336
+ self.alibi_biases = {}
1337
+ self.modality_encoders = nn.ModuleDict()
1338
+ for mod in self.modalities:
1339
+ mod_cfg = getattr(modalities_cfg, mod.lower())
1340
+ enc = self.make_modality_encoder(
1341
+ mod_cfg,
1342
+ config.embed_dim,
1343
+ make_block,
1344
+ make_layer_norm,
1345
+ config.layer_norm_first,
1346
+ self.alibi_biases,
1347
+ )
1348
+ self.modality_encoders[mod] = enc
1349
+
1350
+ self.dropout_input = nn.Dropout(config.dropout_input)
1351
+
1352
+ dpr = np.linspace(config.start_drop_path_rate, config.end_drop_path_rate, config.depth)
1353
+
1354
+ self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(config.depth)])
1355
+
1356
+ self.text_pooler = None
1357
+ if add_pooling_layer and config.supported_modality == "TEXT":
1358
+ self.text_pooler = Data2VecTextPooler(config)
1359
+
1360
+ self.norm = None
1361
+ if config.layer_norm_first:
1362
+ self.norm = make_layer_norm(config.embed_dim)
1363
+
1364
+ self.num_updates = 0
1365
+
1366
+ # Initialize weights and apply final processing
1367
+ self.post_init()
1368
+
1369
+ def freeze_feature_extractor(self):
1370
+ """
1371
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
1372
+ not be updated during training.
1373
+ """
1374
+ warnings.warn(
1375
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
1376
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
1377
+ FutureWarning,
1378
+ )
1379
+ self.freeze_feature_encoder()
1380
+
1381
+ def freeze_feature_encoder(self):
1382
+ """
1383
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1384
+ not be updated during training.
1385
+ """
1386
+ for mod in self.modalities:
1387
+ self.modality_encoders[mod]._freeze_parameters()
1388
+ for block in self.blocks:
1389
+ for p in block.parameters():
1390
+ p.requires_grad = False
1391
+
1392
+ def make_modality_encoder(
1393
+ self,
1394
+ cfg: D2v2ModalityConfig,
1395
+ embed_dim: int,
1396
+ make_block: Callable[[float], nn.ModuleList],
1397
+ norm_layer: Callable[[int], nn.LayerNorm],
1398
+ layer_norm_first: bool,
1399
+ alibi_biases,
1400
+ ) -> ModalitySpecificEncoder:
1401
+ if cfg.type == "AUDIO":
1402
+ enc_cls = AudioEncoder
1403
+ elif cfg.type == "TEXT":
1404
+ enc_cls = TextEncoder
1405
+ else:
1406
+ raise Exception(f"unsupported modality {cfg.type}")
1407
+
1408
+ return enc_cls(
1409
+ cfg,
1410
+ embed_dim,
1411
+ make_block,
1412
+ norm_layer,
1413
+ layer_norm_first,
1414
+ alibi_biases,
1415
+ )
1416
+
1417
+ def forward(
1418
+ self,
1419
+ input_values=None, # audio input
1420
+ input_ids=None, # text input
1421
+ attention_mask=None,
1422
+ padding_mask=None,
1423
+ mask=False,
1424
+ mode=None,
1425
+ output_hidden_states=True,
1426
+ return_dict=True,
1427
+ ):
1428
+ if mode is None:
1429
+ mode = "TEXT" if input_ids is not None else "AUDIO"
1430
+ feature_extractor = self.modality_encoders[mode]
1431
+ extractor_out = feature_extractor(
1432
+ input_ids if input_ids is not None else input_values,
1433
+ padding_mask,
1434
+ mask,
1435
+ remove_masked=False,
1436
+ clone_batch=1,
1437
+ mask_seeds=None,
1438
+ precomputed_mask=None,
1439
+ )
1440
+ x = extractor_out["x"]
1441
+ local_features = x
1442
+
1443
+ # encoder_mask = extractor_out["encoder_mask"]
1444
+ masked_padding_mask = extractor_out["padding_mask"]
1445
+ masked_alibi_bias = extractor_out.get("alibi_bias", None)
1446
+ alibi_scale = extractor_out.get("alibi_scale", None)
1447
+
1448
+ if self.dropout_input is not None:
1449
+ x = self.dropout_input(x)
1450
+
1451
+ layer_results = []
1452
+ for i, blk in enumerate(self.blocks):
1453
+ if (
1454
+ not self.training
1455
+ or self.config.layerdrop == 0
1456
+ or (np.random.random() > self.config.layerdrop)
1457
+ ):
1458
+ ab = masked_alibi_bias
1459
+ if ab is not None and alibi_scale is not None:
1460
+ scale = (
1461
+ alibi_scale[i]
1462
+ if alibi_scale.size(0) > 1
1463
+ else alibi_scale.squeeze(0)
1464
+ )
1465
+ ab = ab * scale.type_as(ab)
1466
+
1467
+ x, lr = blk(
1468
+ x,
1469
+ padding_mask=masked_padding_mask,
1470
+ alibi_bias=ab,
1471
+ )
1472
+ layer_results.append(lr)
1473
+
1474
+ if self.norm is not None:
1475
+ x = self.norm(x)
1476
+
1477
+ x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
1478
+ if masked_padding_mask is not None:
1479
+ masked_padding_mask = masked_padding_mask[
1480
+ :, feature_extractor.modality_cfg.num_extra_tokens :
1481
+ ]
1482
+
1483
+ txt_pooled_output = (
1484
+ self.text_pooler(x) if self.text_pooler is not None else None
1485
+ )
1486
+
1487
+ if not return_dict:
1488
+ return tuple(
1489
+ v
1490
+ for v in [
1491
+ x,
1492
+ txt_pooled_output,
1493
+ local_features,
1494
+ layer_results,
1495
+ ]
1496
+ if v is not None
1497
+ )
1498
+
1499
+ return Data2vec2BaseModelOutput(
1500
+ last_hidden_state=x,
1501
+ pooler_output=txt_pooled_output,
1502
+ local_features=local_features,
1503
+ hidden_states=layer_results if output_hidden_states else None,
1504
+ attentions=None, # switch to manual implementation with fast=False in forward pass of AltAttention as pytorch's dspa does not output attention weights
1505
+ )
modeling_pantagruel_uni.py ADDED
The diff for this file is too large to render. See raw diff
 
utils_data2vec2.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+
9
+ import math
10
+ import numpy as np
11
+ from collections import namedtuple
12
+ from typing import Optional, Tuple
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+
18
+ MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
19
+ MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
20
+
21
+
22
+ def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
23
+ return torch.gather(
24
+ x,
25
+ dim=1,
26
+ index=mask_info.ids_keep,
27
+ )
28
+
29
+
30
+ def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
31
+ return torch.gather(
32
+ x,
33
+ dim=1,
34
+ index=mask_info.ids_keep[..., 0], # ignore the feature dimension
35
+ )
36
+
37
+
38
+ def masked_alibi(alibi_bias, mask_info):
39
+ H = alibi_bias.size(1)
40
+
41
+ orig_bias = alibi_bias
42
+
43
+ index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
44
+ alibi_bias = torch.gather(
45
+ orig_bias,
46
+ dim=-2,
47
+ index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
48
+ )
49
+ alibi_bias = torch.gather(
50
+ alibi_bias,
51
+ dim=-1,
52
+ index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
53
+ )
54
+
55
+ return alibi_bias
56
+
57
+
58
+ def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]):
59
+ N, L, D = x.shape # batch, length, dim
60
+ len_keep = int(L * (1 - mask_ratio))
61
+
62
+ generator = None
63
+ if mask_seed is not None:
64
+ seed = int(
65
+ hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
66
+ )
67
+ generator = torch.Generator(device=x.device)
68
+ generator.manual_seed(seed)
69
+
70
+ noise = torch.rand(N, L, generator=generator, device=x.device) # noise in [0, 1]
71
+
72
+ # sort noise for each sample
73
+ ids_shuffle = noise.argsort(dim=1) # ascend: small is keep, large is remove
74
+ ids_restore = ids_shuffle.argsort(dim=1)
75
+
76
+ # keep the first subset
77
+ ids_keep = ids_shuffle[:, :len_keep]
78
+ ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
79
+ x_unmasked = torch.gather(x, dim=1, index=ids_keep)
80
+
81
+ # generate the binary mask: 0 is keep, 1 is remove
82
+ mask = torch.ones([N, L], dtype=x.dtype, device=x.device)
83
+ mask[:, :len_keep] = 0
84
+ # unshuffle to get the binary mask
85
+ mask = torch.gather(mask, dim=1, index=ids_restore)
86
+
87
+ ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
88
+
89
+ return MaskInfo(
90
+ x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
91
+ )
92
+
93
+
94
+ def get_alibi(
95
+ max_positions: int,
96
+ attention_heads: int,
97
+ dims: int = 1,
98
+ distance: str = "manhattan",
99
+ ):
100
+ def get_slopes(n):
101
+ def get_slopes_power_of_2(n):
102
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
103
+ ratio = start
104
+ return [start * ratio**i for i in range(n)]
105
+
106
+ # In the paper, we only train models that have 2^a heads for some
107
+ # a. This function has some good properties that only occur when
108
+ # the input is a power of 2. To maintain that even when the number
109
+ # of heads is not a power of 2, we use this workaround.
110
+ if math.log2(n).is_integer():
111
+ return get_slopes_power_of_2(n)
112
+ else:
113
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
114
+ return (
115
+ get_slopes_power_of_2(closest_power_of_2)
116
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
117
+ )
118
+
119
+ maxpos = max_positions
120
+ attn_heads = attention_heads
121
+ slopes = torch.Tensor(get_slopes(attn_heads))
122
+
123
+ if dims == 1:
124
+ # prepare alibi position linear bias. Note that wav2vec2 is non
125
+ # autoregressive model so we want a symmetric mask with 0 on the
126
+ # diagonal and other wise linear decreasing valuees
127
+ pos_bias = (
128
+ torch.abs(
129
+ torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
130
+ )
131
+ * -1
132
+ )
133
+ elif dims == 2:
134
+ if distance == "manhattan":
135
+ df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
136
+ elif distance == "euclidean":
137
+ df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
138
+
139
+ n = math.sqrt(max_positions)
140
+ assert n.is_integer(), n
141
+ n = int(n)
142
+
143
+ pos_bias = torch.zeros((max_positions, max_positions))
144
+
145
+ for i in range(n):
146
+ for j in range(n):
147
+ for k in range(n):
148
+ for l in range(n):
149
+ new_x = i * n + j
150
+ new_y = k * n + l
151
+ pos_bias[new_x, new_y] = -df(i, j, k, l)
152
+
153
+ else:
154
+ raise Exception(f"unsupported number of alibi dims: {dims}")
155
+
156
+ alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
157
+ attn_heads, -1, -1
158
+ )
159
+
160
+ return alibi_bias
161
+
162
+
163
+ def get_alibi_bias(
164
+ alibi_biases,
165
+ batch_size,
166
+ time_steps,
167
+ heads,
168
+ dtype,
169
+ device,
170
+ dims=1,
171
+ distance="manhattan",
172
+ ):
173
+ cache_key = f"{dims}_{heads}_{distance}"
174
+
175
+ buffered = alibi_biases.get(cache_key, None)
176
+
177
+ target_size = heads * batch_size
178
+ if (
179
+ buffered is None
180
+ or buffered.size(0) < target_size
181
+ or buffered.size(1) < time_steps
182
+ or buffered.dtype != dtype
183
+ or buffered.device != device
184
+ ):
185
+ bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
186
+ bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
187
+
188
+ buffered = (
189
+ get_alibi(bt, heads, dims=dims, distance=distance)
190
+ .to(dtype=dtype, device=device)
191
+ .repeat(bn, 1, 1)
192
+ )
193
+
194
+ alibi_biases[cache_key] = buffered
195
+
196
+ b = buffered[:target_size, :time_steps, :time_steps]
197
+ b = b.view(batch_size, heads, time_steps, time_steps)
198
+ return b
199
+
200
+
201
+ def is_xla_tensor(tensor):
202
+ return torch.is_tensor(tensor) and tensor.device.type == "xla"
203
+
204
+
205
+ def index_put(tensor, indices, value):
206
+ if is_xla_tensor(tensor):
207
+ for _ in range(indices.dim(), tensor.dim()):
208
+ indices = indices.unsqueeze(-1)
209
+ if indices.size(-1) < tensor.size(-1):
210
+ indices = indices.expand_as(tensor)
211
+ tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
212
+ else:
213
+ tensor[indices] = value
214
+ return tensor
215
+
216
+
217
+ def compute_mask_indices(
218
+ shape: Tuple[int, int],
219
+ padding_mask: Optional[torch.Tensor],
220
+ mask_prob: float,
221
+ mask_length: int,
222
+ mask_type: str = "static",
223
+ mask_other: float = 0.0,
224
+ min_masks: int = 0,
225
+ no_overlap: bool = False,
226
+ min_space: int = 0,
227
+ require_same_masks: bool = True,
228
+ mask_dropout: float = 0.0,
229
+ add_masks: bool = False,
230
+ seed: Optional[int] = None,
231
+ epoch: Optional[int] = None,
232
+ indices: Optional[torch.Tensor] = None,
233
+ idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
234
+ num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
235
+ ) -> np.ndarray:
236
+ """
237
+ Computes random mask spans for a given shape
238
+
239
+ Args:
240
+ shape: the the shape for which to compute masks.
241
+ should be of size 2 where first element is batch size and 2nd is timesteps
242
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
243
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
244
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
245
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
246
+ mask_type: how to compute mask lengths
247
+ static = fixed size
248
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
249
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
250
+ poisson = sample from possion distribution with lambda = mask length
251
+ min_masks: minimum number of masked spans
252
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
253
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
254
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
255
+ mask_dropout: randomly dropout this percentage of masks in each example
256
+ """
257
+
258
+ bsz, all_sz = shape
259
+ mask = np.full((bsz, all_sz), False)
260
+
261
+ if num_mask_ver == 1:
262
+ all_num_mask = int(
263
+ # add a random number for probabilistic rounding
264
+ mask_prob * all_sz / float(mask_length)
265
+ + np.random.rand()
266
+ )
267
+ all_num_mask = max(min_masks, all_num_mask)
268
+
269
+ mask_idcs = []
270
+ for i in range(bsz):
271
+ if seed is not None and epoch is not None and indices is not None:
272
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
273
+ else:
274
+ seed_i = None
275
+
276
+ rng = np.random.default_rng(seed_i)
277
+
278
+ if padding_mask is not None:
279
+ sz = all_sz - padding_mask[i].long().sum().item()
280
+ assert sz >= 0, sz
281
+ else:
282
+ sz = all_sz
283
+
284
+ if num_mask_ver == 1:
285
+ if padding_mask is not None:
286
+ num_mask = int(
287
+ # add a random number for probabilistic rounding
288
+ mask_prob * sz / float(mask_length)
289
+ + np.random.rand()
290
+ )
291
+ num_mask = max(min_masks, num_mask)
292
+ else:
293
+ num_mask = all_num_mask
294
+ elif num_mask_ver == 2:
295
+ num_mask = int(
296
+ # add a random number for probabilistic rounding
297
+ mask_prob * sz / float(mask_length)
298
+ + rng.random()
299
+ )
300
+ num_mask = max(min_masks, num_mask)
301
+ else:
302
+ raise ValueError()
303
+
304
+ if mask_type == "static":
305
+ lengths = np.full(num_mask, mask_length)
306
+ elif mask_type == "uniform":
307
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
308
+ elif mask_type == "normal":
309
+ lengths = rng.normal(mask_length, mask_other, size=num_mask)
310
+ lengths = [max(1, int(round(x))) for x in lengths]
311
+ elif mask_type == "poisson":
312
+ lengths = rng.poisson(mask_length, size=num_mask)
313
+ lengths = [int(round(x)) for x in lengths]
314
+ else:
315
+ raise Exception("unknown mask selection " + mask_type)
316
+
317
+ if sum(lengths) == 0:
318
+ if mask_type == "static":
319
+ raise ValueError(f"this should never happens")
320
+ else:
321
+ lengths = [min(mask_length, sz - 1)]
322
+
323
+ if no_overlap:
324
+ mask_idc = []
325
+
326
+ def arrange(s, e, length, keep_length):
327
+ span_start = rng.randint(s, e - length)
328
+ mask_idc.extend(span_start + i for i in range(length))
329
+
330
+ new_parts = []
331
+ if span_start - s - min_space >= keep_length:
332
+ new_parts.append((s, span_start - min_space + 1))
333
+ if e - span_start - length - min_space > keep_length:
334
+ new_parts.append((span_start + length + min_space, e))
335
+ return new_parts
336
+
337
+ parts = [(0, sz)]
338
+ min_length = min(lengths)
339
+ for length in sorted(lengths, reverse=True):
340
+ lens = np.fromiter(
341
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
342
+ np.int,
343
+ )
344
+ l_sum = np.sum(lens)
345
+ if l_sum == 0:
346
+ break
347
+ probs = lens / np.sum(lens)
348
+ c = rng.choice(len(parts), p=probs)
349
+ s, e = parts.pop(c)
350
+ parts.extend(arrange(s, e, length, min_length))
351
+ mask_idc = np.asarray(mask_idc)
352
+ else:
353
+ if idc_select_ver == 1:
354
+ min_len = min(lengths)
355
+ if sz - min_len <= num_mask:
356
+ min_len = sz - num_mask - 1
357
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
358
+ elif idc_select_ver == 2:
359
+ mask_idc = rng.choice(sz, num_mask, replace=False)
360
+ else:
361
+ raise ValueError()
362
+
363
+ mask_idc = np.asarray(
364
+ [
365
+ mask_idc[j] + offset
366
+ for j in range(len(mask_idc))
367
+ for offset in range(lengths[j])
368
+ ]
369
+ )
370
+
371
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
372
+ if len(mask_idc) >= sz:
373
+ raise ValueError(
374
+ (
375
+ f"the entire sequence is masked. "
376
+ f"sz={sz}; mask_idc[mask_idc]; "
377
+ f"index={indices[i] if indices is not None else None}"
378
+ )
379
+ )
380
+ mask_idcs.append(mask_idc)
381
+
382
+ target_len = None
383
+ if require_same_masks:
384
+ if add_masks:
385
+ target_len = max([len(m) for m in mask_idcs])
386
+ else:
387
+ target_len = min([len(m) for m in mask_idcs])
388
+
389
+ for i, mask_idc in enumerate(mask_idcs):
390
+ if target_len is not None and len(mask_idc) > target_len:
391
+ mask_idc = rng.choice(mask_idc, target_len, replace=False)
392
+
393
+ mask[i, mask_idc] = True
394
+
395
+ if target_len is not None and len(mask_idc) < target_len:
396
+ unmasked = np.flatnonzero(~mask[i])
397
+ to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
398
+ mask[i, to_mask] = True
399
+
400
+ if mask_dropout > 0:
401
+ masked = np.flatnonzero(mask[i])
402
+ num_holes = np.rint(len(masked) * mask_dropout).astype(int)
403
+ to_drop = rng.choice(masked, num_holes, replace=False)
404
+ mask[i, to_drop] = False
405
+
406
+ return mask
407
+
408
+
409
+ def _learned_alibi_bias(
410
+ alibi_bias,
411
+ batch_size,
412
+ time_steps,
413
+ heads,
414
+ scale,
415
+ dtype,
416
+ device,
417
+ ):
418
+ assert alibi_bias.size(1) == heads, alibi_bias.shape
419
+ assert alibi_bias.dtype == dtype, alibi_bias.dtype
420
+ assert alibi_bias.device == device, alibi_bias.device
421
+
422
+ if alibi_bias.size(-1) < time_steps:
423
+ psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
424
+ alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
425
+
426
+ alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
427
+ return alibi_bias[..., :time_steps, :time_steps]
428
+
429
+ def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
430
+ """Replace non-padding symbols with their position numbers.
431
+
432
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
433
+ """
434
+ # The series of casts and type-conversions here are carefully
435
+ # balanced to both work with ONNX export and XLA. In particular XLA
436
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
437
+ # how to handle the dtype kwarg in cumsum.
438
+ mask = tensor.ne(padding_idx).int()
439
+ return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
utils_pantagruel_uni.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+
9
+ import math
10
+ import numpy as np
11
+ from collections import namedtuple
12
+ from typing import Optional, Tuple
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+
18
+ MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
19
+ MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
20
+
21
+
22
+ def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
23
+ return torch.gather(
24
+ x,
25
+ dim=1,
26
+ index=mask_info.ids_keep,
27
+ )
28
+
29
+
30
+ def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
31
+ return torch.gather(
32
+ x,
33
+ dim=1,
34
+ index=mask_info.ids_keep[..., 0], # ignore the feature dimension
35
+ )
36
+
37
+
38
+ def masked_alibi(alibi_bias, mask_info):
39
+ H = alibi_bias.size(1)
40
+
41
+ orig_bias = alibi_bias
42
+
43
+ index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
44
+ alibi_bias = torch.gather(
45
+ orig_bias,
46
+ dim=-2,
47
+ index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
48
+ )
49
+ alibi_bias = torch.gather(
50
+ alibi_bias,
51
+ dim=-1,
52
+ index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
53
+ )
54
+
55
+ return alibi_bias
56
+
57
+
58
+ def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]):
59
+ N, L, D = x.shape # batch, length, dim
60
+ len_keep = int(L * (1 - mask_ratio))
61
+
62
+ generator = None
63
+ if mask_seed is not None:
64
+ seed = int(
65
+ hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
66
+ )
67
+ generator = torch.Generator(device=x.device)
68
+ generator.manual_seed(seed)
69
+
70
+ noise = torch.rand(N, L, generator=generator, device=x.device) # noise in [0, 1]
71
+
72
+ # sort noise for each sample
73
+ ids_shuffle = noise.argsort(dim=1) # ascend: small is keep, large is remove
74
+ ids_restore = ids_shuffle.argsort(dim=1)
75
+
76
+ # keep the first subset
77
+ ids_keep = ids_shuffle[:, :len_keep]
78
+ ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
79
+ x_unmasked = torch.gather(x, dim=1, index=ids_keep)
80
+
81
+ # generate the binary mask: 0 is keep, 1 is remove
82
+ mask = torch.ones([N, L], dtype=x.dtype, device=x.device)
83
+ mask[:, :len_keep] = 0
84
+ # unshuffle to get the binary mask
85
+ mask = torch.gather(mask, dim=1, index=ids_restore)
86
+
87
+ ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
88
+
89
+ return MaskInfo(
90
+ x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
91
+ )
92
+
93
+
94
+ def get_alibi(
95
+ max_positions: int,
96
+ attention_heads: int,
97
+ dims: int = 1,
98
+ distance: str = "manhattan",
99
+ ):
100
+ def get_slopes(n):
101
+ def get_slopes_power_of_2(n):
102
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
103
+ ratio = start
104
+ return [start * ratio**i for i in range(n)]
105
+
106
+ # In the paper, we only train models that have 2^a heads for some
107
+ # a. This function has some good properties that only occur when
108
+ # the input is a power of 2. To maintain that even when the number
109
+ # of heads is not a power of 2, we use this workaround.
110
+ if math.log2(n).is_integer():
111
+ return get_slopes_power_of_2(n)
112
+ else:
113
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
114
+ return (
115
+ get_slopes_power_of_2(closest_power_of_2)
116
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
117
+ )
118
+
119
+ maxpos = max_positions
120
+ attn_heads = attention_heads
121
+ slopes = torch.Tensor(get_slopes(attn_heads))
122
+
123
+ if dims == 1:
124
+ # prepare alibi position linear bias. Note that wav2vec2 is non
125
+ # autoregressive model so we want a symmetric mask with 0 on the
126
+ # diagonal and other wise linear decreasing valuees
127
+ pos_bias = (
128
+ torch.abs(
129
+ torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
130
+ )
131
+ * -1
132
+ )
133
+ elif dims == 2:
134
+ if distance == "manhattan":
135
+ df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
136
+ elif distance == "euclidean":
137
+ df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
138
+
139
+ n = math.sqrt(max_positions)
140
+ assert n.is_integer(), n
141
+ n = int(n)
142
+
143
+ pos_bias = torch.zeros((max_positions, max_positions))
144
+
145
+ for i in range(n):
146
+ for j in range(n):
147
+ for k in range(n):
148
+ for l in range(n):
149
+ new_x = i * n + j
150
+ new_y = k * n + l
151
+ pos_bias[new_x, new_y] = -df(i, j, k, l)
152
+
153
+ else:
154
+ raise Exception(f"unsupported number of alibi dims: {dims}")
155
+
156
+ alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
157
+ attn_heads, -1, -1
158
+ )
159
+
160
+ return alibi_bias
161
+
162
+
163
+ def get_alibi_bias(
164
+ alibi_biases,
165
+ batch_size,
166
+ time_steps,
167
+ heads,
168
+ dtype,
169
+ device,
170
+ dims=1,
171
+ distance="manhattan",
172
+ ):
173
+ cache_key = f"{dims}_{heads}_{distance}"
174
+
175
+ buffered = alibi_biases.get(cache_key, None)
176
+
177
+ target_size = heads * batch_size
178
+ if (
179
+ buffered is None
180
+ or buffered.size(0) < target_size
181
+ or buffered.size(1) < time_steps
182
+ or buffered.dtype != dtype
183
+ or buffered.device != device
184
+ ):
185
+ bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
186
+ bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
187
+
188
+ buffered = (
189
+ get_alibi(bt, heads, dims=dims, distance=distance)
190
+ .to(dtype=dtype, device=device)
191
+ .repeat(bn, 1, 1)
192
+ )
193
+
194
+ alibi_biases[cache_key] = buffered
195
+
196
+ b = buffered[:target_size, :time_steps, :time_steps]
197
+ b = b.view(batch_size, heads, time_steps, time_steps)
198
+ return b
199
+
200
+
201
+ def is_xla_tensor(tensor):
202
+ return torch.is_tensor(tensor) and tensor.device.type == "xla"
203
+
204
+
205
+ def index_put(tensor, indices, value):
206
+ if is_xla_tensor(tensor):
207
+ for _ in range(indices.dim(), tensor.dim()):
208
+ indices = indices.unsqueeze(-1)
209
+ if indices.size(-1) < tensor.size(-1):
210
+ indices = indices.expand_as(tensor)
211
+ tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
212
+ else:
213
+ tensor[indices] = value
214
+ return tensor
215
+
216
+
217
+ def compute_mask_indices(
218
+ shape: Tuple[int, int],
219
+ padding_mask: Optional[torch.Tensor],
220
+ mask_prob: float,
221
+ mask_length: int,
222
+ mask_type: str = "static",
223
+ mask_other: float = 0.0,
224
+ min_masks: int = 0,
225
+ no_overlap: bool = False,
226
+ min_space: int = 0,
227
+ require_same_masks: bool = True,
228
+ mask_dropout: float = 0.0,
229
+ add_masks: bool = False,
230
+ seed: Optional[int] = None,
231
+ epoch: Optional[int] = None,
232
+ indices: Optional[torch.Tensor] = None,
233
+ idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
234
+ num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
235
+ ) -> np.ndarray:
236
+ """
237
+ Computes random mask spans for a given shape
238
+
239
+ Args:
240
+ shape: the the shape for which to compute masks.
241
+ should be of size 2 where first element is batch size and 2nd is timesteps
242
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
243
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
244
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
245
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
246
+ mask_type: how to compute mask lengths
247
+ static = fixed size
248
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
249
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
250
+ poisson = sample from possion distribution with lambda = mask length
251
+ min_masks: minimum number of masked spans
252
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
253
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
254
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
255
+ mask_dropout: randomly dropout this percentage of masks in each example
256
+ """
257
+
258
+ bsz, all_sz = shape
259
+ mask = np.full((bsz, all_sz), False)
260
+
261
+ if num_mask_ver == 1:
262
+ all_num_mask = int(
263
+ # add a random number for probabilistic rounding
264
+ mask_prob * all_sz / float(mask_length)
265
+ + np.random.rand()
266
+ )
267
+ all_num_mask = max(min_masks, all_num_mask)
268
+
269
+ mask_idcs = []
270
+ for i in range(bsz):
271
+ if seed is not None and epoch is not None and indices is not None:
272
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
273
+ else:
274
+ seed_i = None
275
+
276
+ rng = np.random.default_rng(seed_i)
277
+
278
+ if padding_mask is not None:
279
+ sz = all_sz - padding_mask[i].long().sum().item()
280
+ assert sz >= 0, sz
281
+ else:
282
+ sz = all_sz
283
+
284
+ if num_mask_ver == 1:
285
+ if padding_mask is not None:
286
+ num_mask = int(
287
+ # add a random number for probabilistic rounding
288
+ mask_prob * sz / float(mask_length)
289
+ + np.random.rand()
290
+ )
291
+ num_mask = max(min_masks, num_mask)
292
+ else:
293
+ num_mask = all_num_mask
294
+ elif num_mask_ver == 2:
295
+ num_mask = int(
296
+ # add a random number for probabilistic rounding
297
+ mask_prob * sz / float(mask_length)
298
+ + rng.random()
299
+ )
300
+ num_mask = max(min_masks, num_mask)
301
+ else:
302
+ raise ValueError()
303
+
304
+ if mask_type == "static":
305
+ lengths = np.full(num_mask, mask_length)
306
+ elif mask_type == "uniform":
307
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
308
+ elif mask_type == "normal":
309
+ lengths = rng.normal(mask_length, mask_other, size=num_mask)
310
+ lengths = [max(1, int(round(x))) for x in lengths]
311
+ elif mask_type == "poisson":
312
+ lengths = rng.poisson(mask_length, size=num_mask)
313
+ lengths = [int(round(x)) for x in lengths]
314
+ else:
315
+ raise Exception("unknown mask selection " + mask_type)
316
+
317
+ if sum(lengths) == 0:
318
+ if mask_type == "static":
319
+ raise ValueError(f"this should never happens")
320
+ else:
321
+ lengths = [min(mask_length, sz - 1)]
322
+
323
+ if no_overlap:
324
+ mask_idc = []
325
+
326
+ def arrange(s, e, length, keep_length):
327
+ span_start = rng.randint(s, e - length)
328
+ mask_idc.extend(span_start + i for i in range(length))
329
+
330
+ new_parts = []
331
+ if span_start - s - min_space >= keep_length:
332
+ new_parts.append((s, span_start - min_space + 1))
333
+ if e - span_start - length - min_space > keep_length:
334
+ new_parts.append((span_start + length + min_space, e))
335
+ return new_parts
336
+
337
+ parts = [(0, sz)]
338
+ min_length = min(lengths)
339
+ for length in sorted(lengths, reverse=True):
340
+ lens = np.fromiter(
341
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
342
+ np.int,
343
+ )
344
+ l_sum = np.sum(lens)
345
+ if l_sum == 0:
346
+ break
347
+ probs = lens / np.sum(lens)
348
+ c = rng.choice(len(parts), p=probs)
349
+ s, e = parts.pop(c)
350
+ parts.extend(arrange(s, e, length, min_length))
351
+ mask_idc = np.asarray(mask_idc)
352
+ else:
353
+ if idc_select_ver == 1:
354
+ min_len = min(lengths)
355
+ if sz - min_len <= num_mask:
356
+ min_len = sz - num_mask - 1
357
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
358
+ elif idc_select_ver == 2:
359
+ mask_idc = rng.choice(sz, num_mask, replace=False)
360
+ else:
361
+ raise ValueError()
362
+
363
+ mask_idc = np.asarray(
364
+ [
365
+ mask_idc[j] + offset
366
+ for j in range(len(mask_idc))
367
+ for offset in range(lengths[j])
368
+ ]
369
+ )
370
+
371
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
372
+ if len(mask_idc) >= sz:
373
+ raise ValueError(
374
+ (
375
+ f"the entire sequence is masked. "
376
+ f"sz={sz}; mask_idc[mask_idc]; "
377
+ f"index={indices[i] if indices is not None else None}"
378
+ )
379
+ )
380
+ mask_idcs.append(mask_idc)
381
+
382
+ target_len = None
383
+ if require_same_masks:
384
+ if add_masks:
385
+ target_len = max([len(m) for m in mask_idcs])
386
+ else:
387
+ target_len = min([len(m) for m in mask_idcs])
388
+
389
+ for i, mask_idc in enumerate(mask_idcs):
390
+ if target_len is not None and len(mask_idc) > target_len:
391
+ mask_idc = rng.choice(mask_idc, target_len, replace=False)
392
+
393
+ mask[i, mask_idc] = True
394
+
395
+ if target_len is not None and len(mask_idc) < target_len:
396
+ unmasked = np.flatnonzero(~mask[i])
397
+ to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
398
+ mask[i, to_mask] = True
399
+
400
+ if mask_dropout > 0:
401
+ masked = np.flatnonzero(mask[i])
402
+ num_holes = np.rint(len(masked) * mask_dropout).astype(int)
403
+ to_drop = rng.choice(masked, num_holes, replace=False)
404
+ mask[i, to_drop] = False
405
+
406
+ return mask
407
+
408
+
409
+ def _learned_alibi_bias(
410
+ alibi_bias,
411
+ batch_size,
412
+ time_steps,
413
+ heads,
414
+ scale,
415
+ dtype,
416
+ device,
417
+ ):
418
+ assert alibi_bias.size(1) == heads, alibi_bias.shape
419
+ assert alibi_bias.dtype == dtype, alibi_bias.dtype
420
+ assert alibi_bias.device == device, alibi_bias.device
421
+
422
+ if alibi_bias.size(-1) < time_steps:
423
+ psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
424
+ alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
425
+
426
+ alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
427
+ return alibi_bias[..., :time_steps, :time_steps]
428
+
429
+ def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
430
+ """Replace non-padding symbols with their position numbers.
431
+
432
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
433
+ """
434
+ # The series of casts and type-conversions here are carefully
435
+ # balanced to both work with ONNX export and XLA. In particular XLA
436
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
437
+ # how to handle the dtype kwarg in cumsum.
438
+ mask = tensor.ne(padding_idx).int()
439
+ return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx