flaubert commited on
Commit
9dd2d61
·
verified ·
1 Parent(s): f214f73

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "",
3
+ "activation_dropout": 0.0,
4
+ "add_cross_attention": false,
5
+ "architectures": [
6
+ "PantagruelUniForMaskedLM"
7
+ ],
8
+ "attention_dropout": 0.1,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_pantagruel_uni.PantagruelUniConfig",
11
+ "AutoModel": "modeling_pantagruel_uni.PantagruelUniModel",
12
+ "AutoModelForMaskedLM": "modeling_pantagruel_uni.PantagruelUniForMaskedLM",
13
+ "AutoModelForMultipleChoice": "modeling_pantagruel_uni.PantagruelUniForMultipleChoice",
14
+ "AutoModelForQuestionAnswering": "modeling_pantagruel_uni.PantagruelUniForQuestionAnswering",
15
+ "AutoModelForSequenceClassification": "modeling_pantagruel_uni.PantagruelUniForSequenceClassification",
16
+ "AutoModelForTokenClassification": "modeling_pantagruel_uni.PantagruelUniForTokenClassification"
17
+ },
18
+ "bad_words_ids": null,
19
+ "begin_suppress_tokens": null,
20
+ "bos_token_id": null,
21
+ "chunk_size_feed_forward": 0,
22
+ "classifier_dropout": null,
23
+ "clone_batch": 8,
24
+ "cross_attention_hidden_size": null,
25
+ "decoder_start_token_id": null,
26
+ "depth": 12,
27
+ "diversity_penalty": 0.0,
28
+ "do_sample": false,
29
+ "dropout_input": 0.0,
30
+ "dtype": "float32",
31
+ "early_stopping": false,
32
+ "embed_dim": 768,
33
+ "encoder_dropout": 0.1,
34
+ "encoder_no_repeat_ngram_size": 0,
35
+ "end_drop_path_rate": 0.0,
36
+ "end_of_block_targets": false,
37
+ "eos_token_id": null,
38
+ "exponential_decay_length_penalty": null,
39
+ "finetuning_task": null,
40
+ "forced_bos_token_id": null,
41
+ "forced_eos_token_id": null,
42
+ "hidden_size": 768,
43
+ "id2label": {
44
+ "0": "LABEL_0",
45
+ "1": "LABEL_1"
46
+ },
47
+ "is_decoder": false,
48
+ "is_encoder_decoder": false,
49
+ "label2id": {
50
+ "LABEL_0": 0,
51
+ "LABEL_1": 1
52
+ },
53
+ "layer_norm_first": false,
54
+ "layerdrop": 0.0,
55
+ "length_penalty": 1.0,
56
+ "log_norms": true,
57
+ "max_length": 20,
58
+ "min_length": 0,
59
+ "mlp_ratio": 4.0,
60
+ "modalities": {
61
+ "_name_or_path": "",
62
+ "add_cross_attention": false,
63
+ "architectures": null,
64
+ "audio": {
65
+ "_name_or_path": "",
66
+ "add_cross_attention": false,
67
+ "add_masks": false,
68
+ "alibi_max_pos": null,
69
+ "alibi_scale": 1.0,
70
+ "architectures": null,
71
+ "bad_words_ids": null,
72
+ "begin_suppress_tokens": null,
73
+ "bos_token_id": null,
74
+ "chunk_size_feed_forward": 0,
75
+ "conv_pos_depth": 5,
76
+ "conv_pos_groups": 16,
77
+ "conv_pos_pre_ln": false,
78
+ "conv_pos_width": 95,
79
+ "cross_attention_hidden_size": null,
80
+ "decoder_start_token_id": null,
81
+ "diversity_penalty": 0.0,
82
+ "do_sample": false,
83
+ "dtype": null,
84
+ "early_stopping": false,
85
+ "encoder_no_repeat_ngram_size": 0,
86
+ "encoder_zero_mask": true,
87
+ "end_drop_path_rate": 0.0,
88
+ "eos_token_id": null,
89
+ "exponential_decay_length_penalty": null,
90
+ "extractor_mode": "layer_norm",
91
+ "feature_encoder_spec": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
92
+ "finetuning_task": null,
93
+ "forced_bos_token_id": null,
94
+ "forced_eos_token_id": null,
95
+ "id2label": {
96
+ "0": "LABEL_0",
97
+ "1": "LABEL_1"
98
+ },
99
+ "init_extra_token_zero": true,
100
+ "inverse_mask": false,
101
+ "is_decoder": false,
102
+ "is_encoder_decoder": false,
103
+ "keep_masked_pct": 0.0,
104
+ "label2id": {
105
+ "LABEL_0": 0,
106
+ "LABEL_1": 1
107
+ },
108
+ "learned_alibi": false,
109
+ "learned_alibi_scale": false,
110
+ "learned_alibi_scale_per_head": false,
111
+ "learned_alibi_scale_per_layer": false,
112
+ "length_penalty": 1.0,
113
+ "local_grad_mult": 1.0,
114
+ "mask_channel_length": 64,
115
+ "mask_channel_prob": 0.0,
116
+ "mask_dropout": 0.0,
117
+ "mask_length": 5,
118
+ "mask_noise_std": 0.01,
119
+ "mask_prob": 0.7,
120
+ "mask_prob_adjust": 0.0,
121
+ "mask_prob_min": null,
122
+ "max_length": 20,
123
+ "min_length": 0,
124
+ "model_depth": 12,
125
+ "model_type": "",
126
+ "no_repeat_ngram_size": 0,
127
+ "num_alibi_heads": 12,
128
+ "num_beam_groups": 1,
129
+ "num_beams": 1,
130
+ "num_extra_tokens": 0,
131
+ "num_return_sequences": 1,
132
+ "output_attentions": false,
133
+ "output_hidden_states": false,
134
+ "output_scores": false,
135
+ "pad_token_id": null,
136
+ "prefix": null,
137
+ "prenet_depth": 4,
138
+ "prenet_dropout": 0.0,
139
+ "prenet_layerdrop": 0.0,
140
+ "problem_type": null,
141
+ "pruned_heads": {},
142
+ "remove_invalid_values": false,
143
+ "remove_masks": false,
144
+ "repetition_penalty": 1.0,
145
+ "return_dict": true,
146
+ "return_dict_in_generate": false,
147
+ "sep_token_id": null,
148
+ "start_drop_path_rate": 0.0,
149
+ "suppress_tokens": null,
150
+ "task_specific_params": null,
151
+ "temperature": 1.0,
152
+ "tie_encoder_decoder": false,
153
+ "tie_word_embeddings": true,
154
+ "tokenizer_class": null,
155
+ "top_k": 50,
156
+ "top_p": 1.0,
157
+ "torchscript": false,
158
+ "type": "AUDIO",
159
+ "typical_p": 1.0,
160
+ "use_alibi_encoder": false
161
+ },
162
+ "bad_words_ids": null,
163
+ "begin_suppress_tokens": null,
164
+ "bos_token_id": null,
165
+ "chunk_size_feed_forward": 0,
166
+ "cross_attention_hidden_size": null,
167
+ "decoder_start_token_id": null,
168
+ "diversity_penalty": 0.0,
169
+ "do_sample": false,
170
+ "dtype": null,
171
+ "early_stopping": false,
172
+ "encoder_no_repeat_ngram_size": 0,
173
+ "eos_token_id": null,
174
+ "exponential_decay_length_penalty": null,
175
+ "finetuning_task": null,
176
+ "forced_bos_token_id": null,
177
+ "forced_eos_token_id": null,
178
+ "id2label": {
179
+ "0": "LABEL_0",
180
+ "1": "LABEL_1"
181
+ },
182
+ "is_decoder": false,
183
+ "is_encoder_decoder": false,
184
+ "label2id": {
185
+ "LABEL_0": 0,
186
+ "LABEL_1": 1
187
+ },
188
+ "length_penalty": 1.0,
189
+ "max_length": 20,
190
+ "min_length": 0,
191
+ "model_type": "",
192
+ "no_repeat_ngram_size": 0,
193
+ "num_beam_groups": 1,
194
+ "num_beams": 1,
195
+ "num_return_sequences": 1,
196
+ "output_attentions": false,
197
+ "output_hidden_states": false,
198
+ "output_scores": false,
199
+ "pad_token_id": null,
200
+ "prefix": null,
201
+ "problem_type": null,
202
+ "pruned_heads": {},
203
+ "remove_invalid_values": false,
204
+ "repetition_penalty": 1.0,
205
+ "return_dict": true,
206
+ "return_dict_in_generate": false,
207
+ "sep_token_id": null,
208
+ "suppress_tokens": null,
209
+ "task_specific_params": null,
210
+ "temperature": 1.0,
211
+ "text": {
212
+ "_name_or_path": "",
213
+ "add_cross_attention": false,
214
+ "add_masks": false,
215
+ "alibi_max_pos": null,
216
+ "alibi_scale": 1.0,
217
+ "architectures": null,
218
+ "bad_words_ids": null,
219
+ "begin_suppress_tokens": null,
220
+ "bos_token_id": 0,
221
+ "chunk_size_feed_forward": 0,
222
+ "cross_attention_hidden_size": null,
223
+ "decoder_start_token_id": null,
224
+ "diversity_penalty": 0.0,
225
+ "do_sample": false,
226
+ "dropout": 0.1,
227
+ "dtype": null,
228
+ "early_stopping": false,
229
+ "encoder_no_repeat_ngram_size": 0,
230
+ "encoder_zero_mask": true,
231
+ "end_drop_path_rate": 0.0,
232
+ "eos_token_id": 2,
233
+ "exponential_decay_length_penalty": null,
234
+ "finetuning_task": null,
235
+ "forced_bos_token_id": null,
236
+ "forced_eos_token_id": null,
237
+ "id2label": {
238
+ "0": "LABEL_0",
239
+ "1": "LABEL_1"
240
+ },
241
+ "init_extra_token_zero": true,
242
+ "inverse_mask": false,
243
+ "is_decoder": false,
244
+ "is_encoder_decoder": false,
245
+ "keep_masked_pct": 0.0,
246
+ "label2id": {
247
+ "LABEL_0": 0,
248
+ "LABEL_1": 1
249
+ },
250
+ "layernorm_embedding": true,
251
+ "learned_alibi": false,
252
+ "learned_alibi_scale": true,
253
+ "learned_alibi_scale_per_head": true,
254
+ "learned_alibi_scale_per_layer": false,
255
+ "learned_pos": true,
256
+ "length_penalty": 1.0,
257
+ "local_grad_mult": 1.0,
258
+ "mask_channel_length": 64,
259
+ "mask_channel_prob": 0.0,
260
+ "mask_dropout": 0.0,
261
+ "mask_length": 3,
262
+ "mask_noise_std": 0.01,
263
+ "mask_prob": 0.6,
264
+ "mask_prob_adjust": 0.0,
265
+ "mask_prob_min": null,
266
+ "max_length": 20,
267
+ "max_source_positions": 512,
268
+ "min_length": 0,
269
+ "model_depth": 12,
270
+ "model_type": "",
271
+ "no_repeat_ngram_size": 0,
272
+ "no_scale_embedding": true,
273
+ "no_token_positional_embeddings": false,
274
+ "num_alibi_heads": 12,
275
+ "num_beam_groups": 1,
276
+ "num_beams": 1,
277
+ "num_extra_tokens": 0,
278
+ "num_return_sequences": 1,
279
+ "output_attentions": false,
280
+ "output_hidden_states": false,
281
+ "output_scores": false,
282
+ "pad_token_id": 1,
283
+ "prefix": null,
284
+ "prenet_depth": 0,
285
+ "prenet_dropout": 0.0,
286
+ "prenet_layerdrop": 0.0,
287
+ "problem_type": null,
288
+ "pruned_heads": {},
289
+ "remove_invalid_values": false,
290
+ "remove_masks": false,
291
+ "repetition_penalty": 1.0,
292
+ "return_dict": true,
293
+ "return_dict_in_generate": false,
294
+ "sep_token_id": null,
295
+ "start_drop_path_rate": 0.0,
296
+ "suppress_tokens": null,
297
+ "task_specific_params": null,
298
+ "temperature": 1.0,
299
+ "tie_encoder_decoder": false,
300
+ "tie_word_embeddings": true,
301
+ "tokenizer_class": null,
302
+ "top_k": 50,
303
+ "top_p": 1.0,
304
+ "torchscript": false,
305
+ "type": "TEXT",
306
+ "typical_p": 1.0,
307
+ "unk_token_id": 3,
308
+ "use_alibi_encoder": true,
309
+ "vocab_size": 50368
310
+ },
311
+ "tie_encoder_decoder": false,
312
+ "tie_word_embeddings": true,
313
+ "tokenizer_class": null,
314
+ "top_k": 50,
315
+ "top_p": 1.0,
316
+ "torchscript": false,
317
+ "typical_p": 1.0
318
+ },
319
+ "model_type": "pantagruel_uni",
320
+ "n_layers": 12,
321
+ "no_repeat_ngram_size": 0,
322
+ "norm_affine": true,
323
+ "norm_eps": 1e-05,
324
+ "num_beam_groups": 1,
325
+ "num_beams": 1,
326
+ "num_heads": 12,
327
+ "num_hidden_layers": 12,
328
+ "num_layers": 12,
329
+ "num_return_sequences": 1,
330
+ "output_attentions": false,
331
+ "output_hidden_states": false,
332
+ "output_scores": false,
333
+ "pad_token_id": null,
334
+ "post_mlp_drop": 0.1,
335
+ "prefix": null,
336
+ "problem_type": null,
337
+ "pruned_heads": {},
338
+ "remove_invalid_values": false,
339
+ "repetition_penalty": 1.0,
340
+ "return_dict": true,
341
+ "return_dict_in_generate": false,
342
+ "sep_token_id": null,
343
+ "start_drop_path_rate": 0.0,
344
+ "supported_modality": "TEXT",
345
+ "suppress_tokens": null,
346
+ "task_specific_params": null,
347
+ "temperature": 1.0,
348
+ "tie_encoder_decoder": false,
349
+ "tie_word_embeddings": true,
350
+ "tokenizer_class": null,
351
+ "top_k": 50,
352
+ "top_p": 1.0,
353
+ "torchscript": false,
354
+ "transformers_version": "4.57.0.dev0",
355
+ "typical_p": 1.0
356
+ }
configuration_pantagruel_uni.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ #
9
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ """ Pantagruel unimodal configuration"""
25
+
26
+ import os
27
+ from typing import Union, Dict, Any, Optional
28
+ from transformers.dynamic_module_utils import custom_object_save
29
+ from transformers.utils import logging
30
+ from transformers.configuration_utils import PretrainedConfig, CONFIG_NAME
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class MyPretrainedConfig(PretrainedConfig):
37
+ def __init__(self, **kwargs):
38
+ super().__init__(**kwargs)
39
+
40
+ def to_json_string(self, use_diff: bool = False) -> str:
41
+ return super().to_json_string(use_diff)
42
+
43
+ def update(self, config_dict):
44
+ for key, value in config_dict.items():
45
+ if not hasattr(self, key):
46
+ continue
47
+ if isinstance(getattr(self, key), MyPretrainedConfig):
48
+ getattr(self, key).update(config_dict[key])
49
+ else:
50
+ setattr(self, key, value)
51
+
52
+ # Copied from the parent class, only changed use_diff from True to False to correctly save nested config class
53
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
54
+ """
55
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
56
+ [`~PretrainedConfig.from_pretrained`] class method.
57
+
58
+ Args:
59
+ save_directory (`str` or `os.PathLike`):
60
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
61
+ push_to_hub (`bool`, *optional*, defaults to `False`):
62
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
63
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
64
+ namespace).
65
+ kwargs (`Dict[str, Any]`, *optional*):
66
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
67
+ """
68
+ self._set_token_in_kwargs(kwargs)
69
+
70
+ if os.path.isfile(save_directory):
71
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
72
+
73
+ non_default_generation_parameters = {}
74
+ for parameter_name, default_value in self._get_global_generation_defaults().items():
75
+ if hasattr(self, parameter_name) and getattr(self, parameter_name) != default_value:
76
+ non_default_generation_parameters[parameter_name] = getattr(self, parameter_name)
77
+ if len(non_default_generation_parameters) > 0:
78
+ logger.warning(
79
+ "Some non-default generation parameters are set in the model config. These should go into a "
80
+ "GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) "
81
+ "instead. This warning will be raised to an exception in v4.41.\n"
82
+ f"Non-default generation parameters: {str(non_default_generation_parameters)}"
83
+ )
84
+
85
+ os.makedirs(save_directory, exist_ok=True)
86
+
87
+ if push_to_hub:
88
+ commit_message = kwargs.pop("commit_message", None)
89
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
90
+ repo_id = self._create_repo(repo_id, **kwargs)
91
+ files_timestamps = self._get_files_timestamps(save_directory)
92
+
93
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
94
+ # loaded from the Hub.
95
+ if self._auto_class is not None:
96
+ custom_object_save(self, save_directory, config=self)
97
+
98
+ # If we save using the predefined names, we can load using `from_pretrained`
99
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
100
+
101
+ self.to_json_file(output_config_file, use_diff=False)
102
+ logger.info(f"Configuration saved in {output_config_file}")
103
+
104
+ if push_to_hub:
105
+ self._upload_modified_files(
106
+ save_directory,
107
+ repo_id,
108
+ files_timestamps,
109
+ commit_message=commit_message,
110
+ token=kwargs.get("token"),
111
+ )
112
+
113
+ # Copied from the parent class, change the instantiation and updating of class from config_dict to correctly load nested config
114
+ @classmethod
115
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "MyPretrainedConfig":
116
+ """
117
+ Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
118
+
119
+ Args:
120
+ config_dict (`Dict[str, Any]`):
121
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
122
+ retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
123
+ kwargs (`Dict[str, Any]`):
124
+ Additional parameters from which to initialize the configuration object.
125
+
126
+ Returns:
127
+ [`PretrainedConfig`]: The configuration object instantiated from those parameters.
128
+ """
129
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
130
+ # Those arguments may be passed along for our internal telemetry.
131
+ # We remove them so they don't appear in `return_unused_kwargs`.
132
+ kwargs.pop("_from_auto", None)
133
+ kwargs.pop("_from_pipeline", None)
134
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
135
+ if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
136
+ kwargs["_commit_hash"] = config_dict["_commit_hash"]
137
+
138
+ # We remove it from kwargs so that it does not appear in `return_unused_kwargs`.
139
+ config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
140
+
141
+ # config = cls(**config_dict)
142
+ # My updated config
143
+ config = cls()
144
+ for key, value in config_dict.items():
145
+ if not hasattr(config, key):
146
+ continue
147
+ if isinstance(getattr(config, key), MyPretrainedConfig):
148
+ getattr(config, key).update(config_dict[key])
149
+ else:
150
+ setattr(config, key, value)
151
+
152
+
153
+ if hasattr(config, "pruned_heads"):
154
+ config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
155
+
156
+ # Update config with kwargs if needed
157
+ if "num_labels" in kwargs and "id2label" in kwargs:
158
+ num_labels = kwargs["num_labels"]
159
+ id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
160
+ if len(id2label) != num_labels:
161
+ raise ValueError(
162
+ f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
163
+ f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
164
+ "one of them."
165
+ )
166
+ to_remove = []
167
+ for key, value in kwargs.items():
168
+ if hasattr(config, key):
169
+ current_attr = getattr(config, key)
170
+ # To authorize passing a custom subconfig as kwarg in models that have nested configs.
171
+ if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
172
+ value = current_attr.__class__(**value)
173
+ setattr(config, key, value)
174
+ if key != "torch_dtype":
175
+ to_remove.append(key)
176
+ for key in to_remove:
177
+ kwargs.pop(key, None)
178
+
179
+ logger.info(f"Model config {config}")
180
+ if return_unused_kwargs:
181
+ return config, kwargs
182
+ else:
183
+ return config
184
+
185
+
186
+ class PantagruelModalityConfig(MyPretrainedConfig):
187
+ def __init__(
188
+ self,
189
+ type="AUDIO",
190
+ prenet_depth=4,
191
+ prenet_layerdrop=0,
192
+ prenet_dropout=0.0,
193
+ start_drop_path_rate=0.0,
194
+ end_drop_path_rate=0.0,
195
+ num_extra_tokens=0,
196
+ init_extra_token_zero=True,
197
+ mask_noise_std=0.01,
198
+ mask_prob_min=None,
199
+ mask_prob=0.7,
200
+ inverse_mask=False,
201
+ mask_prob_adjust=0.0,
202
+ keep_masked_pct=0.0,
203
+ mask_length=5,
204
+ add_masks=False,
205
+ remove_masks=False,
206
+ mask_dropout=0.0,
207
+ encoder_zero_mask=True,
208
+ mask_channel_prob=0.0,
209
+ mask_channel_length=64,
210
+ local_grad_mult=1.0,
211
+ use_alibi_encoder=False,
212
+ alibi_scale=1.0,
213
+ learned_alibi=False,
214
+ alibi_max_pos=None,
215
+ learned_alibi_scale=False,
216
+ learned_alibi_scale_per_head=False,
217
+ learned_alibi_scale_per_layer=False,
218
+ num_alibi_heads=12,
219
+ model_depth=12,
220
+ ema_local_encoder=False,
221
+ decoder=None,
222
+ **kwargs,
223
+ ):
224
+ super().__init__(**kwargs)
225
+ self.type = type
226
+ self.prenet_depth = prenet_depth
227
+ self.prenet_layerdrop = prenet_layerdrop
228
+ self.prenet_dropout = prenet_dropout
229
+ self.start_drop_path_rate = start_drop_path_rate
230
+ self.end_drop_path_rate = end_drop_path_rate
231
+ self.num_extra_tokens = num_extra_tokens
232
+ self.init_extra_token_zero = init_extra_token_zero
233
+ self.mask_noise_std = mask_noise_std
234
+ self.mask_prob_min = mask_prob_min
235
+ self.mask_prob = mask_prob
236
+ self.inverse_mask = inverse_mask
237
+ self.mask_prob_adjust = mask_prob_adjust
238
+ self.keep_masked_pct = keep_masked_pct
239
+ self.mask_length = mask_length
240
+ self.add_masks = add_masks
241
+ self.remove_masks = remove_masks
242
+ self.mask_dropout = mask_dropout
243
+ self.encoder_zero_mask = encoder_zero_mask
244
+ self.mask_channel_prob = mask_channel_prob
245
+ self.mask_channel_length = mask_channel_length
246
+ self.local_grad_mult = local_grad_mult
247
+ self.use_alibi_encoder = use_alibi_encoder
248
+ self.alibi_scale = alibi_scale
249
+ self.learned_alibi = learned_alibi
250
+ self.alibi_max_pos = alibi_max_pos
251
+ self.learned_alibi_scale = learned_alibi_scale
252
+ self.learned_alibi_scale_per_head = learned_alibi_scale_per_head
253
+ self.learned_alibi_scale_per_layer = learned_alibi_scale_per_layer
254
+ self.num_alibi_heads = num_alibi_heads
255
+ self.model_depth = model_depth
256
+
257
+
258
+ class PantagruelAudioConfig(PantagruelModalityConfig):
259
+ """
260
+ Configuration including common args and args specific to audio-only pre-training
261
+ """
262
+ def __init__(
263
+ self,
264
+ extractor_mode="layer_norm",
265
+ feature_encoder_spec="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
266
+ conv_pos_width=95,
267
+ conv_pos_groups=16,
268
+ conv_pos_depth=5,
269
+ conv_pos_pre_ln=False,
270
+ **kwargs,
271
+ ):
272
+ super().__init__(type="AUDIO", **kwargs)
273
+ self.extractor_mode = extractor_mode
274
+ self.feature_encoder_spec = feature_encoder_spec
275
+ self.conv_pos_width = conv_pos_width
276
+ self.conv_pos_groups = conv_pos_groups
277
+ self.conv_pos_depth = conv_pos_depth
278
+ self.conv_pos_pre_ln = conv_pos_pre_ln
279
+
280
+
281
+ class PantagruelTextConfig(PantagruelModalityConfig):
282
+ """
283
+ Configuration including common args and args specific to text-only pre-training
284
+ """
285
+ def __init__(
286
+ self,
287
+ vocab_size=50000,
288
+ unk_token_id=3,
289
+ bos_token_id=0,
290
+ eos_token_id=2,
291
+ pad_token_id=1,
292
+ max_source_positions=512,
293
+ learned_pos=True,
294
+ dropout=0.1,
295
+ no_scale_embedding=True,
296
+ layernorm_embedding=True,
297
+ no_token_positional_embeddings=False,
298
+ **kwargs,
299
+ ):
300
+ super().__init__(type="TEXT", **kwargs)
301
+ self.vocab_size = vocab_size
302
+ self.unk_token_id = unk_token_id
303
+ self.bos_token_id = bos_token_id
304
+ self.eos_token_id = eos_token_id
305
+ self.pad_token_id = pad_token_id
306
+ self.max_source_positions = max_source_positions
307
+ self.learned_pos = learned_pos
308
+ self.dropout = dropout
309
+ self.no_scale_embedding = no_scale_embedding
310
+ self.layernorm_embedding = layernorm_embedding
311
+ self.no_token_positional_embeddings = no_token_positional_embeddings
312
+
313
+
314
+ class PantagruelModalitiesConfig(MyPretrainedConfig):
315
+ def __init__(
316
+ self,
317
+ audio_config=PantagruelAudioConfig(),
318
+ text_config=PantagruelTextConfig(),
319
+ **kwargs
320
+ ):
321
+ super().__init__(**kwargs)
322
+ self.audio = audio_config
323
+ self.text = text_config
324
+
325
+
326
+ class PantagruelUniConfig(MyPretrainedConfig):
327
+ r"""
328
+ This is the configuration class to store the configuration of a [`PantagruelUniModel`]. It is used to instantiate
329
+ an PantagruelUniModel model according to the specified arguments, defining the model architecture.
330
+
331
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
332
+ documentation from [`PretrainedConfig`] for more information.
333
+
334
+
335
+ Args:
336
+ depth (`int`, *optional*, defaults to 12):
337
+ Number of Transformer layers in the encoder.
338
+
339
+ Example:
340
+
341
+ ```python
342
+ >>> from transformers import PantagruelUniConfig, PantagruelUniModel
343
+
344
+ >>> # Initializing a PantagruelUniConfig for audio
345
+ >>> configuration = PantagruelUniConfig()
346
+
347
+ >>> # Initializing a model (with random weights) with the configuration
348
+ >>> model = PantagruelUniModel(configuration)
349
+
350
+ >>> # Accessing the model configuration
351
+ >>> configuration = model.config
352
+ ```"""
353
+
354
+ model_type = "pantagruel_uni"
355
+
356
+ def __init__(
357
+ self,
358
+ depth=12,
359
+ start_drop_path_rate=0.0,
360
+ end_drop_path_rate=0.0,
361
+ num_heads=12,
362
+ norm_eps=1e-5,
363
+ norm_affine=True,
364
+ encoder_dropout=0.1,
365
+ post_mlp_drop=0.1,
366
+ attention_dropout=0.1,
367
+ activation_dropout=0.0,
368
+ dropout_input=0.0,
369
+ layerdrop=0.0,
370
+ embed_dim=768,
371
+ mlp_ratio=4.0,
372
+ layer_norm_first=False,
373
+ end_of_block_targets=False,
374
+ clone_batch=1,
375
+ log_norms=True,
376
+ modalities=PantagruelModalitiesConfig(),
377
+ supported_modality="AUDIO",
378
+ classifier_dropout=None,
379
+ **kwargs,
380
+ ):
381
+ super().__init__(**kwargs)
382
+
383
+ self.depth = depth
384
+ self.start_drop_path_rate = start_drop_path_rate
385
+ self.end_drop_path_rate = end_drop_path_rate
386
+
387
+ self.num_heads = num_heads
388
+ self.norm_eps = norm_eps
389
+ self.norm_affine = norm_affine
390
+ self.post_mlp_drop = post_mlp_drop
391
+ self.encoder_dropout = encoder_dropout
392
+ self.attention_dropout = attention_dropout
393
+ self.activation_dropout = activation_dropout
394
+ self.dropout_input = dropout_input
395
+ self.layerdrop = layerdrop
396
+ self.embed_dim = embed_dim
397
+ self.mlp_ratio = mlp_ratio
398
+
399
+ self.layer_norm_first = layer_norm_first
400
+ self.end_of_block_targets = end_of_block_targets
401
+ self.clone_batch = clone_batch
402
+ self.log_norms = log_norms
403
+
404
+ self.modalities = modalities
405
+ self.supported_modality = supported_modality
406
+
407
+ # Attributes for hopsparser
408
+ self.hidden_size = embed_dim
409
+ self.num_layers = depth
410
+ self.n_layers = depth
411
+ self.num_hidden_layers = depth
412
+
413
+ self.classifier_dropout = classifier_dropout
414
+
415
+ self.auto_map = {
416
+ 'AutoConfig': 'configuration_pantagruel_uni.PantagruelUniConfig',
417
+ 'AutoModel': 'modeling_pantagruel_uni.PantagruelUniModel',
418
+ 'AutoModelForMaskedLM': 'modeling_pantagruel_uni.PantagruelUniForMaskedLM',
419
+ 'AutoModelForSequenceClassification': 'modeling_pantagruel_uni.PantagruelUniForSequenceClassification',
420
+ 'AutoModelForMultipleChoice': 'modeling_pantagruel_uni.PantagruelUniForMultipleChoice',
421
+ 'AutoModelForTokenClassification': 'modeling_pantagruel_uni.PantagruelUniForTokenClassification',
422
+ 'AutoModelForQuestionAnswering': 'modeling_pantagruel_uni.PantagruelUniForQuestionAnswering',
423
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba6004aeb8e399c4ce6d312d18943bc476d9b4da5cf5df1708a1312ba05cacfb
3
+ size 653850992
modeling_pantagruel_uni.py ADDED
@@ -0,0 +1,1964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ # Copyright 2022 the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ # Copyright from Fairseq
23
+
24
+ """ PantagruelUni model."""
25
+ import math
26
+ import warnings
27
+ from typing import Optional, Tuple, Dict, List, Callable, Any, Union
28
+ from functools import partial
29
+ from dataclasses import dataclass
30
+
31
+ import numpy as np
32
+
33
+ import torch
34
+ import torch.nn.functional as F
35
+ from torch import nn
36
+ from torch import Tensor
37
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
38
+
39
+ from transformers import PreTrainedModel
40
+ from transformers.utils import (
41
+ ModelOutput, TransformersKwargs, auto_docstring
42
+ )
43
+ from transformers.activations import ACT2FN, gelu
44
+ from transformers.modeling_attn_mask_utils import (
45
+ _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
46
+ )
47
+ from transformers.utils.generic import can_return_tuple
48
+ from transformers.processing_utils import Unpack
49
+ from transformers.modeling_outputs import (
50
+ MaskedLMOutput,
51
+ MultipleChoiceModelOutput,
52
+ QuestionAnsweringModelOutput,
53
+ SequenceClassifierOutput,
54
+ TokenClassifierOutput,
55
+ )
56
+ from .configuration_pantagruel_uni import (
57
+ PantagruelUniConfig,
58
+ PantagruelModalityConfig,
59
+ PantagruelAudioConfig,
60
+ PantagruelTextConfig,
61
+ )
62
+
63
+ from .utils_pantagruel_uni import (
64
+ _learned_alibi_bias,
65
+ gather_unmasked,
66
+ gather_unmasked_mask,
67
+ masked_alibi,
68
+ random_masking,
69
+ get_alibi_bias,
70
+ compute_mask_indices,
71
+ index_put,
72
+ MaskInfo, MaskSeed,
73
+ make_positions,
74
+ )
75
+
76
+
77
+ @dataclass
78
+ class PantagruelUniBaseModelOutput(ModelOutput):
79
+ last_hidden_state: Optional[torch.FloatTensor] = None # output of the encoder-only model
80
+ pooler_output: Optional[torch.FloatTensor] = None # pooled output for text tasks, which is the first token representation followed by a dense layer and activation function
81
+ local_features: Optional[torch.FloatTensor] = None # features before the Transformer encoder
82
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
83
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
84
+
85
+
86
+ #################################################
87
+ ### modeling_pantagruel_uni_base.py
88
+ # copied from fairseq.modules.grad_multiply
89
+ class GradMultiply(torch.autograd.Function):
90
+ @staticmethod
91
+ def forward(ctx, x, scale):
92
+ ctx.scale = scale
93
+ res = x.new(x)
94
+ return res
95
+
96
+ @staticmethod
97
+ def backward(ctx, grad):
98
+ return grad * ctx.scale, None
99
+
100
+
101
+ # Copied from fairseq.modules.transpose_last.py
102
+ class TransposeLast(nn.Module):
103
+ def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
104
+ super().__init__()
105
+ self.deconstruct_idx = deconstruct_idx
106
+ self.tranpose_dim = tranpose_dim
107
+
108
+ def forward(self, x):
109
+ if self.deconstruct_idx is not None:
110
+ x = x[self.deconstruct_idx]
111
+ return x.transpose(self.tranpose_dim, -1)
112
+
113
+
114
+ # Copied from fairseq.modules.layer_norm.py
115
+ class Fp32LayerNorm(nn.LayerNorm):
116
+ def __init__(self, *args, **kwargs):
117
+ super().__init__(*args, **kwargs)
118
+
119
+ def forward(self, input):
120
+ output = F.layer_norm(
121
+ input.float(),
122
+ self.normalized_shape,
123
+ self.weight.float() if self.weight is not None else None,
124
+ self.bias.float() if self.bias is not None else None,
125
+ self.eps,
126
+ )
127
+ return output.type_as(input)
128
+
129
+
130
+ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
131
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
132
+
133
+
134
+ # Copied from fairseq.modules.fp32_group_norm.py
135
+ class Fp32GroupNorm(nn.GroupNorm):
136
+ def __init__(self, *args, **kwargs):
137
+ super().__init__(*args, **kwargs)
138
+
139
+ def forward(self, input):
140
+ output = F.group_norm(
141
+ input.float(),
142
+ self.num_groups,
143
+ self.weight.float() if self.weight is not None else None,
144
+ self.bias.float() if self.bias is not None else None,
145
+ self.eps,
146
+ )
147
+ return output.type_as(input)
148
+
149
+
150
+ # Copied from fairseq.modules.same_pad.py
151
+ class SamePad(nn.Module):
152
+ def __init__(self, kernel_size, causal=False):
153
+ super().__init__()
154
+ if causal:
155
+ self.remove = kernel_size - 1
156
+ else:
157
+ self.remove = 1 if kernel_size % 2 == 0 else 0
158
+
159
+ def forward(self, x):
160
+ if self.remove > 0:
161
+ x = x[:, :, : -self.remove]
162
+ return x
163
+
164
+
165
+ # Copied from fairseq.models.wav2vec.wav2vec2.py
166
+ class ConvFeatureExtractionModel(nn.Module):
167
+ def __init__(
168
+ self,
169
+ conv_layers: List[Tuple[int, int, int]],
170
+ dropout: float = 0.0,
171
+ mode: str = "default",
172
+ conv_bias: bool = False,
173
+ ):
174
+ super().__init__()
175
+
176
+ assert mode in {"default", "layer_norm"}
177
+
178
+ def block(
179
+ n_in,
180
+ n_out,
181
+ k,
182
+ stride,
183
+ is_layer_norm=False,
184
+ is_group_norm=False,
185
+ conv_bias=False,
186
+ ):
187
+ def make_conv():
188
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
189
+ nn.init.kaiming_normal_(conv.weight)
190
+ return conv
191
+
192
+ assert (
193
+ is_layer_norm and is_group_norm
194
+ ) == False, "layer norm and group norm are exclusive"
195
+
196
+ if is_layer_norm:
197
+ return nn.Sequential(
198
+ make_conv(),
199
+ nn.Dropout(p=dropout),
200
+ nn.Sequential(
201
+ TransposeLast(),
202
+ Fp32LayerNorm(dim, elementwise_affine=True),
203
+ TransposeLast(),
204
+ ),
205
+ nn.GELU(),
206
+ )
207
+ elif is_group_norm:
208
+ return nn.Sequential(
209
+ make_conv(),
210
+ nn.Dropout(p=dropout),
211
+ Fp32GroupNorm(dim, dim, affine=True),
212
+ nn.GELU(),
213
+ )
214
+ else:
215
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
216
+
217
+ in_d = 1
218
+ self.conv_layers = nn.ModuleList()
219
+ for i, cl in enumerate(conv_layers):
220
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
221
+ (dim, k, stride) = cl
222
+
223
+ self.conv_layers.append(
224
+ block(
225
+ in_d,
226
+ dim,
227
+ k,
228
+ stride,
229
+ is_layer_norm=mode == "layer_norm",
230
+ is_group_norm=mode == "default" and i == 0,
231
+ conv_bias=conv_bias,
232
+ )
233
+ )
234
+ in_d = dim
235
+
236
+ def forward(self, x):
237
+
238
+ # BxT -> BxCxT
239
+ x = x.unsqueeze(1)
240
+
241
+ for conv in self.conv_layers:
242
+ x = conv(x)
243
+
244
+ return x
245
+
246
+
247
+ # copied from fairseq.examples.data2vec.models.modalities.modules
248
+ class AltAttention(nn.Module):
249
+ def __init__(
250
+ self,
251
+ dim,
252
+ num_heads=8,
253
+ qkv_bias=False,
254
+ qk_scale=None,
255
+ attn_drop=0.0,
256
+ proj_drop=0.0,
257
+ cosine_attention=False,
258
+ ):
259
+ super().__init__()
260
+ self.num_heads = num_heads
261
+ head_dim = dim // num_heads
262
+ self.scale = qk_scale or head_dim ** -0.5
263
+
264
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
265
+ # self.attn_drop = nn.Dropout(attn_drop)
266
+ self.attn_drop = attn_drop
267
+ self.proj = nn.Linear(dim, dim)
268
+ # self.proj_drop = nn.Dropout(proj_drop)
269
+ self.proj_drop = proj_drop
270
+
271
+ self.cosine_attention = cosine_attention
272
+
273
+ if cosine_attention:
274
+ self.logit_scale = nn.Parameter(
275
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
276
+ )
277
+
278
+ def forward(self, x, padding_mask=None, alibi_bias=None, fast=True):
279
+ B, N, C = x.shape
280
+ qkv = (
281
+ self.qkv(x)
282
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
283
+ .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
284
+ )
285
+ q, k, v = (
286
+ qkv[0],
287
+ qkv[1],
288
+ qkv[2],
289
+ ) # make torchscript happy (cannot use tensor as tuple)
290
+
291
+ dtype = q.dtype
292
+
293
+ attn = None
294
+ if not fast:
295
+ if self.cosine_attention:
296
+ # cosine attention
297
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
298
+ logit_scale = torch.clamp(
299
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
300
+ ).exp()
301
+ attn = attn * logit_scale
302
+ else:
303
+ q = q * self.scale
304
+ attn = q @ k.transpose(-2, -1) # B x C//H x L x L
305
+
306
+ if alibi_bias is not None:
307
+ attn = attn.type_as(alibi_bias)
308
+ attn[:, : alibi_bias.size(1)] += alibi_bias
309
+
310
+ if padding_mask is not None and padding_mask.any():
311
+ attn = attn.masked_fill(
312
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
313
+ float("-inf"),
314
+ )
315
+
316
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
317
+ # attn = self.attn_drop(attn)
318
+ attn = F.dropout(attn, p=self.attn_drop if self.training else 0.0)
319
+ x = (attn @ v).transpose(1, 2)
320
+ else:
321
+ # Using pytorch 2's sdpa
322
+ assert not self.cosine_attention, "Not support cosine attention yet"
323
+ # Integrate padding_mask and alibi_bias
324
+ if padding_mask is not None and padding_mask.any():
325
+ if alibi_bias is not None:
326
+ padding_mask = alibi_bias.masked_fill(
327
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
328
+ float("-inf"),
329
+ ).to(dtype=dtype)
330
+ else:
331
+ padding_mask = padding_mask.unsqueeze(1).unsqueeze(2).to(
332
+ torch.bool).to(dtype=dtype)
333
+ else:
334
+ if alibi_bias is not None:
335
+ padding_mask = alibi_bias.to(dtype=dtype)
336
+ else:
337
+ padding_mask = None
338
+
339
+ x = F.scaled_dot_product_attention(q, k, v,
340
+ attn_mask=padding_mask,
341
+ dropout_p=self.attn_drop if self.training else 0.0,
342
+ scale=self.scale).transpose(1, 2)
343
+
344
+ x = x.reshape(B, N, C)
345
+ x = self.proj(x)
346
+ x = F.dropout(x, p=self.proj_drop if self.training else 0.0)
347
+
348
+ return x, attn
349
+
350
+
351
+ # copied from fairseq.examples.data2vec.models.modalities.modules.py
352
+ class AltBlock(nn.Module):
353
+ def __init__(
354
+ self,
355
+ dim,
356
+ num_heads,
357
+ mlp_ratio=4.0,
358
+ qkv_bias=False,
359
+ qk_scale=None,
360
+ drop=0.0,
361
+ attn_drop=0.0,
362
+ mlp_drop=0.0,
363
+ post_mlp_drop=0.0,
364
+ drop_path=0.0,
365
+ act_layer=nn.GELU,
366
+ norm_layer=nn.LayerNorm,
367
+ layer_norm_first=True,
368
+ ffn_targets=False,
369
+ cosine_attention=False,
370
+ ):
371
+ super().__init__()
372
+
373
+ self.layer_norm_first = layer_norm_first
374
+ self.ffn_targets = ffn_targets
375
+
376
+ from timm.models.vision_transformer import DropPath, Mlp
377
+
378
+ self.norm1 = norm_layer(dim)
379
+ self.attn = AltAttention(
380
+ dim,
381
+ num_heads=num_heads,
382
+ qkv_bias=qkv_bias,
383
+ qk_scale=qk_scale,
384
+ attn_drop=attn_drop,
385
+ proj_drop=drop,
386
+ cosine_attention=cosine_attention,
387
+ )
388
+
389
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
390
+ self.norm2 = norm_layer(dim)
391
+ mlp_hidden_dim = int(dim * mlp_ratio)
392
+ self.mlp = Mlp(
393
+ in_features=dim,
394
+ hidden_features=mlp_hidden_dim,
395
+ act_layer=act_layer,
396
+ drop=mlp_drop,
397
+ )
398
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
399
+
400
+ def forward(self, x, padding_mask=None, alibi_bias=None, fast=True):
401
+ if self.layer_norm_first:
402
+ _x, _attn = self.attn(self.norm1(x), padding_mask, alibi_bias, fast=fast)
403
+ x = x + self.drop_path(_x)
404
+ r = x = self.mlp(self.norm2(x))
405
+ t = x
406
+ x = r + self.drop_path(self.post_mlp_dropout(x))
407
+ if not self.ffn_targets:
408
+ t = x
409
+ else:
410
+ _x, _attn = self.attn(x, padding_mask, alibi_bias, fast=fast)
411
+ x = x + self.drop_path(_x)
412
+ r = x = self.norm1(x)
413
+ x = self.mlp(x)
414
+ t = x
415
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
416
+ if not self.ffn_targets:
417
+ t = x
418
+
419
+ return x, t, _attn
420
+
421
+
422
+ # copied from fairseq.data2vec.models.modalities.modules
423
+ class BlockEncoder(nn.Module):
424
+ def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout):
425
+ super().__init__()
426
+ self.blocks = blocks
427
+ self.norm = norm_layer
428
+ self.layer_norm_first = layer_norm_first
429
+ self.layerdrop = layerdrop
430
+ self.dropout = nn.Dropout(dropout, inplace=True)
431
+
432
+ def forward(self, x, padding_mask, alibi_bias, alibi_scale):
433
+ if self.norm is not None and not self.layer_norm_first:
434
+ x = self.norm(x)
435
+
436
+ x = self.dropout(x)
437
+
438
+ for i, blk in enumerate(self.blocks):
439
+ if (
440
+ not self.training
441
+ or self.layerdrop == 0
442
+ or (np.random.random() > self.layerdrop)
443
+ ):
444
+ ab = alibi_bias
445
+ if ab is not None and alibi_scale is not None:
446
+ scale = (
447
+ alibi_scale[i]
448
+ if alibi_scale.size(0) > 1
449
+ else alibi_scale.squeeze(0)
450
+ )
451
+ ab = ab * scale.type_as(ab)
452
+ x, _, _ = blk(x, padding_mask, ab)
453
+
454
+ if self.norm is not None and self.layer_norm_first:
455
+ x = self.norm(x)
456
+
457
+ return x
458
+
459
+
460
+ class ModalitySpecificEncoder(nn.Module):
461
+ def __init__(
462
+ self,
463
+ modality_cfg: PantagruelModalityConfig,
464
+ embed_dim: int,
465
+ local_encoder: nn.Module,
466
+ project_features: nn.Module,
467
+ fixed_positional_encoder: Optional[nn.Module],
468
+ relative_positional_encoder: Optional[nn.Module],
469
+ context_encoder: nn.Module,
470
+ decoder: nn.Module,
471
+ get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]],
472
+ ):
473
+ super().__init__()
474
+
475
+ self.modality_cfg = modality_cfg
476
+ self.local_encoder = local_encoder
477
+ self.project_features = project_features
478
+ self.fixed_positional_encoder = fixed_positional_encoder
479
+ self.relative_positional_encoder = relative_positional_encoder
480
+ self.context_encoder = context_encoder
481
+
482
+ self.decoder = None
483
+ self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None
484
+
485
+ self.local_grad_mult = self.modality_cfg.local_grad_mult
486
+
487
+ self.extra_tokens = None
488
+ if modality_cfg.num_extra_tokens > 0:
489
+ self.extra_tokens = nn.Parameter(
490
+ torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim)
491
+ )
492
+ if not modality_cfg.init_extra_token_zero:
493
+ nn.init.normal_(self.extra_tokens)
494
+ elif self.extra_tokens.size(1) > 1:
495
+ nn.init.normal_(self.extra_tokens[:, 1:])
496
+
497
+ self.alibi_scale = None
498
+ if self.get_alibi_bias is not None:
499
+ self.alibi_scale = nn.Parameter(
500
+ torch.full(
501
+ (
502
+ (modality_cfg.prenet_depth + modality_cfg.model_depth)
503
+ if modality_cfg.learned_alibi_scale_per_layer
504
+ else 1,
505
+ 1,
506
+ self.modality_cfg.num_alibi_heads
507
+ if modality_cfg.learned_alibi_scale_per_head
508
+ else 1,
509
+ 1,
510
+ 1,
511
+ ),
512
+ modality_cfg.alibi_scale,
513
+ dtype=torch.float,
514
+ ),
515
+ requires_grad=modality_cfg.learned_alibi_scale,
516
+ )
517
+
518
+ if modality_cfg.learned_alibi and self.get_alibi_bias is not None:
519
+ assert modality_cfg.alibi_max_pos is not None
520
+ alibi_bias = self.get_alibi_bias(
521
+ batch_size=1,
522
+ time_steps=modality_cfg.alibi_max_pos,
523
+ heads=modality_cfg.num_alibi_heads,
524
+ scale=1.0,
525
+ dtype=torch.float,
526
+ device="cpu",
527
+ )
528
+ self.alibi_bias = nn.Parameter(alibi_bias)
529
+ self.get_alibi_bias = partial(
530
+ _learned_alibi_bias, alibi_bias=self.alibi_bias
531
+ )
532
+
533
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder._freeze_parameters
534
+ def _freeze_parameters(self):
535
+ for param in self.parameters():
536
+ param.requires_grad = False
537
+ self._requires_grad = False
538
+
539
+ def convert_padding_mask(self, x, padding_mask):
540
+ return padding_mask
541
+
542
+ def local_features(self, features):
543
+ if self.local_grad_mult > 0:
544
+ if self.local_grad_mult == 1.0:
545
+ x = self.local_encoder(features)
546
+ else:
547
+ x = GradMultiply.apply(
548
+ self.local_encoder(features), self.local_grad_mult
549
+ )
550
+ else:
551
+ with torch.no_grad():
552
+ x = self.local_encoder(features)
553
+
554
+ x = self.project_features(x)
555
+ return x
556
+
557
+ def contextualized_features(
558
+ self,
559
+ x,
560
+ padding_mask,
561
+ mask,
562
+ remove_masked,
563
+ clone_batch: int = 1,
564
+ mask_seeds: Optional[torch.Tensor] = None,
565
+ precomputed_mask=None,
566
+ ):
567
+
568
+ if padding_mask is not None:
569
+ padding_mask = self.convert_padding_mask(x, padding_mask)
570
+
571
+ local_features = x
572
+ if mask and clone_batch == 1:
573
+ local_features = local_features.clone()
574
+
575
+ orig_B, orig_T, _ = x.shape
576
+ pre_mask_B = orig_B
577
+ mask_info = None
578
+
579
+ x_pos = None
580
+ if self.fixed_positional_encoder is not None:
581
+ x = x + self.fixed_positional_encoder(x, padding_mask)
582
+
583
+ if mask:
584
+ if clone_batch > 1:
585
+ x = x.repeat_interleave(clone_batch, 0)
586
+ if mask_seeds is not None:
587
+ clone_hash = [
588
+ int(hash((mask_seeds.seed, ind)) % 1e10)
589
+ for ind in range(clone_batch - 1)
590
+ ]
591
+ clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
592
+
593
+ id = mask_seeds.ids
594
+ id = id.repeat_interleave(clone_batch, 0)
595
+ id = id.view(-1, clone_batch) + clone_hash.to(id)
596
+ id = id.view(-1)
597
+ mask_seeds = MaskSeed(
598
+ seed=mask_seeds.seed, update=mask_seeds.update, ids=id
599
+ )
600
+ if padding_mask is not None:
601
+ padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
602
+
603
+ x, mask_info = self.compute_mask(
604
+ x,
605
+ padding_mask,
606
+ mask_seed=mask_seeds,
607
+ apply=self.relative_positional_encoder is not None or not remove_masked,
608
+ precomputed_mask=precomputed_mask,
609
+ )
610
+
611
+ if self.relative_positional_encoder is not None:
612
+ x_pos = self.relative_positional_encoder(x)
613
+
614
+ masked_padding_mask = padding_mask
615
+ if mask and remove_masked:
616
+ x = mask_info.x_unmasked
617
+ if x_pos is not None:
618
+ x = x + gather_unmasked(x_pos, mask_info)
619
+
620
+ if padding_mask is not None and padding_mask.any():
621
+ masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info)
622
+ if not masked_padding_mask.any():
623
+ masked_padding_mask = None
624
+ else:
625
+ masked_padding_mask = None
626
+
627
+ elif x_pos is not None:
628
+ x = x + x_pos
629
+
630
+ alibi_bias = None
631
+ alibi_scale = self.alibi_scale
632
+
633
+ if self.get_alibi_bias is not None:
634
+ alibi_bias = self.get_alibi_bias(
635
+ batch_size=pre_mask_B,
636
+ time_steps=orig_T,
637
+ heads=self.modality_cfg.num_alibi_heads,
638
+ dtype=torch.float32,
639
+ device=x.device,
640
+ )
641
+
642
+ if alibi_scale is not None:
643
+ alibi_scale = alibi_scale.clamp_min(0)
644
+ if alibi_scale.size(0) == 1:
645
+ alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias)
646
+ alibi_scale = None
647
+
648
+ if clone_batch > 1:
649
+ alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0)
650
+
651
+ if mask_info is not None and remove_masked:
652
+ alibi_bias = masked_alibi(alibi_bias, mask_info)
653
+
654
+ if self.extra_tokens is not None:
655
+ num = self.extra_tokens.size(1)
656
+ x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1)
657
+ if masked_padding_mask is not None:
658
+ # B x T
659
+ masked_padding_mask = F.pad(masked_padding_mask, (num, 0))
660
+ if alibi_bias is not None:
661
+ # B x H x T x T
662
+ alibi_bias = F.pad(alibi_bias, (num, 0, num, 0))
663
+
664
+ x = self.context_encoder(
665
+ x,
666
+ masked_padding_mask,
667
+ alibi_bias,
668
+ alibi_scale[: self.modality_cfg.prenet_depth]
669
+ if alibi_scale is not None
670
+ else None,
671
+ )
672
+
673
+ return {
674
+ "x": x,
675
+ "local_features": local_features,
676
+ "padding_mask": masked_padding_mask,
677
+ "alibi_bias": alibi_bias,
678
+ "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
679
+ if alibi_scale is not None and alibi_scale.size(0) > 1
680
+ else alibi_scale,
681
+ "encoder_mask": mask_info,
682
+ }
683
+
684
+ def forward(
685
+ self,
686
+ features,
687
+ padding_mask,
688
+ mask: bool,
689
+ remove_masked: bool,
690
+ clone_batch: int = 1,
691
+ mask_seeds: Optional[torch.Tensor] = None,
692
+ precomputed_mask=None,
693
+ ):
694
+ x = self.local_features(features)
695
+ return self.contextualized_features(
696
+ x,
697
+ padding_mask,
698
+ mask,
699
+ remove_masked,
700
+ clone_batch,
701
+ mask_seeds,
702
+ precomputed_mask,
703
+ )
704
+
705
+ def compute_mask(
706
+ self,
707
+ x,
708
+ padding_mask,
709
+ mask_seed: Optional[MaskSeed],
710
+ apply,
711
+ precomputed_mask,
712
+ ):
713
+ if precomputed_mask is not None:
714
+ mask = precomputed_mask
715
+ mask_info = self.make_maskinfo(x, mask)
716
+ else:
717
+ B, T, C = x.shape
718
+ cfg = self.modality_cfg
719
+
720
+ mask_prob = cfg.mask_prob
721
+
722
+ if (
723
+ cfg.mask_prob_min is not None
724
+ and cfg.mask_prob_min >= 0
725
+ and cfg.mask_prob_min < mask_prob
726
+ ):
727
+ mask_prob = np.random.uniform(cfg.mask_prob_min, mask_prob)
728
+
729
+ if mask_prob > 0:
730
+ if cfg.mask_length == 1:
731
+ mask_info = random_masking(x, mask_prob, mask_seed)
732
+ else:
733
+ if self.modality_cfg.inverse_mask:
734
+ mask_prob = 1 - mask_prob
735
+
736
+ mask = compute_mask_indices(
737
+ (B, T),
738
+ padding_mask,
739
+ mask_prob,
740
+ cfg.mask_length,
741
+ min_masks=1,
742
+ require_same_masks=True,
743
+ mask_dropout=cfg.mask_dropout,
744
+ add_masks=cfg.add_masks,
745
+ seed=mask_seed.seed if mask_seed is not None else None,
746
+ epoch=mask_seed.update if mask_seed is not None else None,
747
+ indices=mask_seed.ids if mask_seed is not None else None,
748
+ )
749
+
750
+ mask = torch.from_numpy(mask).to(device=x.device)
751
+ if self.modality_cfg.inverse_mask:
752
+ mask = 1 - mask
753
+ mask_info = self.make_maskinfo(x, mask)
754
+ else:
755
+ mask_info = None
756
+
757
+ if apply:
758
+ x = self.apply_mask(x, mask_info)
759
+
760
+ return x, mask_info
761
+
762
+ def make_maskinfo(self, x, mask, shape=None):
763
+ if shape is None:
764
+ B, T, D = x.shape
765
+ else:
766
+ B, T, D = shape
767
+
768
+ mask = mask.to(torch.uint8)
769
+ ids_shuffle = mask.argsort(dim=1)
770
+ ids_restore = ids_shuffle.argsort(dim=1).unsqueeze(-1).expand(-1, -1, D)
771
+
772
+ len_keep = T - mask[0].sum()
773
+ if self.modality_cfg.keep_masked_pct > 0:
774
+ len_keep += round((T - int(len_keep)) * self.modality_cfg.keep_masked_pct)
775
+
776
+ ids_keep = ids_shuffle[:, :len_keep]
777
+
778
+ if shape is not None:
779
+ x_unmasked = None
780
+ else:
781
+ ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
782
+ x_unmasked = torch.gather(x, dim=1, index=ids_keep)
783
+
784
+ mask_info = MaskInfo(
785
+ x_unmasked=x_unmasked,
786
+ mask=mask,
787
+ ids_restore=ids_restore,
788
+ ids_keep=ids_keep,
789
+ )
790
+ return mask_info
791
+
792
+ def apply_mask(self, x, mask_info):
793
+ cfg = self.modality_cfg
794
+ B, T, C = x.shape
795
+
796
+ if mask_info is not None:
797
+ mask = mask_info.mask
798
+ if cfg.encoder_zero_mask:
799
+ x = x * (1 - mask.type_as(x).unsqueeze(-1))
800
+ else:
801
+ num_masks = mask.sum().item()
802
+ masks = x.new_empty(num_masks, x.size(-1)).normal_(
803
+ 0, cfg.mask_noise_std
804
+ )
805
+ x = index_put(x, mask, masks)
806
+ if cfg.mask_channel_prob > 0:
807
+ mask_channel = compute_mask_indices(
808
+ (B, C),
809
+ None,
810
+ cfg.mask_channel_prob,
811
+ cfg.mask_channel_length,
812
+ )
813
+ mask_channel = (
814
+ torch.from_numpy(mask_channel)
815
+ .to(x.device)
816
+ .unsqueeze(1)
817
+ .expand(-1, T, -1)
818
+ )
819
+ x = index_put(x, mask_channel, 0)
820
+ return x
821
+
822
+
823
+ class AudioEncoder(ModalitySpecificEncoder):
824
+
825
+ modality_cfg: PantagruelAudioConfig
826
+
827
+ def __init__(
828
+ self,
829
+ modality_cfg: PantagruelAudioConfig,
830
+ embed_dim: int,
831
+ make_block: Callable[[float], nn.ModuleList],
832
+ norm_layer: Callable[[int], nn.LayerNorm],
833
+ layer_norm_first: bool,
834
+ alibi_biases: Dict,
835
+ ):
836
+
837
+ self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec)
838
+ feature_embed_dim = self.feature_enc_layers[-1][0]
839
+
840
+ local_encoder = ConvFeatureExtractionModel(
841
+ conv_layers=self.feature_enc_layers,
842
+ dropout=0.0,
843
+ mode=modality_cfg.extractor_mode,
844
+ conv_bias=False,
845
+ )
846
+
847
+ project_features = nn.Sequential(
848
+ TransposeLast(),
849
+ nn.LayerNorm(feature_embed_dim),
850
+ nn.Linear(feature_embed_dim, embed_dim),
851
+ )
852
+
853
+ num_pos_layers = modality_cfg.conv_pos_depth
854
+ k = max(3, modality_cfg.conv_pos_width // num_pos_layers)
855
+
856
+ positional_encoder = nn.Sequential(
857
+ TransposeLast(),
858
+ *[
859
+ nn.Sequential(
860
+ nn.Conv1d(
861
+ embed_dim,
862
+ embed_dim,
863
+ kernel_size=k,
864
+ padding=k // 2,
865
+ groups=modality_cfg.conv_pos_groups,
866
+ ),
867
+ SamePad(k),
868
+ TransposeLast(),
869
+ LayerNorm(embed_dim, elementwise_affine=False),
870
+ TransposeLast(),
871
+ nn.GELU(),
872
+ )
873
+ for _ in range(num_pos_layers)
874
+ ],
875
+ TransposeLast(),
876
+ )
877
+
878
+ if modality_cfg.conv_pos_pre_ln:
879
+ positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder)
880
+
881
+ dpr = np.linspace(
882
+ modality_cfg.start_drop_path_rate,
883
+ modality_cfg.end_drop_path_rate,
884
+ modality_cfg.prenet_depth,
885
+ )
886
+ context_encoder = BlockEncoder(
887
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
888
+ norm_layer(embed_dim) if not layer_norm_first else None,
889
+ layer_norm_first,
890
+ modality_cfg.prenet_layerdrop,
891
+ modality_cfg.prenet_dropout,
892
+ )
893
+
894
+ decoder = None
895
+
896
+ alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
897
+
898
+ super().__init__(
899
+ modality_cfg=modality_cfg,
900
+ embed_dim=embed_dim,
901
+ local_encoder=local_encoder,
902
+ project_features=project_features,
903
+ fixed_positional_encoder=None,
904
+ relative_positional_encoder=positional_encoder,
905
+ context_encoder=context_encoder,
906
+ decoder=decoder,
907
+ get_alibi_bias=alibi_bias_fn,
908
+ )
909
+
910
+ def convert_padding_mask(self, x, padding_mask):
911
+ def get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
912
+ """
913
+ Computes the output length of the convolutional layers
914
+ """
915
+
916
+ def _conv_out_length(input_length, kernel_size, stride):
917
+ return torch.floor((input_length - kernel_size) / stride + 1)
918
+
919
+ for i in range(len(self.feature_enc_layers)):
920
+ input_lengths = _conv_out_length(
921
+ input_lengths,
922
+ self.feature_enc_layers[i][1],
923
+ self.feature_enc_layers[i][2],
924
+ )
925
+
926
+ return input_lengths.to(torch.long)
927
+
928
+ if padding_mask is not None:
929
+ input_lengths = (1 - padding_mask.long()).sum(-1)
930
+ # apply conv formula to get real output_lengths
931
+ output_lengths = get_feat_extract_output_lengths(input_lengths)
932
+
933
+ if padding_mask.any():
934
+ padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device)
935
+
936
+ # these two operations makes sure that all values
937
+ # before the output lengths indices are attended to
938
+ padding_mask[
939
+ (
940
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
941
+ output_lengths - 1,
942
+ )
943
+ ] = 1
944
+ padding_mask = (
945
+ 1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
946
+ ).bool()
947
+ else:
948
+ padding_mask = torch.zeros(
949
+ x.shape[:2], dtype=torch.bool, device=x.device
950
+ )
951
+
952
+ return padding_mask
953
+
954
+
955
+ class LearnedPositionalEmbedding(nn.Embedding):
956
+ """
957
+ This module learns positional embeddings up to a fixed maximum size.
958
+ Padding ids are ignored by either offsetting based on padding_idx
959
+ or by setting padding_idx to None and ensuring that the appropriate
960
+ position ids are passed to the forward function.
961
+ """
962
+
963
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
964
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
965
+ self.onnx_trace = False
966
+ if self.padding_idx is not None:
967
+ self.max_positions = self.num_embeddings - self.padding_idx - 1
968
+ else:
969
+ self.max_positions = self.num_embeddings
970
+
971
+ def forward(
972
+ self,
973
+ input: Tensor,
974
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
975
+ positions: Optional[Tensor] = None,
976
+ ):
977
+ """Input is expected to be of size [bsz x seqlen]."""
978
+ assert (positions is None) or (
979
+ self.padding_idx is None
980
+ ), "If positions is pre-computed then padding_idx should not be set."
981
+
982
+ if positions is None:
983
+ if incremental_state is not None:
984
+ # positions is the same for every token when decoding a single step
985
+ # Without the int() cast, it doesn't work in some cases when exporting to ONNX
986
+ positions = torch.zeros(
987
+ (1, 1), device=input.device, dtype=input.dtype
988
+ ).fill_(int(self.padding_idx + input.size(1)))
989
+ else:
990
+ positions = make_positions(
991
+ input, self.padding_idx, onnx_trace=self.onnx_trace
992
+ )
993
+ return F.embedding(
994
+ positions,
995
+ self.weight,
996
+ self.padding_idx,
997
+ self.max_norm,
998
+ self.norm_type,
999
+ self.scale_grad_by_freq,
1000
+ self.sparse,
1001
+ )
1002
+
1003
+
1004
+ class SinusoidalPositionalEmbedding(nn.Module):
1005
+ """This module produces sinusoidal positional embeddings of any length.
1006
+
1007
+ Padding symbols are ignored.
1008
+ """
1009
+
1010
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
1011
+ super().__init__()
1012
+ self.embedding_dim = embedding_dim
1013
+ self.padding_idx = padding_idx if padding_idx is not None else 0
1014
+ self.register_buffer("weights", SinusoidalPositionalEmbedding.get_embedding(
1015
+ init_size, embedding_dim, padding_idx
1016
+ ), persistent=False)
1017
+ self.max_positions = int(1e5)
1018
+ self.onnx_trace = False
1019
+
1020
+ def prepare_for_onnx_export_(self):
1021
+ self.onnx_trace = True
1022
+
1023
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
1024
+ # Ignore some deprecated keys that were used in older versions
1025
+ deprecated_keys = ["weights", "_float_tensor"]
1026
+ for key in deprecated_keys:
1027
+ if prefix + key in state_dict:
1028
+ del state_dict[prefix + key]
1029
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
1030
+
1031
+ @staticmethod
1032
+ def get_embedding(
1033
+ num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
1034
+ ):
1035
+ """Build sinusoidal embeddings.
1036
+
1037
+ This matches the implementation in tensor2tensor, but differs slightly
1038
+ from the description in Section 3.5 of "Attention Is All You Need".
1039
+ """
1040
+ half_dim = embedding_dim // 2
1041
+ emb = math.log(10000) / (half_dim - 1)
1042
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
1043
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1044
+ 1
1045
+ ) * emb.unsqueeze(0)
1046
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
1047
+ num_embeddings, -1
1048
+ )
1049
+ if embedding_dim % 2 == 1:
1050
+ # zero pad
1051
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
1052
+ if padding_idx is not None:
1053
+ emb[padding_idx, :] = 0
1054
+ return emb
1055
+
1056
+ def forward(
1057
+ self,
1058
+ input,
1059
+ incremental_state: Optional[Any] = None,
1060
+ timestep: Optional[Tensor] = None,
1061
+ positions: Optional[Any] = None,
1062
+ ):
1063
+ """Input is expected to be of size [bsz x seqlen]."""
1064
+ bspair = torch.onnx.operators.shape_as_tensor(input)
1065
+ bsz, seq_len = bspair[0], bspair[1]
1066
+ max_pos = self.padding_idx + 1 + seq_len
1067
+ if max_pos > self.weights.size(0):
1068
+ # expand embeddings if needed
1069
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
1070
+ max_pos, self.embedding_dim, self.padding_idx
1071
+ ).to(self.weights)
1072
+
1073
+ if incremental_state is not None:
1074
+ # positions is the same for every token when decoding a single step
1075
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
1076
+ if self.onnx_trace:
1077
+ return (
1078
+ self.weights.index_select(index=self.padding_idx + pos, dim=0)
1079
+ .unsqueeze(1)
1080
+ .repeat(bsz, 1, 1)
1081
+ )
1082
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
1083
+
1084
+ positions = make_positions(
1085
+ input, self.padding_idx, onnx_trace=self.onnx_trace
1086
+ )
1087
+ if self.onnx_trace:
1088
+ flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
1089
+ embedding_shape = torch.cat(
1090
+ (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long))
1091
+ )
1092
+ embeddings = torch.onnx.operators.reshape_from_tensor_shape(
1093
+ flat_embeddings, embedding_shape
1094
+ )
1095
+ return embeddings
1096
+ return (
1097
+ self.weights.index_select(0, positions.view(-1))
1098
+ .view(bsz, seq_len, -1)
1099
+ .detach()
1100
+ )
1101
+
1102
+ def PositionalEmbedding(
1103
+ num_embeddings: int,
1104
+ embedding_dim: int,
1105
+ padding_idx: int,
1106
+ learned: bool = False,
1107
+ ):
1108
+ if learned:
1109
+ # if padding_idx is specified then offset the embedding ids by
1110
+ # this index and adjust num_embeddings appropriately
1111
+ # TODO: The right place for this offset would be inside
1112
+ # LearnedPositionalEmbedding. Move this there for a cleaner implementation.
1113
+ if padding_idx is not None:
1114
+ num_embeddings = num_embeddings + padding_idx + 1
1115
+ m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
1116
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
1117
+ if padding_idx is not None:
1118
+ nn.init.constant_(m.weight[padding_idx], 0)
1119
+ else:
1120
+ m = SinusoidalPositionalEmbedding(
1121
+ embedding_dim,
1122
+ padding_idx,
1123
+ init_size=num_embeddings + padding_idx + 1,
1124
+ )
1125
+ return m
1126
+
1127
+
1128
+ class TextLocalEncoder(nn.Module):
1129
+ def __init__(
1130
+ self,
1131
+ vocab_size,
1132
+ embed_dim,
1133
+ max_source_positions,
1134
+ pad_idx,
1135
+ no_scale_embedding,
1136
+ layernorm_embedding,
1137
+ dropout,
1138
+ no_token_positional_embeddings,
1139
+ learned_pos,
1140
+ ):
1141
+ super().__init__()
1142
+ self.pad_idx = pad_idx
1143
+ self.dropout_module = nn.Dropout(dropout)
1144
+
1145
+ self.embed_tokens = nn.Embedding(vocab_size, embed_dim, pad_idx)
1146
+ self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
1147
+ self.embed_positions = (
1148
+ PositionalEmbedding(
1149
+ max_source_positions,
1150
+ embed_dim,
1151
+ pad_idx,
1152
+ learned=learned_pos,
1153
+ )
1154
+ if not no_token_positional_embeddings
1155
+ else None
1156
+ )
1157
+ self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
1158
+
1159
+ self.layernorm_embedding = None
1160
+ if layernorm_embedding:
1161
+ self.layernorm_embedding = LayerNorm(embed_dim)
1162
+
1163
+ def forward(self, src_tokens):
1164
+ x = self.embed_scale * self.embed_tokens(src_tokens)
1165
+ if self.embed_positions is not None:
1166
+ x = x + self.embed_positions(src_tokens)
1167
+
1168
+ if self.layernorm_embedding is not None:
1169
+ x = self.layernorm_embedding(x)
1170
+ x = self.dropout_module(x)
1171
+ return x
1172
+
1173
+
1174
+ class TextEncoder(ModalitySpecificEncoder):
1175
+
1176
+ modality_cfg: PantagruelTextConfig
1177
+
1178
+ def __init__(
1179
+ self,
1180
+ modality_cfg: PantagruelTextConfig,
1181
+ embed_dim: int,
1182
+ make_block: Callable[[float], nn.ModuleList],
1183
+ norm_layer: Callable[[int], nn.LayerNorm],
1184
+ layer_norm_first: bool,
1185
+ alibi_biases: Dict,
1186
+ ):
1187
+ self.pad_idx = modality_cfg.pad_token_id
1188
+ self.vocab_size = modality_cfg.vocab_size
1189
+
1190
+ local_encoder = TextLocalEncoder(
1191
+ vocab_size=self.vocab_size,
1192
+ embed_dim=embed_dim,
1193
+ max_source_positions=modality_cfg.max_source_positions,
1194
+ pad_idx=self.pad_idx,
1195
+ no_scale_embedding=modality_cfg.no_scale_embedding,
1196
+ layernorm_embedding=modality_cfg.layernorm_embedding,
1197
+ dropout=modality_cfg.dropout,
1198
+ no_token_positional_embeddings=modality_cfg.no_token_positional_embeddings,
1199
+ learned_pos=modality_cfg.learned_pos,
1200
+ )
1201
+ dpr = np.linspace(
1202
+ modality_cfg.start_drop_path_rate,
1203
+ modality_cfg.end_drop_path_rate,
1204
+ modality_cfg.prenet_depth,
1205
+ )
1206
+ context_encoder = BlockEncoder(
1207
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
1208
+ norm_layer(embed_dim)
1209
+ if not layer_norm_first and modality_cfg.prenet_depth > 0
1210
+ else None,
1211
+ layer_norm_first,
1212
+ modality_cfg.prenet_layerdrop,
1213
+ modality_cfg.prenet_dropout if modality_cfg.prenet_depth > 0 else 0.0,
1214
+ )
1215
+ decoder = None
1216
+
1217
+ alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
1218
+
1219
+ super().__init__(
1220
+ modality_cfg=modality_cfg,
1221
+ embed_dim=embed_dim,
1222
+ local_encoder=local_encoder,
1223
+ project_features=nn.Identity(),
1224
+ fixed_positional_encoder=None,
1225
+ relative_positional_encoder=None,
1226
+ context_encoder=context_encoder,
1227
+ decoder=decoder,
1228
+ get_alibi_bias=alibi_bias_fn,
1229
+ )
1230
+
1231
+ def convert_padding_mask(self, x, padding_mask):
1232
+ if padding_mask is None or padding_mask.size(1) == x.size(1):
1233
+ return padding_mask
1234
+
1235
+ diff = self.downsample - padding_mask.size(1) % self.downsample
1236
+ if 0 < diff < self.downsample:
1237
+ padding_mask = F.pad(padding_mask, (0, diff), value=True)
1238
+
1239
+ padding_mask = padding_mask.view(padding_mask.size(0), -1, self.downsample)
1240
+ padding_mask = padding_mask.all(-1)
1241
+ if padding_mask.size(1) > x.size(1):
1242
+ padding_mask = padding_mask[:, : x.size(1)]
1243
+
1244
+ assert x.size(1) == padding_mask.size(
1245
+ 1
1246
+ ), f"{x.size(1), padding_mask.size(1), diff, self.downsample}"
1247
+
1248
+ return padding_mask
1249
+ #################################################
1250
+
1251
+
1252
+ # copied from transformers.models.data2vec.modeling_data2vec.PantagruelUniTextPooler
1253
+ class PantagruelUniTextPooler(nn.Module):
1254
+ def __init__(self, config):
1255
+ super().__init__()
1256
+ self.dense = nn.Linear(config.embed_dim, config.embed_dim)
1257
+ self.activation = nn.Tanh()
1258
+
1259
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1260
+ # We "pool" the model by simply taking the hidden state corresponding
1261
+ # to the first token.
1262
+ first_token_tensor = hidden_states[:, 0]
1263
+ pooled_output = self.dense(first_token_tensor)
1264
+ pooled_output = self.activation(pooled_output)
1265
+ return pooled_output
1266
+
1267
+
1268
+ class PantagruelUniPreTrainedModel(PreTrainedModel):
1269
+ config_class = PantagruelUniConfig
1270
+ base_model_prefix = "pantagruel_uni"
1271
+
1272
+ # use init_bert_params from fairseq
1273
+ # copied from fairseq.modules.transformer_sentence_encoder.py
1274
+ def _init_weights(self, module):
1275
+ """Initialize the weights"""
1276
+
1277
+ def normal_(data):
1278
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
1279
+ # so that the RNG is consistent with and without FSDP
1280
+ if not data.is_meta:
1281
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
1282
+ return data
1283
+
1284
+ def _init(module):
1285
+ if isinstance(module, nn.Linear):
1286
+ normal_(module.weight.data)
1287
+ if module.bias is not None:
1288
+ module.bias.data.zero_()
1289
+ if isinstance(module, nn.Embedding):
1290
+ normal_(module.weight.data)
1291
+ if module.padding_idx is not None:
1292
+ module.weight.data[module.padding_idx].zero_()
1293
+ if isinstance(module, AltBlock):
1294
+ normal_(module.attn.proj.weight.data)
1295
+ # init strategy for audio encoder
1296
+ if isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
1297
+ if module.bias is not None:
1298
+ module.bias.data.zero_()
1299
+ if module.weight is not None:
1300
+ module.weight.data.fill_(1.0)
1301
+ if isinstance(module, nn.Conv1d):
1302
+ nn.init.kaiming_normal_(module.weight)
1303
+ if module.bias is not None:
1304
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
1305
+ nn.init.uniform_(module.bias, a=-k, b=k)
1306
+
1307
+ if isinstance(module, nn.ModuleList):
1308
+ for _, mod in enumerate(module):
1309
+ _init(mod)
1310
+ else:
1311
+ _init(module)
1312
+
1313
+ # @classmethod
1314
+ # def from_pretrained(
1315
+ # cls,
1316
+ # pretrained_model_name_or_path,
1317
+ # *model_args,
1318
+ # **kwargs,
1319
+ # ):
1320
+ # config = cls.config_class()
1321
+ # config.from_pretrained(pretrained_model_name_or_path)
1322
+ # print(f"Loading configuration from pre-trained model: {type(config)}")
1323
+ # return super().from_pretrained(pretrained_model_name_or_path,
1324
+ # *model_args,
1325
+ # config,
1326
+ # **kwargs,)
1327
+
1328
+
1329
+ class PantagruelUniModel(PantagruelUniPreTrainedModel):
1330
+
1331
+ def __init__(
1332
+ self, config: PantagruelUniConfig, add_pooling_layer: bool = True
1333
+ ):
1334
+ super().__init__(config)
1335
+ self.config = config
1336
+ modalities_cfg = config.modalities
1337
+ self.modalities = [config.supported_modality]
1338
+
1339
+ make_layer_norm = partial(
1340
+ nn.LayerNorm, eps=config.norm_eps, elementwise_affine=config.norm_affine
1341
+ )
1342
+
1343
+ def make_block(drop_path, dim=None, heads=None):
1344
+ return AltBlock(
1345
+ config.embed_dim if dim is None else dim,
1346
+ config.num_heads if heads is None else heads,
1347
+ config.mlp_ratio,
1348
+ qkv_bias=True,
1349
+ drop=config.encoder_dropout,
1350
+ attn_drop=config.attention_dropout,
1351
+ mlp_drop=config.activation_dropout,
1352
+ post_mlp_drop=config.post_mlp_drop,
1353
+ drop_path=drop_path,
1354
+ norm_layer=make_layer_norm,
1355
+ layer_norm_first=config.layer_norm_first,
1356
+ ffn_targets=not config.end_of_block_targets,
1357
+ )
1358
+
1359
+ self.alibi_biases = {}
1360
+ self.modality_encoders = nn.ModuleDict()
1361
+ for mod in self.modalities:
1362
+ mod_cfg = getattr(modalities_cfg, mod.lower())
1363
+ enc = self.make_modality_encoder(
1364
+ mod_cfg,
1365
+ config.embed_dim,
1366
+ make_block,
1367
+ make_layer_norm,
1368
+ config.layer_norm_first,
1369
+ self.alibi_biases,
1370
+ )
1371
+ self.modality_encoders[mod] = enc
1372
+
1373
+ self.dropout_input = nn.Dropout(config.dropout_input)
1374
+
1375
+ dpr = np.linspace(config.start_drop_path_rate, config.end_drop_path_rate, config.depth)
1376
+
1377
+ self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(config.depth)])
1378
+
1379
+ self.text_pooler = None
1380
+ if add_pooling_layer and config.supported_modality == "TEXT":
1381
+ self.text_pooler = PantagruelUniTextPooler(config)
1382
+
1383
+ self.norm = None
1384
+ if config.layer_norm_first:
1385
+ self.norm = make_layer_norm(config.embed_dim)
1386
+
1387
+ self.num_updates = 0
1388
+
1389
+ # Initialize weights and apply final processing
1390
+ self.post_init()
1391
+
1392
+ def get_input_embeddings(self):
1393
+ return self.modality_encoders["TEXT"].local_encoder.embed_tokens
1394
+
1395
+ def set_input_embeddings(self, value):
1396
+ self.modality_encoders["TEXT"].local_encoder.embed_tokens = value
1397
+
1398
+ def freeze_feature_extractor(self):
1399
+ """
1400
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
1401
+ not be updated during training.
1402
+ """
1403
+ warnings.warn(
1404
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
1405
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
1406
+ FutureWarning,
1407
+ )
1408
+ self.freeze_feature_encoder()
1409
+
1410
+ def freeze_feature_encoder(self):
1411
+ """
1412
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1413
+ not be updated during training.
1414
+ """
1415
+ for mod in self.modalities:
1416
+ self.modality_encoders[mod]._freeze_parameters()
1417
+ for block in self.blocks:
1418
+ for p in block.parameters():
1419
+ p.requires_grad = False
1420
+
1421
+ def make_modality_encoder(
1422
+ self,
1423
+ cfg: PantagruelModalityConfig,
1424
+ embed_dim: int,
1425
+ make_block: Callable[[float], nn.ModuleList],
1426
+ norm_layer: Callable[[int], nn.LayerNorm],
1427
+ layer_norm_first: bool,
1428
+ alibi_biases,
1429
+ ) -> ModalitySpecificEncoder:
1430
+ if cfg.type == "AUDIO":
1431
+ enc_cls = AudioEncoder
1432
+ elif cfg.type == "TEXT":
1433
+ enc_cls = TextEncoder
1434
+ else:
1435
+ raise Exception(f"unsupported modality {cfg.type}")
1436
+
1437
+ return enc_cls(
1438
+ cfg,
1439
+ embed_dim,
1440
+ make_block,
1441
+ norm_layer,
1442
+ layer_norm_first,
1443
+ alibi_biases,
1444
+ )
1445
+
1446
+ def forward(
1447
+ self,
1448
+ input_values=None, # audio input
1449
+ input_ids=None, # text input
1450
+ attention_mask=None,
1451
+ padding_mask=None,
1452
+ mask=False,
1453
+ mode=None,
1454
+ output_hidden_states=True,
1455
+ output_attn_weights=False,
1456
+ return_dict=True,
1457
+ ):
1458
+ if mode is None:
1459
+ mode = "TEXT" if input_ids is not None else "AUDIO"
1460
+
1461
+ if padding_mask is None and attention_mask is not None:
1462
+ padding_mask = ~attention_mask # attention mask: 1 means to attend to (not masked), 0 means not to attend to (masked). padding mask: 1 means padded (not attend to), 0 means not padded (to attend to)
1463
+
1464
+ feature_extractor = self.modality_encoders[mode]
1465
+ extractor_out = feature_extractor(
1466
+ input_ids if input_ids is not None else input_values,
1467
+ padding_mask,
1468
+ mask,
1469
+ remove_masked=False,
1470
+ clone_batch=1,
1471
+ mask_seeds=None,
1472
+ precomputed_mask=None,
1473
+ )
1474
+ x = extractor_out["x"]
1475
+ local_features = x
1476
+
1477
+ # encoder_mask = extractor_out["encoder_mask"]
1478
+ masked_padding_mask = extractor_out["padding_mask"]
1479
+ masked_alibi_bias = extractor_out.get("alibi_bias", None)
1480
+ alibi_scale = extractor_out.get("alibi_scale", None)
1481
+
1482
+ if self.dropout_input is not None:
1483
+ x = self.dropout_input(x)
1484
+
1485
+ layer_results = []
1486
+ attn_weights = []
1487
+ for i, blk in enumerate(self.blocks):
1488
+ if (
1489
+ not self.training
1490
+ or self.config.layerdrop == 0
1491
+ or (np.random.random() > self.config.layerdrop)
1492
+ ):
1493
+ ab = masked_alibi_bias
1494
+ if ab is not None and alibi_scale is not None:
1495
+ scale = (
1496
+ alibi_scale[i]
1497
+ if alibi_scale.size(0) > 1
1498
+ else alibi_scale.squeeze(0)
1499
+ )
1500
+ ab = ab * scale.type_as(ab)
1501
+
1502
+ x, lr, _attn = blk(
1503
+ x,
1504
+ padding_mask=masked_padding_mask,
1505
+ alibi_bias=ab,
1506
+ fast=not output_attn_weights,
1507
+ )
1508
+ layer_results.append(lr)
1509
+ attn_weights.append(_attn)
1510
+
1511
+ if self.norm is not None:
1512
+ x = self.norm(x)
1513
+
1514
+ x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
1515
+ if masked_padding_mask is not None:
1516
+ masked_padding_mask = masked_padding_mask[
1517
+ :, feature_extractor.modality_cfg.num_extra_tokens :
1518
+ ]
1519
+
1520
+ txt_pooled_output = (
1521
+ self.text_pooler(x) if self.text_pooler is not None else None
1522
+ )
1523
+
1524
+ if not return_dict:
1525
+ return tuple(
1526
+ v
1527
+ for v in [
1528
+ x,
1529
+ txt_pooled_output,
1530
+ local_features,
1531
+ layer_results,
1532
+ attn_weights,
1533
+ ]
1534
+ if v is not None
1535
+ )
1536
+
1537
+ return PantagruelUniBaseModelOutput(
1538
+ last_hidden_state=x,
1539
+ pooler_output=txt_pooled_output,
1540
+ local_features=local_features,
1541
+ hidden_states=layer_results if output_hidden_states else None,
1542
+ attentions=attn_weights if output_attn_weights else None,
1543
+ )
1544
+
1545
+
1546
+ class PantagruelTextLMHead(nn.Module):
1547
+ """PantagruelText Head for masked language modeling."""
1548
+
1549
+ def __init__(self, config):
1550
+ super().__init__()
1551
+ self.dense = nn.Linear(config.embed_dim, config.embed_dim)
1552
+ self.layer_norm = nn.LayerNorm(config.embed_dim, eps=config.norm_eps)
1553
+
1554
+ self.decoder = nn.Linear(config.embed_dim, config.modalities.text.vocab_size)
1555
+ self.bias = nn.Parameter(torch.zeros(config.modalities.text.vocab_size))
1556
+ self.decoder.bias = self.bias
1557
+
1558
+ def forward(self, features, **kwargs):
1559
+ x = self.dense(features)
1560
+ x = gelu(x)
1561
+ x = self.layer_norm(x)
1562
+
1563
+ # project back to size of vocabulary with bias
1564
+ x = self.decoder(x)
1565
+
1566
+ return x
1567
+
1568
+ def _tie_weights(self):
1569
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
1570
+ # For accelerate compatibility and to not break backward compatibility
1571
+ if self.decoder.bias.device.type == "meta":
1572
+ self.decoder.bias = self.bias
1573
+ else:
1574
+ self.bias = self.decoder.bias
1575
+
1576
+
1577
+ class PantagruelTextClassificationHead(nn.Module):
1578
+ """Head for sentence-level classification tasks."""
1579
+
1580
+ def __init__(self, config):
1581
+ super().__init__()
1582
+ self.dense = nn.Linear(config.embed_dim, config.embed_dim)
1583
+ classifier_dropout = (
1584
+ config.classifier_dropout if config.classifier_dropout is not None else config.encoder_dropout
1585
+ )
1586
+ self.dropout = nn.Dropout(classifier_dropout)
1587
+ self.out_proj = nn.Linear(config.embed_dim, config.num_labels)
1588
+
1589
+ def forward(self, features, **kwargs):
1590
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1591
+ x = self.dropout(x)
1592
+ x = self.dense(x)
1593
+ x = torch.tanh(x)
1594
+ x = self.dropout(x)
1595
+ x = self.out_proj(x)
1596
+ return x
1597
+
1598
+
1599
+ @auto_docstring
1600
+ class PantagruelUniForMaskedLM(PantagruelUniPreTrainedModel):
1601
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
1602
+
1603
+ def __init__(self, config):
1604
+ super().__init__(config)
1605
+
1606
+ if config.is_decoder:
1607
+ logger.warning(
1608
+ "If you want to use `PantagruelTextForMaskedLM` make sure `config.is_decoder=False` for "
1609
+ "bi-directional self-attention."
1610
+ )
1611
+
1612
+ self.pantagruel_uni = PantagruelUniModel(config, add_pooling_layer=False)
1613
+ self.lm_head = PantagruelTextLMHead(config)
1614
+
1615
+ # Initialize weights and apply final processing
1616
+ self.post_init()
1617
+
1618
+ def get_output_embeddings(self):
1619
+ return self.lm_head.decoder
1620
+
1621
+ def set_output_embeddings(self, new_embeddings):
1622
+ self.lm_head.decoder = new_embeddings
1623
+
1624
+ @can_return_tuple
1625
+ @auto_docstring
1626
+ def forward(
1627
+ self,
1628
+ input_ids: Optional[torch.LongTensor] = None,
1629
+ attention_mask: Optional[torch.FloatTensor] = None,
1630
+ padding_mask: Optional[torch.FloatTensor] = None,
1631
+ labels: Optional[torch.LongTensor] = None,
1632
+ **kwargs: Unpack[TransformersKwargs],
1633
+ ) -> Union[tuple, MaskedLMOutput]:
1634
+ r"""
1635
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1636
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1637
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1638
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1639
+ """
1640
+ outputs = self.pantagruel_uni(
1641
+ input_ids=input_ids,
1642
+ attention_mask=attention_mask,
1643
+ padding_mask=padding_mask,
1644
+ mask=False,
1645
+ mode="TEXT",
1646
+ return_dict=True,
1647
+ )
1648
+ sequence_output = outputs.last_hidden_state[0]
1649
+ prediction_scores = self.lm_head(sequence_output)
1650
+
1651
+ masked_lm_loss = None
1652
+ if labels is not None:
1653
+ loss_fct = CrossEntropyLoss()
1654
+
1655
+ labels = labels.to(prediction_scores.device)
1656
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1657
+
1658
+ return MaskedLMOutput(
1659
+ loss=masked_lm_loss,
1660
+ logits=prediction_scores,
1661
+ hidden_states=outputs.last_hidden_state,
1662
+ attentions=outputs.attentions,
1663
+ )
1664
+
1665
+
1666
+ @auto_docstring(
1667
+ custom_intro="""
1668
+ PantagruelText Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1669
+ pooled output) e.g. for GLUE tasks.
1670
+ """
1671
+ )
1672
+ class PantagruelUniForSequenceClassification(PantagruelUniPreTrainedModel):
1673
+ def __init__(self, config):
1674
+ super().__init__(config)
1675
+ self.num_labels = config.num_labels
1676
+ self.config = config
1677
+
1678
+ self.pantagruel_uni = PantagruelUniModel(config, add_pooling_layer=False)
1679
+ self.classifier = PantagruelTextClassificationHead(config)
1680
+
1681
+ # Initialize weights and apply final processing
1682
+ self.post_init()
1683
+
1684
+ @can_return_tuple
1685
+ @auto_docstring
1686
+ def forward(
1687
+ self,
1688
+ input_ids: Optional[torch.LongTensor] = None,
1689
+ attention_mask: Optional[torch.FloatTensor] = None,
1690
+ padding_mask: Optional[torch.FloatTensor] = None,
1691
+ labels: Optional[torch.LongTensor] = None,
1692
+ **kwargs: Unpack[TransformersKwargs],
1693
+ ) -> Union[tuple, SequenceClassifierOutput]:
1694
+ r"""
1695
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1696
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1697
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1698
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1699
+ """
1700
+ outputs = self.pantagruel_uni(
1701
+ input_ids=input_ids,
1702
+ attention_mask=attention_mask,
1703
+ padding_mask=padding_mask,
1704
+ mask=False,
1705
+ mode="TEXT",
1706
+ return_dict=True,
1707
+ )
1708
+ sequence_output = outputs.last_hidden_state
1709
+ logits = self.classifier(sequence_output)
1710
+
1711
+ loss = None
1712
+ if labels is not None:
1713
+ labels = labels.to(logits.device)
1714
+
1715
+ if self.config.problem_type is None:
1716
+ if self.num_labels == 1:
1717
+ self.config.problem_type = "regression"
1718
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1719
+ self.config.problem_type = "single_label_classification"
1720
+ else:
1721
+ self.config.problem_type = "multi_label_classification"
1722
+
1723
+ if self.config.problem_type == "regression":
1724
+ loss_fct = MSELoss()
1725
+ if self.num_labels == 1:
1726
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1727
+ else:
1728
+ loss = loss_fct(logits, labels)
1729
+ elif self.config.problem_type == "single_label_classification":
1730
+ loss_fct = CrossEntropyLoss()
1731
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1732
+ elif self.config.problem_type == "multi_label_classification":
1733
+ loss_fct = BCEWithLogitsLoss()
1734
+ loss = loss_fct(logits, labels)
1735
+
1736
+ return SequenceClassifierOutput(
1737
+ loss=loss,
1738
+ logits=logits,
1739
+ hidden_states=outputs.last_hidden_state,
1740
+ attentions=outputs.attentions,
1741
+ )
1742
+
1743
+
1744
+ @auto_docstring
1745
+ class PantagruelUniForMultipleChoice(PantagruelUniPreTrainedModel):
1746
+ def __init__(self, config):
1747
+ super().__init__(config)
1748
+
1749
+ self.pantagruel_uni = PantagruelUniModel(config)
1750
+ self.dropout = nn.Dropout(config.encoder_dropout)
1751
+ self.classifier = nn.Linear(config.embed_dim, 1)
1752
+
1753
+ # Initialize weights and apply final processing
1754
+ self.post_init()
1755
+
1756
+ @can_return_tuple
1757
+ @auto_docstring
1758
+ def forward(
1759
+ self,
1760
+ input_ids: Optional[torch.LongTensor] = None,
1761
+ token_type_ids: Optional[torch.LongTensor] = None,
1762
+ attention_mask: Optional[torch.FloatTensor] = None,
1763
+ padding_mask: Optional[torch.FloatTensor] = None,
1764
+ labels: Optional[torch.LongTensor] = None,
1765
+ position_ids: Optional[torch.LongTensor] = None,
1766
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1767
+ **kwargs: Unpack[TransformersKwargs],
1768
+ ) -> Union[tuple, MultipleChoiceModelOutput]:
1769
+ r"""
1770
+ input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
1771
+ Indices of input sequence tokens in the vocabulary.
1772
+
1773
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1774
+ [`PreTrainedTokenizer.__call__`] for details.
1775
+
1776
+ [What are input IDs?](../glossary#input-ids)
1777
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
1778
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1779
+ 1]`:
1780
+
1781
+ - 0 corresponds to a *sentence A* token,
1782
+ - 1 corresponds to a *sentence B* token.
1783
+
1784
+ [What are token type IDs?](../glossary#token-type-ids)
1785
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1786
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1787
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1788
+ `input_ids` above)
1789
+ """
1790
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1791
+
1792
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1793
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1794
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1795
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1796
+ flat_inputs_embeds = (
1797
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1798
+ if inputs_embeds is not None
1799
+ else None
1800
+ )
1801
+
1802
+ outputs = self.data2vec_text(
1803
+ input_ids=flat_input_ids,
1804
+ attention_mask=flat_attention_mask,
1805
+ padding_mask=flat_attention_mask,
1806
+ mask=False,
1807
+ mode="TEXT",
1808
+ return_dict=True,
1809
+ )
1810
+ pooled_output = outputs.pooler_output
1811
+
1812
+ pooled_output = self.dropout(pooled_output)
1813
+ logits = self.classifier(pooled_output)
1814
+ reshaped_logits = logits.view(-1, num_choices)
1815
+
1816
+ loss = None
1817
+ if labels is not None:
1818
+ loss_fct = CrossEntropyLoss()
1819
+
1820
+ labels = labels.to(reshaped_logits.device)
1821
+ loss = loss_fct(reshaped_logits, labels)
1822
+
1823
+ return MultipleChoiceModelOutput(
1824
+ loss=loss,
1825
+ logits=reshaped_logits,
1826
+ hidden_states=outputs.hidden_states,
1827
+ attentions=outputs.attentions,
1828
+ )
1829
+
1830
+
1831
+ @auto_docstring
1832
+ class PantagruelUniForTokenClassification(PantagruelUniPreTrainedModel):
1833
+ def __init__(self, config):
1834
+ super().__init__(config)
1835
+ self.num_labels = config.num_labels
1836
+
1837
+ self.pantagruel_uni = PantagruelUniModel(config, add_pooling_layer=False)
1838
+ classifier_dropout = (
1839
+ config.classifier_dropout if config.classifier_dropout is not None else config.encoder_dropout
1840
+ )
1841
+ self.dropout = nn.Dropout(classifier_dropout)
1842
+ self.classifier = nn.Linear(config.embed_dim, config.num_labels)
1843
+
1844
+ # Initialize weights and apply final processing
1845
+ self.post_init()
1846
+
1847
+ @can_return_tuple
1848
+ @auto_docstring
1849
+ def forward(
1850
+ self,
1851
+ input_ids: Optional[torch.LongTensor] = None,
1852
+ attention_mask: Optional[torch.FloatTensor] = None,
1853
+ padding_mask: Optional[torch.FloatTensor] = None,
1854
+ labels: Optional[torch.LongTensor] = None,
1855
+ **kwargs: Unpack[TransformersKwargs],
1856
+ ) -> Union[tuple, TokenClassifierOutput]:
1857
+ r"""
1858
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1859
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1860
+ """
1861
+ outputs = self.pantagruel_uni(
1862
+ input_ids=input_ids,
1863
+ attention_mask=attention_mask,
1864
+ padding_mask=padding_mask,
1865
+ mask=False,
1866
+ mode="TEXT",
1867
+ return_dict=True,
1868
+ )
1869
+
1870
+ sequence_output = outputs.last_hidden_state
1871
+
1872
+ sequence_output = self.dropout(sequence_output)
1873
+ logits = self.classifier(sequence_output)
1874
+
1875
+ loss = None
1876
+ if labels is not None:
1877
+ loss_fct = CrossEntropyLoss()
1878
+
1879
+ labels = labels.to(logits.device)
1880
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1881
+
1882
+ return TokenClassifierOutput(
1883
+ loss=loss,
1884
+ logits=logits,
1885
+ hidden_states=outputs.hidden_states,
1886
+ attentions=outputs.attentions,
1887
+ )
1888
+
1889
+
1890
+ @auto_docstring
1891
+ class PantagruelUniForQuestionAnswering(PantagruelUniPreTrainedModel):
1892
+ def __init__(self, config):
1893
+ super().__init__(config)
1894
+ self.num_labels = config.num_labels
1895
+
1896
+ self.pantagruel_uni = PantagruelUniModel(config, add_pooling_layer=False)
1897
+ self.qa_outputs = nn.Linear(config.embed_dim, config.num_labels)
1898
+
1899
+ # Initialize weights and apply final processing
1900
+ self.post_init()
1901
+
1902
+ @can_return_tuple
1903
+ @auto_docstring
1904
+ def forward(
1905
+ self,
1906
+ input_ids: Optional[torch.LongTensor] = None,
1907
+ attention_mask: Optional[torch.FloatTensor] = None,
1908
+ padding_mask: Optional[torch.FloatTensor] = None,
1909
+ start_positions: Optional[torch.LongTensor] = None,
1910
+ end_positions: Optional[torch.LongTensor] = None,
1911
+ **kwargs: Unpack[TransformersKwargs],
1912
+ ) -> Union[tuple, QuestionAnsweringModelOutput]:
1913
+
1914
+ outputs = self.pantagruel_uni(
1915
+ input_ids=input_ids,
1916
+ attention_mask=attention_mask,
1917
+ padding_mask=padding_mask,
1918
+ mask=False,
1919
+ mode="TEXT",
1920
+ return_dict=True,
1921
+ )
1922
+
1923
+ sequence_output = outputs.last_hidden_state[0]
1924
+
1925
+ logits = self.qa_outputs(sequence_output)
1926
+ start_logits, end_logits = logits.split(1, dim=-1)
1927
+ start_logits = start_logits.squeeze(-1).contiguous()
1928
+ end_logits = end_logits.squeeze(-1).contiguous()
1929
+
1930
+ total_loss = None
1931
+ if start_positions is not None and end_positions is not None:
1932
+ # If we are on multi-GPU, split add a dimension
1933
+ if len(start_positions.size()) > 1:
1934
+ start_positions = start_positions.squeeze(-1)
1935
+ if len(end_positions.size()) > 1:
1936
+ end_positions = end_positions.squeeze(-1)
1937
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1938
+ ignored_index = start_logits.size(1)
1939
+ start_positions = start_positions.clamp(0, ignored_index)
1940
+ end_positions = end_positions.clamp(0, ignored_index)
1941
+
1942
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1943
+ start_loss = loss_fct(start_logits, start_positions)
1944
+ end_loss = loss_fct(end_logits, end_positions)
1945
+ total_loss = (start_loss + end_loss) / 2
1946
+
1947
+ return QuestionAnsweringModelOutput(
1948
+ loss=total_loss,
1949
+ start_logits=start_logits,
1950
+ end_logits=end_logits,
1951
+ hidden_states=outputs.hidden_states,
1952
+ attentions=outputs.attentions,
1953
+ )
1954
+
1955
+
1956
+ __all__ = [
1957
+ "PantagruelUniForMaskedLM",
1958
+ "PantagruelUniForMultipleChoice",
1959
+ "PantagruelUniForQuestionAnswering",
1960
+ "PantagruelUniForSequenceClassification",
1961
+ "PantagruelUniForTokenClassification",
1962
+ "PantagruelUniModel",
1963
+ "PantagruelUniPreTrainedModel",
1964
+ ]
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": false,
26
+ "normalized": true,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": true,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": true,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
utils_pantagruel_uni.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+
9
+ import math
10
+ import numpy as np
11
+ from collections import namedtuple
12
+ from typing import Optional, Tuple
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+
18
+ MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
19
+ MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
20
+
21
+
22
+ def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
23
+ return torch.gather(
24
+ x,
25
+ dim=1,
26
+ index=mask_info.ids_keep,
27
+ )
28
+
29
+
30
+ def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
31
+ return torch.gather(
32
+ x,
33
+ dim=1,
34
+ index=mask_info.ids_keep[..., 0], # ignore the feature dimension
35
+ )
36
+
37
+
38
+ def masked_alibi(alibi_bias, mask_info):
39
+ H = alibi_bias.size(1)
40
+
41
+ orig_bias = alibi_bias
42
+
43
+ index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
44
+ alibi_bias = torch.gather(
45
+ orig_bias,
46
+ dim=-2,
47
+ index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
48
+ )
49
+ alibi_bias = torch.gather(
50
+ alibi_bias,
51
+ dim=-1,
52
+ index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
53
+ )
54
+
55
+ return alibi_bias
56
+
57
+
58
+ def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]):
59
+ N, L, D = x.shape # batch, length, dim
60
+ len_keep = int(L * (1 - mask_ratio))
61
+
62
+ generator = None
63
+ if mask_seed is not None:
64
+ seed = int(
65
+ hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
66
+ )
67
+ generator = torch.Generator(device=x.device)
68
+ generator.manual_seed(seed)
69
+
70
+ noise = torch.rand(N, L, generator=generator, device=x.device) # noise in [0, 1]
71
+
72
+ # sort noise for each sample
73
+ ids_shuffle = noise.argsort(dim=1) # ascend: small is keep, large is remove
74
+ ids_restore = ids_shuffle.argsort(dim=1)
75
+
76
+ # keep the first subset
77
+ ids_keep = ids_shuffle[:, :len_keep]
78
+ ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
79
+ x_unmasked = torch.gather(x, dim=1, index=ids_keep)
80
+
81
+ # generate the binary mask: 0 is keep, 1 is remove
82
+ mask = torch.ones([N, L], dtype=x.dtype, device=x.device)
83
+ mask[:, :len_keep] = 0
84
+ # unshuffle to get the binary mask
85
+ mask = torch.gather(mask, dim=1, index=ids_restore)
86
+
87
+ ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
88
+
89
+ return MaskInfo(
90
+ x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
91
+ )
92
+
93
+
94
+ def get_alibi(
95
+ max_positions: int,
96
+ attention_heads: int,
97
+ dims: int = 1,
98
+ distance: str = "manhattan",
99
+ ):
100
+ def get_slopes(n):
101
+ def get_slopes_power_of_2(n):
102
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
103
+ ratio = start
104
+ return [start * ratio**i for i in range(n)]
105
+
106
+ # In the paper, we only train models that have 2^a heads for some
107
+ # a. This function has some good properties that only occur when
108
+ # the input is a power of 2. To maintain that even when the number
109
+ # of heads is not a power of 2, we use this workaround.
110
+ if math.log2(n).is_integer():
111
+ return get_slopes_power_of_2(n)
112
+ else:
113
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
114
+ return (
115
+ get_slopes_power_of_2(closest_power_of_2)
116
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
117
+ )
118
+
119
+ maxpos = max_positions
120
+ attn_heads = attention_heads
121
+ slopes = torch.Tensor(get_slopes(attn_heads))
122
+
123
+ if dims == 1:
124
+ # prepare alibi position linear bias. Note that wav2vec2 is non
125
+ # autoregressive model so we want a symmetric mask with 0 on the
126
+ # diagonal and other wise linear decreasing valuees
127
+ pos_bias = (
128
+ torch.abs(
129
+ torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
130
+ )
131
+ * -1
132
+ )
133
+ elif dims == 2:
134
+ if distance == "manhattan":
135
+ df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
136
+ elif distance == "euclidean":
137
+ df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
138
+
139
+ n = math.sqrt(max_positions)
140
+ assert n.is_integer(), n
141
+ n = int(n)
142
+
143
+ pos_bias = torch.zeros((max_positions, max_positions))
144
+
145
+ for i in range(n):
146
+ for j in range(n):
147
+ for k in range(n):
148
+ for l in range(n):
149
+ new_x = i * n + j
150
+ new_y = k * n + l
151
+ pos_bias[new_x, new_y] = -df(i, j, k, l)
152
+
153
+ else:
154
+ raise Exception(f"unsupported number of alibi dims: {dims}")
155
+
156
+ alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
157
+ attn_heads, -1, -1
158
+ )
159
+
160
+ return alibi_bias
161
+
162
+
163
+ def get_alibi_bias(
164
+ alibi_biases,
165
+ batch_size,
166
+ time_steps,
167
+ heads,
168
+ dtype,
169
+ device,
170
+ dims=1,
171
+ distance="manhattan",
172
+ ):
173
+ cache_key = f"{dims}_{heads}_{distance}"
174
+
175
+ buffered = alibi_biases.get(cache_key, None)
176
+
177
+ target_size = heads * batch_size
178
+ if (
179
+ buffered is None
180
+ or buffered.size(0) < target_size
181
+ or buffered.size(1) < time_steps
182
+ or buffered.dtype != dtype
183
+ or buffered.device != device
184
+ ):
185
+ bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
186
+ bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
187
+
188
+ buffered = (
189
+ get_alibi(bt, heads, dims=dims, distance=distance)
190
+ .to(dtype=dtype, device=device)
191
+ .repeat(bn, 1, 1)
192
+ )
193
+
194
+ alibi_biases[cache_key] = buffered
195
+
196
+ b = buffered[:target_size, :time_steps, :time_steps]
197
+ b = b.view(batch_size, heads, time_steps, time_steps)
198
+ return b
199
+
200
+
201
+ def is_xla_tensor(tensor):
202
+ return torch.is_tensor(tensor) and tensor.device.type == "xla"
203
+
204
+
205
+ def index_put(tensor, indices, value):
206
+ if is_xla_tensor(tensor):
207
+ for _ in range(indices.dim(), tensor.dim()):
208
+ indices = indices.unsqueeze(-1)
209
+ if indices.size(-1) < tensor.size(-1):
210
+ indices = indices.expand_as(tensor)
211
+ tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
212
+ else:
213
+ tensor[indices] = value
214
+ return tensor
215
+
216
+
217
+ def compute_mask_indices(
218
+ shape: Tuple[int, int],
219
+ padding_mask: Optional[torch.Tensor],
220
+ mask_prob: float,
221
+ mask_length: int,
222
+ mask_type: str = "static",
223
+ mask_other: float = 0.0,
224
+ min_masks: int = 0,
225
+ no_overlap: bool = False,
226
+ min_space: int = 0,
227
+ require_same_masks: bool = True,
228
+ mask_dropout: float = 0.0,
229
+ add_masks: bool = False,
230
+ seed: Optional[int] = None,
231
+ epoch: Optional[int] = None,
232
+ indices: Optional[torch.Tensor] = None,
233
+ idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
234
+ num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
235
+ ) -> np.ndarray:
236
+ """
237
+ Computes random mask spans for a given shape
238
+
239
+ Args:
240
+ shape: the the shape for which to compute masks.
241
+ should be of size 2 where first element is batch size and 2nd is timesteps
242
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
243
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
244
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
245
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
246
+ mask_type: how to compute mask lengths
247
+ static = fixed size
248
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
249
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
250
+ poisson = sample from possion distribution with lambda = mask length
251
+ min_masks: minimum number of masked spans
252
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
253
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
254
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
255
+ mask_dropout: randomly dropout this percentage of masks in each example
256
+ """
257
+
258
+ bsz, all_sz = shape
259
+ mask = np.full((bsz, all_sz), False)
260
+
261
+ if num_mask_ver == 1:
262
+ all_num_mask = int(
263
+ # add a random number for probabilistic rounding
264
+ mask_prob * all_sz / float(mask_length)
265
+ + np.random.rand()
266
+ )
267
+ all_num_mask = max(min_masks, all_num_mask)
268
+
269
+ mask_idcs = []
270
+ for i in range(bsz):
271
+ if seed is not None and epoch is not None and indices is not None:
272
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
273
+ else:
274
+ seed_i = None
275
+
276
+ rng = np.random.default_rng(seed_i)
277
+
278
+ if padding_mask is not None:
279
+ sz = all_sz - padding_mask[i].long().sum().item()
280
+ assert sz >= 0, sz
281
+ else:
282
+ sz = all_sz
283
+
284
+ if num_mask_ver == 1:
285
+ if padding_mask is not None:
286
+ num_mask = int(
287
+ # add a random number for probabilistic rounding
288
+ mask_prob * sz / float(mask_length)
289
+ + np.random.rand()
290
+ )
291
+ num_mask = max(min_masks, num_mask)
292
+ else:
293
+ num_mask = all_num_mask
294
+ elif num_mask_ver == 2:
295
+ num_mask = int(
296
+ # add a random number for probabilistic rounding
297
+ mask_prob * sz / float(mask_length)
298
+ + rng.random()
299
+ )
300
+ num_mask = max(min_masks, num_mask)
301
+ else:
302
+ raise ValueError()
303
+
304
+ if mask_type == "static":
305
+ lengths = np.full(num_mask, mask_length)
306
+ elif mask_type == "uniform":
307
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
308
+ elif mask_type == "normal":
309
+ lengths = rng.normal(mask_length, mask_other, size=num_mask)
310
+ lengths = [max(1, int(round(x))) for x in lengths]
311
+ elif mask_type == "poisson":
312
+ lengths = rng.poisson(mask_length, size=num_mask)
313
+ lengths = [int(round(x)) for x in lengths]
314
+ else:
315
+ raise Exception("unknown mask selection " + mask_type)
316
+
317
+ if sum(lengths) == 0:
318
+ if mask_type == "static":
319
+ raise ValueError(f"this should never happens")
320
+ else:
321
+ lengths = [min(mask_length, sz - 1)]
322
+
323
+ if no_overlap:
324
+ mask_idc = []
325
+
326
+ def arrange(s, e, length, keep_length):
327
+ span_start = rng.randint(s, e - length)
328
+ mask_idc.extend(span_start + i for i in range(length))
329
+
330
+ new_parts = []
331
+ if span_start - s - min_space >= keep_length:
332
+ new_parts.append((s, span_start - min_space + 1))
333
+ if e - span_start - length - min_space > keep_length:
334
+ new_parts.append((span_start + length + min_space, e))
335
+ return new_parts
336
+
337
+ parts = [(0, sz)]
338
+ min_length = min(lengths)
339
+ for length in sorted(lengths, reverse=True):
340
+ lens = np.fromiter(
341
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
342
+ np.int,
343
+ )
344
+ l_sum = np.sum(lens)
345
+ if l_sum == 0:
346
+ break
347
+ probs = lens / np.sum(lens)
348
+ c = rng.choice(len(parts), p=probs)
349
+ s, e = parts.pop(c)
350
+ parts.extend(arrange(s, e, length, min_length))
351
+ mask_idc = np.asarray(mask_idc)
352
+ else:
353
+ if idc_select_ver == 1:
354
+ min_len = min(lengths)
355
+ if sz - min_len <= num_mask:
356
+ min_len = sz - num_mask - 1
357
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
358
+ elif idc_select_ver == 2:
359
+ mask_idc = rng.choice(sz, num_mask, replace=False)
360
+ else:
361
+ raise ValueError()
362
+
363
+ mask_idc = np.asarray(
364
+ [
365
+ mask_idc[j] + offset
366
+ for j in range(len(mask_idc))
367
+ for offset in range(lengths[j])
368
+ ]
369
+ )
370
+
371
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
372
+ if len(mask_idc) >= sz:
373
+ raise ValueError(
374
+ (
375
+ f"the entire sequence is masked. "
376
+ f"sz={sz}; mask_idc[mask_idc]; "
377
+ f"index={indices[i] if indices is not None else None}"
378
+ )
379
+ )
380
+ mask_idcs.append(mask_idc)
381
+
382
+ target_len = None
383
+ if require_same_masks:
384
+ if add_masks:
385
+ target_len = max([len(m) for m in mask_idcs])
386
+ else:
387
+ target_len = min([len(m) for m in mask_idcs])
388
+
389
+ for i, mask_idc in enumerate(mask_idcs):
390
+ if target_len is not None and len(mask_idc) > target_len:
391
+ mask_idc = rng.choice(mask_idc, target_len, replace=False)
392
+
393
+ mask[i, mask_idc] = True
394
+
395
+ if target_len is not None and len(mask_idc) < target_len:
396
+ unmasked = np.flatnonzero(~mask[i])
397
+ to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
398
+ mask[i, to_mask] = True
399
+
400
+ if mask_dropout > 0:
401
+ masked = np.flatnonzero(mask[i])
402
+ num_holes = np.rint(len(masked) * mask_dropout).astype(int)
403
+ to_drop = rng.choice(masked, num_holes, replace=False)
404
+ mask[i, to_drop] = False
405
+
406
+ return mask
407
+
408
+
409
+ def _learned_alibi_bias(
410
+ alibi_bias,
411
+ batch_size,
412
+ time_steps,
413
+ heads,
414
+ scale,
415
+ dtype,
416
+ device,
417
+ ):
418
+ assert alibi_bias.size(1) == heads, alibi_bias.shape
419
+ assert alibi_bias.dtype == dtype, alibi_bias.dtype
420
+ assert alibi_bias.device == device, alibi_bias.device
421
+
422
+ if alibi_bias.size(-1) < time_steps:
423
+ psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
424
+ alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
425
+
426
+ alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
427
+ return alibi_bias[..., :time_steps, :time_steps]
428
+
429
+ def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
430
+ """Replace non-padding symbols with their position numbers.
431
+
432
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
433
+ """
434
+ # The series of casts and type-conversions here are carefully
435
+ # balanced to both work with ONNX export and XLA. In particular XLA
436
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
437
+ # how to handle the dtype kwarg in cumsum.
438
+ mask = tensor.ne(padding_idx).int()
439
+ return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx