sleepyhead111 commited on
Commit
4626a8f
·
verified ·
1 Parent(s): b3360fe

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq-0.10.2/fairseq/data/__pycache__/backtranslation_dataset.cpython-310.pyc +0 -0
  2. fairseq-0.10.2/fairseq/data/__pycache__/base_wrapper_dataset.cpython-310.pyc +0 -0
  3. fairseq-0.10.2/fairseq/data/__pycache__/concat_sentences_dataset.cpython-310.pyc +0 -0
  4. fairseq-0.10.2/fairseq/data/__pycache__/data_utils.cpython-310.pyc +0 -0
  5. fairseq-0.10.2/fairseq/data/__pycache__/denoising_dataset.cpython-310.pyc +0 -0
  6. fairseq-0.10.2/fairseq/data/__pycache__/fasta_dataset.cpython-310.pyc +0 -0
  7. fairseq-0.10.2/fairseq/data/__pycache__/iterators.cpython-310.pyc +0 -0
  8. fairseq-0.10.2/fairseq/data/__pycache__/list_dataset.cpython-310.pyc +0 -0
  9. fairseq-0.10.2/fairseq/data/__pycache__/lm_context_window_dataset.cpython-310.pyc +0 -0
  10. fairseq-0.10.2/fairseq/data/__pycache__/lru_cache_dataset.cpython-310.pyc +0 -0
  11. fairseq-0.10.2/fairseq/data/__pycache__/mask_tokens_dataset.cpython-310.pyc +0 -0
  12. fairseq-0.10.2/fairseq/data/__pycache__/noising.cpython-310.pyc +0 -0
  13. fairseq-0.10.2/fairseq/data/__pycache__/num_samples_dataset.cpython-310.pyc +0 -0
  14. fairseq-0.10.2/fairseq/data/__pycache__/numel_dataset.cpython-310.pyc +0 -0
  15. fairseq-0.10.2/fairseq/data/__pycache__/pad_dataset.cpython-310.pyc +0 -0
  16. fairseq-0.10.2/fairseq/data/__pycache__/plasma_utils.cpython-310.pyc +0 -0
  17. fairseq-0.10.2/fairseq/data/__pycache__/prepend_token_dataset.cpython-310.pyc +0 -0
  18. fairseq-0.10.2/fairseq/data/__pycache__/raw_label_dataset.cpython-310.pyc +0 -0
  19. fairseq-0.10.2/fairseq/data/__pycache__/replace_dataset.cpython-310.pyc +0 -0
  20. fairseq-0.10.2/fairseq/data/__pycache__/shorten_dataset.cpython-310.pyc +0 -0
  21. fairseq-0.10.2/fairseq/data/__pycache__/sort_dataset.cpython-310.pyc +0 -0
  22. fairseq-0.10.2/fairseq/data/__pycache__/strip_token_dataset.cpython-310.pyc +0 -0
  23. fairseq-0.10.2/fairseq/data/__pycache__/token_block_dataset.cpython-310.pyc +0 -0
  24. fairseq-0.10.2/fairseq/data/__pycache__/transform_eos_dataset.cpython-310.pyc +0 -0
  25. fairseq-0.10.2/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-310.pyc +0 -0
  26. fairseq-0.10.2/fairseq/models/__pycache__/composite_encoder.cpython-310.pyc +0 -0
  27. fairseq-0.10.2/fairseq/models/__pycache__/transformer_align.cpython-310.pyc +0 -0
  28. fairseq-0.10.2/fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-310.pyc +0 -0
  29. fairseq-0.10.2/fairseq/models/bart/__pycache__/__init__.cpython-310.pyc +0 -0
  30. fairseq-0.10.2/fairseq/models/bart/__pycache__/hub_interface.cpython-310.pyc +0 -0
  31. fairseq-0.10.2/fairseq/models/bart/__pycache__/model.cpython-310.pyc +0 -0
  32. fairseq-0.10.2/fairseq/models/bart/model.py +368 -0
  33. fairseq-0.10.2/fairseq/models/nat/__init__.py +13 -0
  34. fairseq-0.10.2/fairseq/models/nat/__pycache__/cmlm_transformer.cpython-310.pyc +0 -0
  35. fairseq-0.10.2/fairseq/models/nat/cmlm_transformer.py +162 -0
  36. fairseq-0.10.2/fairseq/models/nat/fairseq_nat_model.py +159 -0
  37. fairseq-0.10.2/fairseq/models/nat/insertion_transformer.py +280 -0
  38. fairseq-0.10.2/fairseq/models/nat/iterative_nonautoregressive_transformer.py +228 -0
  39. fairseq-0.10.2/fairseq/models/nat/levenshtein_utils.py +293 -0
  40. fairseq-0.10.2/fairseq/models/nat/nat_crf_transformer.py +121 -0
  41. fairseq-0.10.2/fairseq/models/nat/nonautoregressive_ensembles.py +254 -0
  42. fairseq-0.10.2/fairseq/models/nat/nonautoregressive_transformer.py +440 -0
  43. fairseq-0.10.2/fairseq/models/roberta/__init__.py +9 -0
  44. fairseq-0.10.2/fairseq/models/roberta/__pycache__/model.cpython-310.pyc +0 -0
  45. fairseq-0.10.2/fairseq/models/roberta/alignment_utils.py +118 -0
  46. fairseq-0.10.2/fairseq/models/roberta/hub_interface.py +235 -0
  47. fairseq-0.10.2/fairseq/models/roberta/model.py +524 -0
  48. fairseq-0.10.2/fairseq/models/roberta/model_camembert.py +50 -0
  49. fairseq-0.10.2/fairseq/models/wav2vec/__init__.py +8 -0
  50. fairseq-0.10.2/fairseq/models/wav2vec/__pycache__/__init__.cpython-310.pyc +0 -0
fairseq-0.10.2/fairseq/data/__pycache__/backtranslation_dataset.cpython-310.pyc ADDED
Binary file (6.74 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/base_wrapper_dataset.cpython-310.pyc ADDED
Binary file (3.23 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/concat_sentences_dataset.cpython-310.pyc ADDED
Binary file (3.34 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/data_utils.cpython-310.pyc ADDED
Binary file (15.6 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/denoising_dataset.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/fasta_dataset.cpython-310.pyc ADDED
Binary file (3.83 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/iterators.cpython-310.pyc ADDED
Binary file (18.2 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/list_dataset.cpython-310.pyc ADDED
Binary file (1.39 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/lm_context_window_dataset.cpython-310.pyc ADDED
Binary file (2.99 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/lru_cache_dataset.cpython-310.pyc ADDED
Binary file (973 Bytes). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/mask_tokens_dataset.cpython-310.pyc ADDED
Binary file (5.1 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/noising.cpython-310.pyc ADDED
Binary file (9.37 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/num_samples_dataset.cpython-310.pyc ADDED
Binary file (792 Bytes). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/numel_dataset.cpython-310.pyc ADDED
Binary file (1.16 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/pad_dataset.cpython-310.pyc ADDED
Binary file (1.36 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/plasma_utils.cpython-310.pyc ADDED
Binary file (2.57 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/prepend_token_dataset.cpython-310.pyc ADDED
Binary file (1.41 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/raw_label_dataset.cpython-310.pyc ADDED
Binary file (1.03 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/replace_dataset.cpython-310.pyc ADDED
Binary file (1.53 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/shorten_dataset.cpython-310.pyc ADDED
Binary file (2.84 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/sort_dataset.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/strip_token_dataset.cpython-310.pyc ADDED
Binary file (922 Bytes). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/token_block_dataset.cpython-310.pyc ADDED
Binary file (5.03 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/transform_eos_dataset.cpython-310.pyc ADDED
Binary file (4.15 kB). View file
 
fairseq-0.10.2/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-310.pyc ADDED
Binary file (3.68 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/composite_encoder.cpython-310.pyc ADDED
Binary file (2.36 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/transformer_align.cpython-310.pyc ADDED
Binary file (3 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-310.pyc ADDED
Binary file (5.34 kB). View file
 
fairseq-0.10.2/fairseq/models/bart/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (215 Bytes). View file
 
fairseq-0.10.2/fairseq/models/bart/__pycache__/hub_interface.cpython-310.pyc ADDED
Binary file (7.4 kB). View file
 
fairseq-0.10.2/fairseq/models/bart/__pycache__/model.cpython-310.pyc ADDED
Binary file (9.81 kB). View file
 
fairseq-0.10.2/fairseq/models/bart/model.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ BART: Denoising Sequence-to-Sequence Pre-training for
7
+ Natural Language Generation, Translation, and Comprehension
8
+ """
9
+
10
+ import logging
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from fairseq import utils
15
+ from fairseq.models import register_model, register_model_architecture
16
+ from fairseq.models.transformer import TransformerModel
17
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
18
+
19
+ from .hub_interface import BARTHubInterface
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @register_model("bart")
26
+ class BARTModel(TransformerModel):
27
+ @classmethod
28
+ def hub_models(cls):
29
+ return {
30
+ "bart.base": "http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz",
31
+ "bart.large": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz",
32
+ "bart.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz",
33
+ "bart.large.cnn": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz",
34
+ "bart.large.xsum": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz",
35
+ }
36
+
37
+ def __init__(self, args, encoder, decoder):
38
+ super().__init__(args, encoder, decoder)
39
+
40
+ # We follow BERT's random weight initialization
41
+ self.apply(init_bert_params)
42
+
43
+ self.classification_heads = nn.ModuleDict()
44
+
45
+ @staticmethod
46
+ def add_args(parser):
47
+ super(BARTModel, BARTModel).add_args(parser)
48
+ parser.add_argument(
49
+ "--pooler-dropout",
50
+ type=float,
51
+ metavar="D",
52
+ help="dropout probability in the masked_lm pooler layers",
53
+ )
54
+ parser.add_argument(
55
+ "--pooler-activation-fn",
56
+ choices=utils.get_available_activation_fns(),
57
+ help="activation function to use for pooler layer",
58
+ )
59
+ parser.add_argument(
60
+ "--spectral-norm-classification-head",
61
+ action="store_true",
62
+ help="Apply spectral normalization on the classification head",
63
+ )
64
+
65
+ @property
66
+ def supported_targets(self):
67
+ return {"self"}
68
+
69
+ def forward(
70
+ self,
71
+ src_tokens,
72
+ src_lengths,
73
+ prev_output_tokens,
74
+ features_only=False,
75
+ classification_head_name=None,
76
+ token_embeddings=None,
77
+ **kwargs,
78
+ ):
79
+ if classification_head_name is not None:
80
+ features_only = True
81
+
82
+ encoder_out = self.encoder(
83
+ src_tokens,
84
+ src_lengths=src_lengths,
85
+ token_embeddings=token_embeddings,
86
+ **kwargs,
87
+ )
88
+ x, extra = self.decoder(
89
+ prev_output_tokens,
90
+ encoder_out=encoder_out,
91
+ features_only=features_only,
92
+ **kwargs,
93
+ )
94
+
95
+ if classification_head_name is not None:
96
+ sentence_representation = x[
97
+ src_tokens.eq(self.encoder.dictionary.eos()), :
98
+ ].view(x.size(0), -1, x.size(-1))[:, -1, :]
99
+ x = self.classification_heads[classification_head_name](
100
+ sentence_representation
101
+ )
102
+ return x, extra
103
+
104
+ @classmethod
105
+ def from_pretrained(
106
+ cls,
107
+ model_name_or_path,
108
+ checkpoint_file="model.pt",
109
+ data_name_or_path=".",
110
+ bpe="gpt2",
111
+ **kwargs,
112
+ ):
113
+ from fairseq import hub_utils
114
+
115
+ x = hub_utils.from_pretrained(
116
+ model_name_or_path,
117
+ checkpoint_file,
118
+ data_name_or_path,
119
+ archive_map=cls.hub_models(),
120
+ bpe=bpe,
121
+ load_checkpoint_heads=True,
122
+ **kwargs,
123
+ )
124
+ return BARTHubInterface(x["args"], x["task"], x["models"][0])
125
+
126
+ def register_classification_head(
127
+ self, name, num_classes=None, inner_dim=None, **kwargs
128
+ ):
129
+ """Register a classification head."""
130
+ logger.info("Registering classification head: {0}".format(name))
131
+ if name in self.classification_heads:
132
+ prev_num_classes = self.classification_heads[name].out_proj.out_features
133
+ prev_inner_dim = self.classification_heads[name].dense.out_features
134
+ if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
135
+ logger.warning(
136
+ 're-registering head "{}" with num_classes {} (prev: {}) '
137
+ "and inner_dim {} (prev: {})".format(
138
+ name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
139
+ )
140
+ )
141
+ self.classification_heads[name] = BARTClassificationHead(
142
+ input_dim=self.args.encoder_embed_dim,
143
+ inner_dim=inner_dim or self.args.encoder_embed_dim,
144
+ num_classes=num_classes,
145
+ activation_fn=self.args.pooler_activation_fn,
146
+ pooler_dropout=self.args.pooler_dropout,
147
+ do_spectral_norm=self.args.spectral_norm_classification_head,
148
+ )
149
+
150
+ def upgrade_state_dict_named(self, state_dict, name):
151
+ super().upgrade_state_dict_named(state_dict, name)
152
+
153
+ prefix = name + "." if name != "" else ""
154
+ current_head_names = (
155
+ []
156
+ if not hasattr(self, "classification_heads")
157
+ else self.classification_heads.keys()
158
+ )
159
+
160
+ # Handle new classification heads present in the state dict.
161
+ keys_to_delete = []
162
+ for k in state_dict.keys():
163
+ if not k.startswith(prefix + "classification_heads."):
164
+ continue
165
+
166
+ head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
167
+ num_classes = state_dict[
168
+ prefix + "classification_heads." + head_name + ".out_proj.weight"
169
+ ].size(0)
170
+ inner_dim = state_dict[
171
+ prefix + "classification_heads." + head_name + ".dense.weight"
172
+ ].size(0)
173
+
174
+ if getattr(self.args, "load_checkpoint_heads", False):
175
+ if head_name not in current_head_names:
176
+ self.register_classification_head(head_name, num_classes, inner_dim)
177
+ else:
178
+ if head_name not in current_head_names:
179
+ logger.warning(
180
+ "deleting classification head ({}) from checkpoint "
181
+ "not present in current model: {}".format(head_name, k)
182
+ )
183
+ keys_to_delete.append(k)
184
+ elif (
185
+ num_classes
186
+ != self.classification_heads[head_name].out_proj.out_features
187
+ or inner_dim
188
+ != self.classification_heads[head_name].dense.out_features
189
+ ):
190
+ logger.warning(
191
+ "deleting classification head ({}) from checkpoint "
192
+ "with different dimensions than current model: {}".format(
193
+ head_name, k
194
+ )
195
+ )
196
+ keys_to_delete.append(k)
197
+ for k in keys_to_delete:
198
+ del state_dict[k]
199
+
200
+ def truncate_emb(key):
201
+ if key in state_dict:
202
+ state_dict[key] = state_dict[key][:-1, :]
203
+
204
+ # When finetuning on translation task, remove last row of
205
+ # embedding matrix that corresponds to mask_idx token.
206
+ loaded_dict_size = state_dict["encoder.embed_tokens.weight"].size(0)
207
+ if (
208
+ loaded_dict_size == len(self.encoder.dictionary) + 1
209
+ and "<mask>" not in self.encoder.dictionary
210
+ ):
211
+ truncate_emb("encoder.embed_tokens.weight")
212
+ truncate_emb("decoder.embed_tokens.weight")
213
+ truncate_emb("encoder.output_projection.weight")
214
+ truncate_emb("decoder.output_projection.weight")
215
+
216
+ # When continued pretraining on new set of languages for mbart,
217
+ # add extra lang embeddings at the end of embed_tokens.
218
+ # Note: newly added languages are assumed to have been added at the end.
219
+ if self.args.task == "multilingual_denoising" and loaded_dict_size < len(
220
+ self.encoder.dictionary
221
+ ):
222
+ logger.info(
223
+ "Adding extra language embeddings not found in pretrained model for "
224
+ "continued pretraining of MBART on new set of languages."
225
+ )
226
+ loaded_mask_token_embedding = state_dict["encoder.embed_tokens.weight"][
227
+ -1, :
228
+ ]
229
+
230
+ num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size
231
+ embed_dim = state_dict["encoder.embed_tokens.weight"].size(1)
232
+
233
+ new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim)
234
+ nn.init.normal_(new_lang_embed_to_add, mean=0, std=embed_dim ** -0.5)
235
+ new_lang_embed_to_add = new_lang_embed_to_add.to(
236
+ dtype=state_dict["encoder.embed_tokens.weight"].dtype,
237
+ )
238
+
239
+ state_dict["encoder.embed_tokens.weight"] = torch.cat(
240
+ [
241
+ state_dict["encoder.embed_tokens.weight"][
242
+ : loaded_dict_size - 1, :
243
+ ],
244
+ new_lang_embed_to_add,
245
+ loaded_mask_token_embedding.unsqueeze(0),
246
+ ]
247
+ )
248
+ state_dict["decoder.embed_tokens.weight"] = torch.cat(
249
+ [
250
+ state_dict["decoder.embed_tokens.weight"][
251
+ : loaded_dict_size - 1, :
252
+ ],
253
+ new_lang_embed_to_add,
254
+ loaded_mask_token_embedding.unsqueeze(0),
255
+ ]
256
+ )
257
+
258
+ # Copy any newly-added classification heads into the state dict
259
+ # with their current weights.
260
+ if hasattr(self, "classification_heads"):
261
+ cur_state = self.classification_heads.state_dict()
262
+ for k, v in cur_state.items():
263
+ if prefix + "classification_heads." + k not in state_dict:
264
+ logger.info("Overwriting", prefix + "classification_heads." + k)
265
+ state_dict[prefix + "classification_heads." + k] = v
266
+
267
+
268
+ class BARTClassificationHead(nn.Module):
269
+ """Head for sentence-level classification tasks."""
270
+
271
+ def __init__(
272
+ self,
273
+ input_dim,
274
+ inner_dim,
275
+ num_classes,
276
+ activation_fn,
277
+ pooler_dropout,
278
+ do_spectral_norm=False,
279
+ ):
280
+ super().__init__()
281
+ self.dense = nn.Linear(input_dim, inner_dim)
282
+ self.activation_fn = utils.get_activation_fn(activation_fn)
283
+ self.dropout = nn.Dropout(p=pooler_dropout)
284
+ self.out_proj = nn.Linear(inner_dim, num_classes)
285
+
286
+ if do_spectral_norm:
287
+ self.out_proj = torch.nn.utils.spectral_norm(self.out_proj)
288
+
289
+ def forward(self, features, **kwargs):
290
+ x = features
291
+ x = self.dropout(x)
292
+ x = self.dense(x)
293
+ x = self.activation_fn(x)
294
+ x = self.dropout(x)
295
+ x = self.out_proj(x)
296
+ return x
297
+
298
+
299
+ @register_model_architecture("bart", "bart_large")
300
+ def bart_large_architecture(args):
301
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
302
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
303
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1024)
304
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
305
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
306
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
307
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
308
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
309
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
310
+ args.decoder_ffn_embed_dim = getattr(
311
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
312
+ )
313
+ args.decoder_layers = getattr(args, "decoder_layers", 12)
314
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
315
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
316
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
317
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
318
+ args.relu_dropout = getattr(args, "relu_dropout", 0.0)
319
+ args.dropout = getattr(args, "dropout", 0.1)
320
+ args.max_target_positions = getattr(args, "max_target_positions", 1024)
321
+ args.max_source_positions = getattr(args, "max_source_positions", 1024)
322
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
323
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
324
+ args.share_decoder_input_output_embed = getattr(
325
+ args, "share_decoder_input_output_embed", True
326
+ )
327
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
328
+
329
+ args.decoder_output_dim = getattr(
330
+ args, "decoder_output_dim", args.decoder_embed_dim
331
+ )
332
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
333
+
334
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
335
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
336
+
337
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
338
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
339
+ args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
340
+
341
+
342
+ @register_model_architecture("bart", "bart_base")
343
+ def bart_base_architecture(args):
344
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
345
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 768)
346
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
347
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
348
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
349
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
350
+ bart_large_architecture(args)
351
+
352
+
353
+ @register_model_architecture("bart", "mbart_large")
354
+ def mbart_large_architecture(args):
355
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
356
+ bart_large_architecture(args)
357
+
358
+
359
+ @register_model_architecture("bart", "mbart_base")
360
+ def mbart_base_architecture(args):
361
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
362
+ bart_base_architecture(args)
363
+
364
+
365
+ @register_model_architecture("bart", "mbart_base_wmt20")
366
+ def mbart_base_wmt20_architecture(args):
367
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
368
+ mbart_base_architecture(args)
fairseq-0.10.2/fairseq/models/nat/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """isort:skip_file"""
6
+
7
+ from .fairseq_nat_model import *
8
+ from .nonautoregressive_transformer import *
9
+ from .nat_crf_transformer import *
10
+ from .iterative_nonautoregressive_transformer import *
11
+ from .cmlm_transformer import *
12
+ from .levenshtein_transformer import *
13
+ from .insertion_transformer import *
fairseq-0.10.2/fairseq/models/nat/__pycache__/cmlm_transformer.cpython-310.pyc ADDED
Binary file (4.38 kB). View file
 
fairseq-0.10.2/fairseq/models/nat/cmlm_transformer.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ This file implements:
8
+ Ghazvininejad, Marjan, et al.
9
+ "Constant-time machine translation with conditional masked language models."
10
+ arXiv preprint arXiv:1904.09324 (2019).
11
+ """
12
+
13
+ from fairseq.models import register_model, register_model_architecture
14
+ from fairseq.models.nat import NATransformerModel
15
+ from fairseq.utils import new_arange
16
+
17
+
18
+ def _skeptical_unmasking(output_scores, output_masks, p):
19
+ sorted_index = output_scores.sort(-1)[1]
20
+ boundary_len = (
21
+ (output_masks.sum(1, keepdim=True).type_as(output_scores) - 2) * p
22
+ ).long()
23
+ skeptical_mask = new_arange(output_masks) < boundary_len
24
+ return skeptical_mask.scatter(1, sorted_index, skeptical_mask)
25
+
26
+
27
+ @register_model("cmlm_transformer")
28
+ class CMLMNATransformerModel(NATransformerModel):
29
+ @staticmethod
30
+ def add_args(parser):
31
+ NATransformerModel.add_args(parser)
32
+
33
+ def forward(
34
+ self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
35
+ ):
36
+ assert not self.decoder.src_embedding_copy, "do not support embedding copy."
37
+
38
+ # encoding
39
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
40
+ # length prediction
41
+ length_out = self.decoder.forward_length(
42
+ normalize=False, encoder_out=encoder_out
43
+ )
44
+ length_tgt = self.decoder.forward_length_prediction(
45
+ length_out, encoder_out, tgt_tokens
46
+ )
47
+
48
+ # decoding
49
+ word_ins_out = self.decoder(
50
+ normalize=False,
51
+ prev_output_tokens=prev_output_tokens,
52
+ encoder_out=encoder_out,
53
+ )
54
+ word_ins_mask = prev_output_tokens.eq(self.unk)
55
+
56
+ return {
57
+ "word_ins": {
58
+ "out": word_ins_out,
59
+ "tgt": tgt_tokens,
60
+ "mask": word_ins_mask,
61
+ "ls": self.args.label_smoothing,
62
+ "nll_loss": True,
63
+ },
64
+ "length": {
65
+ "out": length_out,
66
+ "tgt": length_tgt,
67
+ "factor": self.decoder.length_loss_factor,
68
+ },
69
+ }
70
+
71
+ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
72
+
73
+ step = decoder_out.step
74
+ max_step = decoder_out.max_step
75
+
76
+ output_tokens = decoder_out.output_tokens
77
+ output_scores = decoder_out.output_scores
78
+ history = decoder_out.history
79
+
80
+ # execute the decoder
81
+ output_masks = output_tokens.eq(self.unk)
82
+ _scores, _tokens = self.decoder(
83
+ normalize=True,
84
+ prev_output_tokens=output_tokens,
85
+ encoder_out=encoder_out,
86
+ ).max(-1)
87
+ output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
88
+ output_scores.masked_scatter_(output_masks, _scores[output_masks])
89
+
90
+ if history is not None:
91
+ history.append(output_tokens.clone())
92
+
93
+ # skeptical decoding (depend on the maximum decoding steps.)
94
+ if (step + 1) < max_step:
95
+ skeptical_mask = _skeptical_unmasking(
96
+ output_scores, output_tokens.ne(self.pad), 1 - (step + 1) / max_step
97
+ )
98
+
99
+ output_tokens.masked_fill_(skeptical_mask, self.unk)
100
+ output_scores.masked_fill_(skeptical_mask, 0.0)
101
+
102
+ if history is not None:
103
+ history.append(output_tokens.clone())
104
+
105
+ return decoder_out._replace(
106
+ output_tokens=output_tokens,
107
+ output_scores=output_scores,
108
+ attn=None,
109
+ history=history,
110
+ )
111
+
112
+
113
+ @register_model_architecture("cmlm_transformer", "cmlm_transformer")
114
+ def cmlm_base_architecture(args):
115
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
116
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
117
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
118
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
119
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
120
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
121
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
122
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
123
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
124
+ args.decoder_ffn_embed_dim = getattr(
125
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
126
+ )
127
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
128
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
129
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
130
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
131
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
132
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
133
+ args.activation_fn = getattr(args, "activation_fn", "relu")
134
+ args.dropout = getattr(args, "dropout", 0.1)
135
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
136
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
137
+ args.share_decoder_input_output_embed = getattr(
138
+ args, "share_decoder_input_output_embed", False
139
+ )
140
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
141
+ args.no_token_positional_embeddings = getattr(
142
+ args, "no_token_positional_embeddings", False
143
+ )
144
+ args.adaptive_input = getattr(args, "adaptive_input", False)
145
+ args.apply_bert_init = getattr(args, "apply_bert_init", False)
146
+
147
+ args.decoder_output_dim = getattr(
148
+ args, "decoder_output_dim", args.decoder_embed_dim
149
+ )
150
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
151
+
152
+ # --- special arguments ---
153
+ args.sg_length_pred = getattr(args, "sg_length_pred", False)
154
+ args.pred_length_offset = getattr(args, "pred_length_offset", False)
155
+ args.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
156
+ args.ngram_predictor = getattr(args, "ngram_predictor", 1)
157
+ args.src_embedding_copy = getattr(args, "src_embedding_copy", False)
158
+
159
+
160
+ @register_model_architecture("cmlm_transformer", "cmlm_transformer_wmt_en_de")
161
+ def cmlm_wmt_en_de(args):
162
+ cmlm_base_architecture(args)
fairseq-0.10.2/fairseq/models/nat/fairseq_nat_model.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ from fairseq.models.transformer import (
10
+ TransformerDecoder,
11
+ TransformerEncoder,
12
+ TransformerModel,
13
+ )
14
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
15
+
16
+
17
+ def ensemble_encoder(func):
18
+ def wrapper(self, *args, **kwargs):
19
+ if self.ensemble_models is None or len(self.ensemble_models) == 1:
20
+ return func(self, *args, **kwargs)
21
+ encoder_outs = [func(model, *args, **kwargs) for model in self.ensemble_models]
22
+ _encoder_out = encoder_outs[0]
23
+
24
+ def stack(key):
25
+ outs = [getattr(e, key) for e in encoder_outs]
26
+ return torch.stack(outs, -1) if outs[0] is not None else None
27
+
28
+ return _encoder_out._replace(
29
+ encoder_out=stack("encoder_out"),
30
+ encoder_embedding=stack("encoder_embedding"),
31
+ encoder_states=stack("encoder_states"),
32
+ )
33
+
34
+ return wrapper
35
+
36
+
37
+ def ensemble_decoder(func):
38
+ def wrapper(self, normalize=False, encoder_out=None, *args, **kwargs):
39
+ if self.ensemble_models is None or len(self.ensemble_models) == 1:
40
+ return func(
41
+ self, normalize=normalize, encoder_out=encoder_out, *args, **kwargs
42
+ )
43
+
44
+ action_outs = [
45
+ func(
46
+ model,
47
+ normalize=normalize,
48
+ encoder_out=encoder_out._replace(
49
+ encoder_out=encoder_out.encoder_out[:, :, :, i]
50
+ ),
51
+ *args,
52
+ **kwargs
53
+ )
54
+ for i, model in enumerate(self.ensemble_models)
55
+ ]
56
+
57
+ if not isinstance(action_outs[0], tuple): # return multiple values
58
+ action_outs = [[a] for a in action_outs]
59
+ else:
60
+ action_outs = [list(a) for a in action_outs]
61
+
62
+ ensembled_outs = []
63
+ for i in range(len(action_outs[0])):
64
+ if i == 0 and normalize:
65
+ ensembled_outs += [
66
+ torch.logsumexp(
67
+ torch.stack([a[i] for a in action_outs], -1), dim=-1
68
+ )
69
+ - math.log(len(self.ensemble_models))
70
+ ]
71
+ elif action_outs[0][i] is not None:
72
+ ensembled_outs += [torch.stack([a[i] for a in action_outs], -1)]
73
+ else:
74
+ ensembled_outs += [None]
75
+
76
+ if len(ensembled_outs) == 1:
77
+ return ensembled_outs[0]
78
+ return tuple(ensembled_outs)
79
+
80
+ return wrapper
81
+
82
+
83
+ class FairseqNATModel(TransformerModel):
84
+ """
85
+ Abstract class for all nonautoregressive-based models
86
+ """
87
+
88
+ def __init__(self, args, encoder, decoder):
89
+ super().__init__(args, encoder, decoder)
90
+ self.tgt_dict = decoder.dictionary
91
+ self.bos = decoder.dictionary.bos()
92
+ self.eos = decoder.dictionary.eos()
93
+ self.pad = decoder.dictionary.pad()
94
+ self.unk = decoder.dictionary.unk()
95
+
96
+ self.ensemble_models = None
97
+
98
+ @property
99
+ def allow_length_beam(self):
100
+ return False
101
+
102
+ @property
103
+ def allow_ensemble(self):
104
+ return True
105
+
106
+ def enable_ensemble(self, models):
107
+ self.encoder.ensemble_models = [m.encoder for m in models]
108
+ self.decoder.ensemble_models = [m.decoder for m in models]
109
+
110
+ @staticmethod
111
+ def add_args(parser):
112
+ TransformerModel.add_args(parser)
113
+ parser.add_argument(
114
+ "--apply-bert-init",
115
+ action="store_true",
116
+ help="use custom param initialization for BERT",
117
+ )
118
+
119
+ @classmethod
120
+ def build_decoder(cls, args, tgt_dict, embed_tokens):
121
+ decoder = FairseqNATDecoder(args, tgt_dict, embed_tokens)
122
+ if getattr(args, "apply_bert_init", False):
123
+ decoder.apply(init_bert_params)
124
+ return decoder
125
+
126
+ @classmethod
127
+ def build_encoder(cls, args, src_dict, embed_tokens):
128
+ encoder = FairseqNATEncoder(args, src_dict, embed_tokens)
129
+ if getattr(args, "apply_bert_init", False):
130
+ encoder.apply(init_bert_params)
131
+ return encoder
132
+
133
+ def forward_encoder(self, encoder_inputs):
134
+ return self.encoder(*encoder_inputs)
135
+
136
+ def forward_decoder(self, *args, **kwargs):
137
+ return NotImplementedError
138
+
139
+ def initialize_output_tokens(self, *args, **kwargs):
140
+ return NotImplementedError
141
+
142
+ def forward(self, *args, **kwargs):
143
+ return NotImplementedError
144
+
145
+
146
+ class FairseqNATEncoder(TransformerEncoder):
147
+ def __init__(self, args, dictionary, embed_tokens):
148
+ super().__init__(args, dictionary, embed_tokens)
149
+ self.ensemble_models = None
150
+
151
+ @ensemble_encoder
152
+ def forward(self, *args, **kwargs):
153
+ return super().forward(*args, **kwargs)
154
+
155
+
156
+ class FairseqNATDecoder(TransformerDecoder):
157
+ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
158
+ super().__init__(args, dictionary, embed_tokens, no_encoder_attn)
159
+ self.ensemble_models = None
fairseq-0.10.2/fairseq/models/nat/insertion_transformer.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from fairseq.models import register_model, register_model_architecture
10
+ from fairseq.models.nat import (
11
+ FairseqNATModel,
12
+ LevenshteinTransformerDecoder,
13
+ LevenshteinTransformerModel,
14
+ ensemble_decoder,
15
+ )
16
+ from fairseq.models.transformer import Linear
17
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
18
+ from fairseq.utils import new_arange
19
+
20
+
21
+ class NegativeDistanceScore(object):
22
+ def __init__(self):
23
+
24
+ # pre-compute some values
25
+ self.scores = {}
26
+
27
+ self.scores[0.5] = self.compute_score_full(50, 0.5)
28
+ self.scores[1.0] = self.compute_score_full(50, 1.0)
29
+ self.scores[2.0] = self.compute_score_full(50, 2.0)
30
+
31
+ def __call__(self, i, L, tau):
32
+ if (tau is None) or (tau > 1000):
33
+ return 1 / L
34
+
35
+ if tau in self.scores:
36
+ if L < self.scores[tau].shape[0]:
37
+ return self.scores[tau][L - 1, i]
38
+ return self.compute_score(L, tau)[i]
39
+
40
+ def compute_score(self, L, tau):
41
+ s = np.array([-abs(L / 2 - i) / tau for i in range(L)])
42
+ s = np.exp(s - s.max())
43
+ return s / s.sum()
44
+
45
+ def compute_score_full(self, L, tau):
46
+ s = -abs(np.arange(0, L - 1)[:, None] / 2 - np.arange(L)[None, :]) / tau
47
+ s = np.tril(s, 0) + np.triu(s - float("inf"), 1)
48
+ s = np.exp(s - s.max(1, keepdims=True))
49
+ return s / s.sum(1, keepdims=True)
50
+
51
+
52
+ neg_scorer = NegativeDistanceScore()
53
+
54
+
55
+ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, tau=None):
56
+ try:
57
+ from fairseq import libnat
58
+ except ImportError as e:
59
+ import sys
60
+
61
+ sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
62
+ raise e
63
+
64
+ B = in_tokens.size(0)
65
+ T = in_tokens.size(1)
66
+ V = vocab_size
67
+
68
+ with torch.cuda.device_of(in_tokens):
69
+ in_tokens_list = [
70
+ [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
71
+ ]
72
+ out_tokens_list = [
73
+ [t for t in s if t != padding_idx]
74
+ for i, s in enumerate(out_tokens.tolist())
75
+ ]
76
+
77
+ full_labels = libnat.suggested_ed2_path(
78
+ in_tokens_list, out_tokens_list, padding_idx
79
+ )
80
+ insert_labels = [a[:-1] for a in full_labels]
81
+
82
+ # numericalize1
83
+ insert_label_tensors = in_tokens.new_zeros(B * (T - 1) * V).float()
84
+ insert_index, insert_labels = zip(
85
+ *[
86
+ (w + (j + i * (T - 1)) * V, neg_scorer(k, len(label), tau))
87
+ for i, labels in enumerate(insert_labels)
88
+ for j, label in enumerate(labels[1:-1])
89
+ for k, w in enumerate(label)
90
+ ]
91
+ ) # HACK 1:-1
92
+ insert_index, insert_labels = [
93
+ torch.tensor(list(a), device=in_tokens.device)
94
+ for a in [insert_index, insert_labels]
95
+ ]
96
+ insert_label_tensors.scatter_(0, insert_index.long(), insert_labels)
97
+ insert_label_tensors = insert_label_tensors.view(B, T - 1, V)
98
+
99
+ return insert_label_tensors
100
+
101
+
102
+ def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, padding_idx):
103
+
104
+ padding_masks = in_tokens[:, 1:].eq(padding_idx)
105
+ word_ins_scores.masked_fill_(padding_masks, 0.0)
106
+ word_ins_pred.masked_fill_(padding_masks, padding_idx)
107
+
108
+ in_coords = new_arange(in_tokens).type_as(in_scores)
109
+
110
+ # shift all padding predictions to infinite
111
+ out_coords = (in_coords[:, 1:] - 0.5).masked_fill(
112
+ word_ins_pred.eq(padding_idx), float("inf")
113
+ )
114
+ out_coords = torch.cat([in_coords, out_coords], 1).sort(-1)[1]
115
+ out_tokens = torch.cat([in_tokens, word_ins_pred], 1).gather(1, out_coords)
116
+ out_scores = torch.cat([in_scores, word_ins_scores], 1).gather(1, out_coords)
117
+ return out_tokens, out_scores
118
+
119
+
120
+ @register_model("insertion_transformer")
121
+ class InsertionTransformerModel(LevenshteinTransformerModel):
122
+ def __init__(self, args, encoder, decoder):
123
+ super().__init__(args, encoder, decoder)
124
+
125
+ @staticmethod
126
+ def add_args(parser):
127
+ FairseqNATModel.add_args(parser)
128
+ parser.add_argument("--label-tau", default=None, type=float)
129
+
130
+ @classmethod
131
+ def build_decoder(cls, args, tgt_dict, embed_tokens):
132
+ decoder = InsertionTransformerDecoder(args, tgt_dict, embed_tokens)
133
+ if getattr(args, "apply_bert_init", False):
134
+ decoder.apply(init_bert_params)
135
+ return decoder
136
+
137
+ def forward(
138
+ self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
139
+ ):
140
+
141
+ assert tgt_tokens is not None, "forward function only supports training."
142
+
143
+ # encoding
144
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
145
+
146
+ # generate training labels for insertion
147
+ word_ins_out = self.decoder.forward_word_ins(
148
+ normalize=False,
149
+ prev_output_tokens=prev_output_tokens,
150
+ encoder_out=encoder_out,
151
+ )
152
+
153
+ word_ins_tgt = _get_ins_targets(
154
+ prev_output_tokens,
155
+ tgt_tokens,
156
+ self.pad,
157
+ self.unk,
158
+ len(self.tgt_dict),
159
+ tau=self.decoder.label_tau,
160
+ ).type_as(word_ins_out)
161
+ word_ins_masks = prev_output_tokens[:, 1:].ne(self.pad)
162
+
163
+ return {
164
+ "word_ins": {
165
+ "out": word_ins_out,
166
+ "tgt": word_ins_tgt,
167
+ "mask": word_ins_masks,
168
+ "ls": self.args.label_smoothing,
169
+ "nll_loss": True,
170
+ }
171
+ }
172
+
173
+ def forward_decoder(
174
+ self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
175
+ ):
176
+
177
+ output_tokens = decoder_out.output_tokens
178
+ output_scores = decoder_out.output_scores
179
+ history = decoder_out.history
180
+
181
+ # TODO: decoding for InsertionTransformer
182
+ word_ins_score = self.decoder.forward_word_ins(
183
+ normalize=True, prev_output_tokens=output_tokens, encoder_out=encoder_out
184
+ )
185
+
186
+ if eos_penalty > 0.0:
187
+ word_ins_score[:, :, self.pad] -= eos_penalty
188
+ word_ins_score, word_ins_pred = word_ins_score.max(-1)
189
+ output_tokens, output_scores = _apply_ins_words(
190
+ output_tokens, output_scores, word_ins_pred, word_ins_score, self.pad
191
+ )
192
+
193
+ # delete some unnecessary paddings
194
+ cut_off = output_tokens.ne(self.pad).sum(1).max()
195
+ output_tokens = output_tokens[:, :cut_off]
196
+ output_scores = output_scores[:, :cut_off]
197
+
198
+ if history is not None:
199
+ history.append(output_tokens.clone())
200
+
201
+ return decoder_out._replace(
202
+ output_tokens=output_tokens,
203
+ output_scores=output_scores,
204
+ attn=None,
205
+ history=history,
206
+ )
207
+
208
+
209
+ class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
210
+ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
211
+ # use the TransformerDecoder's __init__
212
+ super(LevenshteinTransformerDecoder, self).__init__(
213
+ args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
214
+ )
215
+
216
+ self.dictionary = dictionary
217
+ self.bos = dictionary.bos()
218
+ self.unk = dictionary.unk()
219
+ self.eos = dictionary.eos()
220
+ self.pool_out = Linear(self.output_embed_dim * 2, self.output_embed_dim)
221
+
222
+ self.label_tau = getattr(args, "label_tau", None)
223
+
224
+ @ensemble_decoder
225
+ def forward_word_ins(self, normalize, encoder_out, prev_output_tokens):
226
+ features = self.extract_features(prev_output_tokens, encoder_out=encoder_out)[0]
227
+ features = self.pool_out(
228
+ torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
229
+ )
230
+ decoder_out = self.output_layer(features)
231
+ return F.log_softmax(decoder_out, -1) if normalize else decoder_out
232
+
233
+ def forward_mask_ins(self, *args, **kwargs):
234
+ raise NotImplementedError
235
+
236
+ def forward_word_del(self, *args, **kwargs):
237
+ raise NotImplementedError
238
+
239
+
240
+ @register_model_architecture("insertion_transformer", "insertion_transformer")
241
+ def insertion_base_architecture(args):
242
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
243
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
244
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
245
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
246
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
247
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
248
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
249
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
250
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
251
+ args.decoder_ffn_embed_dim = getattr(
252
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
253
+ )
254
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
255
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
256
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
257
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
258
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
259
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
260
+ args.activation_fn = getattr(args, "activation_fn", "relu")
261
+ args.dropout = getattr(args, "dropout", 0.1)
262
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
263
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
264
+ args.share_decoder_input_output_embed = getattr(
265
+ args, "share_decoder_input_output_embed", False
266
+ )
267
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
268
+ args.no_token_positional_embeddings = getattr(
269
+ args, "no_token_positional_embeddings", False
270
+ )
271
+ args.adaptive_input = getattr(args, "adaptive_input", False)
272
+ args.apply_bert_init = getattr(args, "apply_bert_init", False)
273
+
274
+ args.decoder_output_dim = getattr(
275
+ args, "decoder_output_dim", args.decoder_embed_dim
276
+ )
277
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
278
+
279
+ # special for insertion transformer
280
+ args.label_tau = getattr(args, "label_tau", None)
fairseq-0.10.2/fairseq/models/nat/iterative_nonautoregressive_transformer.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from fairseq.models import register_model, register_model_architecture
8
+ from fairseq.models.nat import NATransformerModel
9
+
10
+
11
+ def _sequential_poisoning(s, V, beta=0.33, bos=2, eos=3, pad=1):
12
+ # s: input batch
13
+ # V: vocabulary size
14
+ rand_words = torch.randint(low=4, high=V, size=s.size(), device=s.device)
15
+ choices = torch.rand(size=s.size(), device=s.device)
16
+ choices.masked_fill_((s == pad) | (s == bos) | (s == eos), 1)
17
+
18
+ replace = choices < beta / 3
19
+ repeat = (choices >= beta / 3) & (choices < beta * 2 / 3)
20
+ swap = (choices >= beta * 2 / 3) & (choices < beta)
21
+ safe = choices >= beta
22
+
23
+ for i in range(s.size(1) - 1):
24
+ rand_word = rand_words[:, i]
25
+ next_word = s[:, i + 1]
26
+ self_word = s[:, i]
27
+
28
+ replace_i = replace[:, i]
29
+ swap_i = swap[:, i] & (next_word != 3)
30
+ repeat_i = repeat[:, i] & (next_word != 3)
31
+ safe_i = safe[:, i] | ((next_word == 3) & (~replace_i))
32
+
33
+ s[:, i] = (
34
+ self_word * (safe_i | repeat_i).long()
35
+ + next_word * swap_i.long()
36
+ + rand_word * replace_i.long()
37
+ )
38
+ s[:, i + 1] = (
39
+ next_word * (safe_i | replace_i).long()
40
+ + self_word * (swap_i | repeat_i).long()
41
+ )
42
+ return s
43
+
44
+
45
+ def gumbel_noise(input, TINY=1e-8):
46
+ return (
47
+ input.new_zeros(*input.size())
48
+ .uniform_()
49
+ .add_(TINY)
50
+ .log_()
51
+ .neg_()
52
+ .add_(TINY)
53
+ .log_()
54
+ .neg_()
55
+ )
56
+
57
+
58
+ @register_model("iterative_nonautoregressive_transformer")
59
+ class IterNATransformerModel(NATransformerModel):
60
+ @staticmethod
61
+ def add_args(parser):
62
+ NATransformerModel.add_args(parser)
63
+ parser.add_argument(
64
+ "--train-step",
65
+ type=int,
66
+ help="number of refinement iterations during training",
67
+ )
68
+ parser.add_argument(
69
+ "--dae-ratio",
70
+ type=float,
71
+ help="the probability of switching to the denoising auto-encoder loss",
72
+ )
73
+ parser.add_argument(
74
+ "--stochastic-approx",
75
+ action="store_true",
76
+ help="sampling from the decoder as the inputs for next iteration",
77
+ )
78
+
79
+ @classmethod
80
+ def build_model(cls, args, task):
81
+ model = super().build_model(args, task)
82
+ model.train_step = getattr(args, "train_step", 4)
83
+ model.dae_ratio = getattr(args, "dae_ratio", 0.5)
84
+ model.stochastic_approx = getattr(args, "stochastic_approx", False)
85
+ return model
86
+
87
+ def forward(
88
+ self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
89
+ ):
90
+
91
+ B, T = prev_output_tokens.size()
92
+
93
+ # encoding
94
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
95
+
96
+ # length prediction
97
+ length_out = self.decoder.forward_length(
98
+ normalize=False, encoder_out=encoder_out
99
+ )
100
+ length_tgt = self.decoder.forward_length_prediction(
101
+ length_out, encoder_out, tgt_tokens
102
+ )
103
+
104
+ # decoding
105
+ word_ins_outs, word_ins_tgts, word_ins_masks = [], [], []
106
+ for t in range(self.train_step):
107
+ word_ins_out = self.decoder(
108
+ normalize=False,
109
+ prev_output_tokens=prev_output_tokens,
110
+ encoder_out=encoder_out,
111
+ step=t,
112
+ )
113
+ word_ins_tgt = tgt_tokens
114
+ word_ins_mask = word_ins_tgt.ne(self.pad)
115
+
116
+ word_ins_outs.append(word_ins_out)
117
+ word_ins_tgts.append(word_ins_tgt)
118
+ word_ins_masks.append(word_ins_mask)
119
+
120
+ if t < (self.train_step - 1):
121
+ # prediction for next iteration
122
+ if self.stochastic_approx:
123
+ word_ins_prediction = (
124
+ word_ins_out + gumbel_noise(word_ins_out)
125
+ ).max(-1)[1]
126
+ else:
127
+ word_ins_prediction = word_ins_out.max(-1)[1]
128
+
129
+ prev_output_tokens = prev_output_tokens.masked_scatter(
130
+ word_ins_mask, word_ins_prediction[word_ins_mask]
131
+ )
132
+
133
+ if self.dae_ratio > 0:
134
+ # we do not perform denoising for the first iteration
135
+ corrputed = (
136
+ torch.rand(size=(B,), device=prev_output_tokens.device)
137
+ < self.dae_ratio
138
+ )
139
+ corrputed_tokens = _sequential_poisoning(
140
+ tgt_tokens[corrputed],
141
+ len(self.tgt_dict),
142
+ 0.33,
143
+ self.bos,
144
+ self.eos,
145
+ self.pad,
146
+ )
147
+ prev_output_tokens[corrputed] = corrputed_tokens
148
+
149
+ # concat everything
150
+ word_ins_out = torch.cat(word_ins_outs, 0)
151
+ word_ins_tgt = torch.cat(word_ins_tgts, 0)
152
+ word_ins_mask = torch.cat(word_ins_masks, 0)
153
+
154
+ return {
155
+ "word_ins": {
156
+ "out": word_ins_out,
157
+ "tgt": word_ins_tgt,
158
+ "mask": word_ins_mask,
159
+ "ls": self.args.label_smoothing,
160
+ "nll_loss": True,
161
+ },
162
+ "length": {
163
+ "out": length_out,
164
+ "tgt": length_tgt,
165
+ "factor": self.decoder.length_loss_factor,
166
+ },
167
+ }
168
+
169
+
170
+ @register_model_architecture(
171
+ "iterative_nonautoregressive_transformer", "iterative_nonautoregressive_transformer"
172
+ )
173
+ def inat_base_architecture(args):
174
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
175
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
176
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
177
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
178
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
179
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
180
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
181
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
182
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
183
+ args.decoder_ffn_embed_dim = getattr(
184
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
185
+ )
186
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
187
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
188
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
189
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
190
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
191
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
192
+ args.activation_fn = getattr(args, "activation_fn", "relu")
193
+ args.dropout = getattr(args, "dropout", 0.1)
194
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
195
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
196
+ args.share_decoder_input_output_embed = getattr(
197
+ args, "share_decoder_input_output_embed", False
198
+ )
199
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
200
+ args.no_token_positional_embeddings = getattr(
201
+ args, "no_token_positional_embeddings", False
202
+ )
203
+ args.adaptive_input = getattr(args, "adaptive_input", False)
204
+ args.apply_bert_init = getattr(args, "apply_bert_init", False)
205
+
206
+ args.decoder_output_dim = getattr(
207
+ args, "decoder_output_dim", args.decoder_embed_dim
208
+ )
209
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
210
+
211
+ # --- special arguments ---
212
+ args.sg_length_pred = getattr(args, "sg_length_pred", False)
213
+ args.pred_length_offset = getattr(args, "pred_length_offset", False)
214
+ args.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
215
+ args.ngram_predictor = getattr(args, "ngram_predictor", 1)
216
+ args.src_embedding_copy = getattr(args, "src_embedding_copy", False)
217
+
218
+ args.train_step = getattr(args, "train_step", 4)
219
+ args.dae_ratio = getattr(args, "dae_ratio", 0.5)
220
+ args.stochastic_approx = getattr(args, "stochastic_approx", False)
221
+
222
+
223
+ @register_model_architecture(
224
+ "iterative_nonautoregressive_transformer",
225
+ "iterative_nonautoregressive_transformer_wmt_en_de",
226
+ )
227
+ def iter_nat_wmt_en_de(args):
228
+ inat_base_architecture(args)
fairseq-0.10.2/fairseq/models/nat/levenshtein_utils.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from fairseq.utils import new_arange
8
+
9
+
10
+ # -------------- Helper Functions --------------------------------------------------- #
11
+
12
+
13
+ def load_libnat():
14
+ try:
15
+ from fairseq import libnat_cuda
16
+
17
+ return libnat_cuda, True
18
+
19
+ except ImportError as e:
20
+ print(str(e) + "... fall back to CPU version")
21
+
22
+ try:
23
+ from fairseq import libnat
24
+
25
+ return libnat, False
26
+
27
+ except ImportError as e:
28
+ import sys
29
+
30
+ sys.stderr.write(
31
+ "ERROR: missing libnat_cuda. run `python setup.py build_ext --inplace`\n"
32
+ )
33
+ raise e
34
+
35
+
36
+ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
37
+ libnat, use_cuda = load_libnat()
38
+
39
+ def _get_ins_targets_cuda(in_tokens, out_tokens, padding_idx, unk_idx):
40
+ in_masks = in_tokens.ne(padding_idx)
41
+ out_masks = out_tokens.ne(padding_idx)
42
+ mask_ins_targets, masked_tgt_masks = libnat.generate_insertion_labels(
43
+ out_tokens.int(),
44
+ libnat.levenshtein_distance(
45
+ in_tokens.int(),
46
+ out_tokens.int(),
47
+ in_masks.sum(1).int(),
48
+ out_masks.sum(1).int(),
49
+ ),
50
+ )
51
+ masked_tgt_masks = masked_tgt_masks.bool() & out_masks
52
+ mask_ins_targets = mask_ins_targets.type_as(in_tokens)[
53
+ :, 1 : in_masks.size(1)
54
+ ].masked_fill_(~in_masks[:, 1:], 0)
55
+ masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
56
+ return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
57
+
58
+ def _get_ins_targets_cpu(in_tokens, out_tokens, padding_idx, unk_idx):
59
+ in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
60
+
61
+ in_tokens_list = [
62
+ [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
63
+ ]
64
+ out_tokens_list = [
65
+ [t for t in s if t != padding_idx]
66
+ for i, s in enumerate(out_tokens.tolist())
67
+ ]
68
+
69
+ full_labels = libnat.suggested_ed2_path(
70
+ in_tokens_list, out_tokens_list, padding_idx
71
+ )
72
+ mask_inputs = [
73
+ [len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels
74
+ ]
75
+
76
+ # generate labels
77
+ masked_tgt_masks = []
78
+ for mask_input in mask_inputs:
79
+ mask_label = []
80
+ for beam_size in mask_input[1:-1]: # HACK 1:-1
81
+ mask_label += [0] + [1 for _ in range(beam_size)]
82
+ masked_tgt_masks.append(
83
+ mask_label + [0 for _ in range(out_seq_len - len(mask_label))]
84
+ )
85
+ mask_ins_targets = [
86
+ mask_input[1:-1]
87
+ + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))]
88
+ for mask_input in mask_inputs
89
+ ]
90
+
91
+ # transform to tensor
92
+ masked_tgt_masks = torch.tensor(
93
+ masked_tgt_masks, device=out_tokens.device
94
+ ).bool()
95
+ mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
96
+ masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
97
+ return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
98
+
99
+ if use_cuda:
100
+ return _get_ins_targets_cuda(in_tokens, out_tokens, padding_idx, unk_idx)
101
+ return _get_ins_targets_cpu(in_tokens, out_tokens, padding_idx, unk_idx)
102
+
103
+
104
+ def _get_del_targets(in_tokens, out_tokens, padding_idx):
105
+ libnat, use_cuda = load_libnat()
106
+
107
+ def _get_del_targets_cuda(in_tokens, out_tokens, padding_idx):
108
+ in_masks = in_tokens.ne(padding_idx)
109
+ out_masks = out_tokens.ne(padding_idx)
110
+
111
+ word_del_targets = libnat.generate_deletion_labels(
112
+ in_tokens.int(),
113
+ libnat.levenshtein_distance(
114
+ in_tokens.int(),
115
+ out_tokens.int(),
116
+ in_masks.sum(1).int(),
117
+ out_masks.sum(1).int(),
118
+ ),
119
+ )
120
+ word_del_targets = word_del_targets.type_as(in_tokens).masked_fill_(
121
+ ~in_masks, 0
122
+ )
123
+ return word_del_targets
124
+
125
+ def _get_del_targets_cpu(in_tokens, out_tokens, padding_idx):
126
+ out_seq_len = out_tokens.size(1)
127
+ with torch.cuda.device_of(in_tokens):
128
+ in_tokens_list = [
129
+ [t for t in s if t != padding_idx]
130
+ for i, s in enumerate(in_tokens.tolist())
131
+ ]
132
+ out_tokens_list = [
133
+ [t for t in s if t != padding_idx]
134
+ for i, s in enumerate(out_tokens.tolist())
135
+ ]
136
+
137
+ full_labels = libnat.suggested_ed2_path(
138
+ in_tokens_list, out_tokens_list, padding_idx
139
+ )
140
+ word_del_targets = [b[-1] for b in full_labels]
141
+ word_del_targets = [
142
+ labels + [0 for _ in range(out_seq_len - len(labels))]
143
+ for labels in word_del_targets
144
+ ]
145
+
146
+ # transform to tensor
147
+ word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
148
+ return word_del_targets
149
+
150
+ if use_cuda:
151
+ return _get_del_targets_cuda(in_tokens, out_tokens, padding_idx)
152
+ return _get_del_targets_cpu(in_tokens, out_tokens, padding_idx)
153
+
154
+
155
+ def _apply_ins_masks(
156
+ in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
157
+ ):
158
+
159
+ in_masks = in_tokens.ne(padding_idx)
160
+ in_lengths = in_masks.sum(1)
161
+
162
+ # HACK: hacky way to shift all the paddings to eos first.
163
+ in_tokens.masked_fill_(~in_masks, eos_idx)
164
+ mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
165
+
166
+ out_lengths = in_lengths + mask_ins_pred.sum(1)
167
+ out_max_len = out_lengths.max()
168
+ out_masks = new_arange(out_lengths, out_max_len)[None, :] < out_lengths[:, None]
169
+
170
+ reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
171
+ out_tokens = (
172
+ in_tokens.new_zeros(in_tokens.size(0), out_max_len)
173
+ .fill_(padding_idx)
174
+ .masked_fill_(out_masks, unk_idx)
175
+ )
176
+ out_tokens[:, 0] = in_tokens[:, 0]
177
+ out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
178
+
179
+ out_scores = None
180
+ if in_scores is not None:
181
+ in_scores.masked_fill_(~in_masks, 0)
182
+ out_scores = in_scores.new_zeros(*out_tokens.size())
183
+ out_scores[:, 0] = in_scores[:, 0]
184
+ out_scores.scatter_(1, reordering, in_scores[:, 1:])
185
+
186
+ return out_tokens, out_scores
187
+
188
+
189
+ def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx):
190
+ word_ins_masks = in_tokens.eq(unk_idx)
191
+ out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
192
+
193
+ if in_scores is not None:
194
+ out_scores = in_scores.masked_scatter(
195
+ word_ins_masks, word_ins_scores[word_ins_masks]
196
+ )
197
+ else:
198
+ out_scores = None
199
+
200
+ return out_tokens, out_scores
201
+
202
+
203
+ def _apply_del_words(
204
+ in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
205
+ ):
206
+ # apply deletion to a tensor
207
+ in_masks = in_tokens.ne(padding_idx)
208
+ bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
209
+
210
+ max_len = in_tokens.size(1)
211
+ word_del_pred.masked_fill_(~in_masks, 1)
212
+ word_del_pred.masked_fill_(bos_eos_masks, 0)
213
+
214
+ reordering = new_arange(in_tokens).masked_fill_(word_del_pred, max_len).sort(1)[1]
215
+
216
+ out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
217
+
218
+ out_scores = None
219
+ if in_scores is not None:
220
+ out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
221
+
222
+ out_attn = None
223
+ if in_attn is not None:
224
+ _mask = word_del_pred[:, :, None].expand_as(in_attn)
225
+ _reordering = reordering[:, :, None].expand_as(in_attn)
226
+ out_attn = in_attn.masked_fill(_mask, 0.0).gather(1, _reordering)
227
+
228
+ return out_tokens, out_scores, out_attn
229
+
230
+
231
+ def _skip(x, mask):
232
+ """
233
+ Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors.
234
+ """
235
+ if isinstance(x, int):
236
+ return x
237
+
238
+ if x is None:
239
+ return None
240
+
241
+ if isinstance(x, torch.Tensor):
242
+ if x.size(0) == mask.size(0):
243
+ return x[mask]
244
+ elif x.size(1) == mask.size(0):
245
+ return x[:, mask]
246
+
247
+ if isinstance(x, list):
248
+ return [_skip(x_i, mask) for x_i in x]
249
+
250
+ if isinstance(x, dict):
251
+ return {k: _skip(v, mask) for k, v in x.items()}
252
+
253
+ raise NotImplementedError
254
+
255
+
256
+ def _skip_encoder_out(encoder, encoder_out, mask):
257
+ if not mask.any():
258
+ return encoder_out
259
+ else:
260
+ return encoder.reorder_encoder_out(
261
+ encoder_out, mask.nonzero(as_tuple=False).squeeze()
262
+ )
263
+
264
+
265
+ def _fill(x, mask, y, padding_idx):
266
+ """
267
+ Filling tensor x with y at masked positions (dim=0).
268
+ """
269
+ if x is None:
270
+ return y
271
+ assert x.dim() == y.dim() and mask.size(0) == x.size(0)
272
+ assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
273
+ n_selected = mask.sum()
274
+ assert n_selected == y.size(0)
275
+
276
+ if n_selected == x.size(0):
277
+ return y
278
+
279
+ if x.size(1) < y.size(1):
280
+ dims = [x.size(0), y.size(1) - x.size(1)]
281
+ if x.dim() == 3:
282
+ dims.append(x.size(2))
283
+ x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1)
284
+ x[mask] = y
285
+ elif x.size(1) > y.size(1):
286
+ x[mask] = padding_idx
287
+ if x.dim() == 2:
288
+ x[mask, : y.size(1)] = y
289
+ else:
290
+ x[mask, : y.size(1), :] = y
291
+ else:
292
+ x[mask] = y
293
+ return x
fairseq-0.10.2/fairseq/models/nat/nat_crf_transformer.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from fairseq.models import register_model, register_model_architecture
8
+ from fairseq.models.nat import NATransformerModel, base_architecture
9
+ from fairseq.modules import DynamicCRF
10
+
11
+
12
+ @register_model("nacrf_transformer")
13
+ class NACRFTransformerModel(NATransformerModel):
14
+ def __init__(self, args, encoder, decoder):
15
+ super().__init__(args, encoder, decoder)
16
+ self.crf_layer = DynamicCRF(
17
+ num_embedding=len(self.tgt_dict),
18
+ low_rank=args.crf_lowrank_approx,
19
+ beam_size=args.crf_beam_approx,
20
+ )
21
+
22
+ @property
23
+ def allow_ensemble(self):
24
+ return False
25
+
26
+ @staticmethod
27
+ def add_args(parser):
28
+ NATransformerModel.add_args(parser)
29
+ parser.add_argument(
30
+ "--crf-lowrank-approx",
31
+ type=int,
32
+ help="the dimension of low-rank approximation of transition",
33
+ )
34
+ parser.add_argument(
35
+ "--crf-beam-approx",
36
+ type=int,
37
+ help="the beam size for apporixmating the normalizing factor",
38
+ )
39
+ parser.add_argument(
40
+ "--word-ins-loss-factor",
41
+ type=float,
42
+ help="weights on NAT loss used to co-training with CRF loss.",
43
+ )
44
+
45
+ def forward(
46
+ self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
47
+ ):
48
+ # encoding
49
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
50
+
51
+ # length prediction
52
+ length_out = self.decoder.forward_length(
53
+ normalize=False, encoder_out=encoder_out
54
+ )
55
+ length_tgt = self.decoder.forward_length_prediction(
56
+ length_out, encoder_out, tgt_tokens
57
+ )
58
+
59
+ # decoding
60
+ word_ins_out = self.decoder(
61
+ normalize=False,
62
+ prev_output_tokens=prev_output_tokens,
63
+ encoder_out=encoder_out,
64
+ )
65
+ word_ins_tgt, word_ins_mask = tgt_tokens, tgt_tokens.ne(self.pad)
66
+
67
+ # compute the log-likelihood of CRF
68
+ crf_nll = -self.crf_layer(word_ins_out, word_ins_tgt, word_ins_mask)
69
+ crf_nll = (crf_nll / word_ins_mask.type_as(crf_nll).sum(-1)).mean()
70
+
71
+ return {
72
+ "word_ins": {
73
+ "out": word_ins_out,
74
+ "tgt": word_ins_tgt,
75
+ "mask": word_ins_mask,
76
+ "ls": self.args.label_smoothing,
77
+ "nll_loss": True,
78
+ "factor": self.args.word_ins_loss_factor,
79
+ },
80
+ "word_crf": {"loss": crf_nll},
81
+ "length": {
82
+ "out": length_out,
83
+ "tgt": length_tgt,
84
+ "factor": self.decoder.length_loss_factor,
85
+ },
86
+ }
87
+
88
+ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
89
+ output_tokens = decoder_out.output_tokens
90
+ output_scores = decoder_out.output_scores
91
+ history = decoder_out.history
92
+
93
+ # execute the decoder and get emission scores
94
+ output_masks = output_tokens.ne(self.pad)
95
+ word_ins_out = self.decoder(
96
+ normalize=False, prev_output_tokens=output_tokens, encoder_out=encoder_out
97
+ )
98
+
99
+ # run viterbi decoding through CRF
100
+ _scores, _tokens = self.crf_layer.forward_decoder(word_ins_out, output_masks)
101
+ output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
102
+ output_scores.masked_scatter_(output_masks, _scores[output_masks])
103
+ if history is not None:
104
+ history.append(output_tokens.clone())
105
+
106
+ return decoder_out._replace(
107
+ output_tokens=output_tokens,
108
+ output_scores=output_scores,
109
+ attn=None,
110
+ history=history,
111
+ )
112
+
113
+
114
+ @register_model_architecture("nacrf_transformer", "nacrf_transformer")
115
+ def nacrf_base_architecture(args):
116
+ args.crf_lowrank_approx = getattr(args, "crf_lowrank_approx", 32)
117
+ args.crf_beam_approx = getattr(args, "crf_beam_approx", 64)
118
+ args.word_ins_loss_factor = getattr(args, "word_ins_loss_factor", 0.5)
119
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
120
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
121
+ base_architecture(args)
fairseq-0.10.2/fairseq/models/nat/nonautoregressive_ensembles.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from fairseq.models.nat import (
11
+ _apply_del_words,
12
+ _apply_ins_masks,
13
+ _apply_ins_words,
14
+ _fill,
15
+ _skip,
16
+ _skip_encoder_out,
17
+ )
18
+
19
+
20
+ class _EnsembleModelEncoder(object):
21
+ def __init__(self, models):
22
+ self.models = models
23
+
24
+ def reorder_encoder_out(self, encoder_outs, new_order):
25
+ encoder_outs = [
26
+ model.encoder.reorder_encoder_out(encoder_out, new_order)
27
+ for model, encoder_out in zip(self.models, encoder_outs)
28
+ ]
29
+ return encoder_outs
30
+
31
+
32
+ class BasicEnsembleModel(torch.nn.Module):
33
+ """A wrapper around an ensemble of models."""
34
+
35
+ def __init__(self, models):
36
+ super().__init__()
37
+ self.models = torch.nn.ModuleList(models)
38
+ self.bos = self.models[0].decoder.dictionary.bos()
39
+ self.eos = self.models[0].decoder.dictionary.eos()
40
+ self.pad = self.models[0].decoder.dictionary.pad()
41
+ self.unk = self.models[0].decoder.dictionary.unk()
42
+ self.encoder = _EnsembleModelEncoder(self.models)
43
+
44
+ def has_encoder(self):
45
+ return hasattr(self.models[0], "encoder")
46
+
47
+ def max_decoder_positions(self):
48
+ return min(m.max_decoder_positions() for m in self.models)
49
+
50
+ @torch.no_grad()
51
+ def forward_encoder(self, encoder_input):
52
+ if not self.has_encoder():
53
+ return None
54
+ return [model.forward_encoder(encoder_input) for model in self.models]
55
+
56
+ @torch.no_grad()
57
+ def forward_decoder(self, *inputs):
58
+ raise NotImplementedError
59
+
60
+ def initialize_output_tokens(self, *inputs):
61
+ raise NotImplementedError
62
+
63
+
64
+ class EnsembleLevT(BasicEnsembleModel):
65
+ """A wrapper around an ensemble of models."""
66
+
67
+ def __init__(self, models):
68
+ super().__init__(models)
69
+
70
+ @torch.no_grad()
71
+ def forward_decoder(
72
+ self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=None, **kwargs
73
+ ):
74
+ # LevT ensembling
75
+ # A pipeline of three steps: deletion, placeholder, and word insertion.
76
+ # We need to average scores in each step in a pipeline way because of dependence.
77
+ # deletion
78
+ output_tokens = decoder_out.output_tokens
79
+ output_scores = decoder_out.output_scores
80
+ attn = decoder_out.attn
81
+
82
+ bsz = output_tokens.size(0)
83
+ if max_ratio is None:
84
+ max_lens = output_tokens.new().fill_(255)
85
+ else:
86
+ if encoder_outs[0].encoder_padding_mask is None:
87
+ src_lens = (
88
+ encoder_outs[0]
89
+ .encoder_out.new(bsz)
90
+ .fill_(encoder_outs[0].encoder_out.size(1))
91
+ )
92
+ else:
93
+ src_lens = (~encoder_outs[0].encoder_padding_mask).sum(1)
94
+ max_lens = (src_lens * max_ratio).clamp(min=10).long()
95
+
96
+ # delete words
97
+ # do not delete tokens if it is <s> </s>
98
+ can_del_word = output_tokens.ne(self.pad).sum(1) > 2
99
+ if can_del_word.sum() != 0: # we cannot delete, skip
100
+ output_tokens, output_scores, attn = self.forward_word_del(
101
+ encoder_outs,
102
+ output_tokens,
103
+ output_scores,
104
+ attn,
105
+ can_del_word,
106
+ )
107
+
108
+ # insert placeholders
109
+ can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens
110
+ if can_ins_mask.sum() != 0:
111
+ output_tokens, output_scores = self.forward_mask_ins(
112
+ encoder_outs,
113
+ output_tokens,
114
+ output_scores,
115
+ can_ins_mask,
116
+ eos_penalty,
117
+ max_lens,
118
+ )
119
+
120
+ # insert words
121
+ can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
122
+ if can_ins_word.sum() != 0:
123
+ output_tokens, output_scores, attn = self.forward_word_ins(
124
+ encoder_outs,
125
+ output_tokens,
126
+ output_scores,
127
+ attn,
128
+ can_ins_word,
129
+ )
130
+
131
+ # delete some unnecessary paddings
132
+ cut_off = output_tokens.ne(self.pad).sum(1).max()
133
+ output_tokens = output_tokens[:, :cut_off]
134
+ output_scores = output_scores[:, :cut_off]
135
+ attn = None if attn is None else attn[:, :cut_off, :]
136
+ return decoder_out._replace(
137
+ output_tokens=output_tokens,
138
+ output_scores=output_scores,
139
+ attn=attn,
140
+ history=None,
141
+ )
142
+
143
+ def forward_word_del(
144
+ self, encoder_outs, output_tokens, output_scores, attn, can_del_word
145
+ ):
146
+ word_del_score_avg = []
147
+ word_del_attn_avg = []
148
+ for model, encoder_out in zip(self.models, encoder_outs):
149
+ word_del_out, word_del_attn = model.decoder.forward_word_del(
150
+ _skip(output_tokens, can_del_word),
151
+ _skip_encoder_out(model.encoder, encoder_out, can_del_word),
152
+ )
153
+ word_del_score = F.log_softmax(word_del_out, 2)
154
+ word_del_score_avg.append(word_del_score)
155
+ word_del_attn_avg.append(word_del_attn)
156
+ word_del_score_avg = torch.logsumexp(
157
+ torch.stack(word_del_score_avg, dim=0), dim=0
158
+ ) - math.log(len(self.models))
159
+ word_del_pred = word_del_score_avg.max(-1)[1].bool()
160
+ if word_del_attn_avg[0] is not None:
161
+ word_del_attn_avg = torch.stack(word_del_attn_avg, dim=0) / len(self.models)
162
+ else:
163
+ word_del_attn_avg = None
164
+
165
+ _tokens, _scores, _attn = _apply_del_words(
166
+ output_tokens[can_del_word],
167
+ output_scores[can_del_word],
168
+ word_del_attn_avg,
169
+ word_del_pred,
170
+ self.pad,
171
+ self.bos,
172
+ self.eos,
173
+ )
174
+ output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad)
175
+ output_scores = _fill(output_scores, can_del_word, _scores, 0)
176
+ attn = _fill(attn, can_del_word, _attn, 0.0)
177
+ return output_tokens, output_scores, attn
178
+
179
+ def forward_mask_ins(
180
+ self,
181
+ encoder_outs,
182
+ output_tokens,
183
+ output_scores,
184
+ can_ins_mask,
185
+ eos_penalty,
186
+ max_lens,
187
+ ):
188
+ mask_ins_score_avg = []
189
+ for model, encoder_out in zip(self.models, encoder_outs):
190
+ mask_ins_out, _ = model.decoder.forward_mask_ins(
191
+ _skip(output_tokens, can_ins_mask),
192
+ _skip_encoder_out(model.encoder, encoder_out, can_ins_mask),
193
+ )
194
+ mask_ins_score = F.log_softmax(mask_ins_out, 2)
195
+ if eos_penalty > 0.0:
196
+ mask_ins_score[:, :, 0] -= eos_penalty
197
+ mask_ins_score_avg.append(mask_ins_score)
198
+ mask_ins_score_avg = torch.logsumexp(
199
+ torch.stack(mask_ins_score_avg, dim=0), dim=0
200
+ ) - math.log(len(self.models))
201
+ mask_ins_pred = mask_ins_score_avg.max(-1)[1]
202
+ mask_ins_pred = torch.min(
203
+ mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred)
204
+ )
205
+ _tokens, _scores = _apply_ins_masks(
206
+ output_tokens[can_ins_mask],
207
+ output_scores[can_ins_mask],
208
+ mask_ins_pred,
209
+ self.pad,
210
+ self.unk,
211
+ self.eos,
212
+ )
213
+ output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad)
214
+ output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
215
+ return output_tokens, output_scores
216
+
217
+ def forward_word_ins(
218
+ self, encoder_outs, output_tokens, output_scores, attn, can_ins_word
219
+ ):
220
+ word_ins_score_avg = []
221
+ word_ins_attn_avg = []
222
+ for model, encoder_out in zip(self.models, encoder_outs):
223
+ word_ins_out, word_ins_attn = model.decoder.forward_word_ins(
224
+ _skip(output_tokens, can_ins_word),
225
+ _skip_encoder_out(model.encoder, encoder_out, can_ins_word),
226
+ )
227
+ word_ins_score = F.log_softmax(word_ins_out, 2)
228
+ word_ins_score_avg.append(word_ins_score)
229
+ word_ins_attn_avg.append(word_ins_attn)
230
+ word_ins_score_avg = torch.logsumexp(
231
+ torch.stack(word_ins_score_avg, dim=0), dim=0
232
+ ) - math.log(len(self.models))
233
+ if word_ins_attn_avg[0] is not None:
234
+ word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0) / len(self.models)
235
+ else:
236
+ word_ins_attn_avg = None
237
+ word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1)
238
+
239
+ _tokens, _scores = _apply_ins_words(
240
+ output_tokens[can_ins_word],
241
+ output_scores[can_ins_word],
242
+ word_ins_pred,
243
+ word_ins_score_max,
244
+ self.unk,
245
+ )
246
+
247
+ output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad)
248
+ output_scores = _fill(output_scores, can_ins_word, _scores, 0)
249
+ attn = _fill(attn, can_ins_word, word_ins_attn, 0.0)
250
+ return output_tokens, output_scores, attn
251
+
252
+ def initialize_output_tokens(self, encoder_outs, src_tokens):
253
+ # LevT doesn't do length prediction.
254
+ return self.models[0].initialize_output_tokens(encoder_outs[0], src_tokens)
fairseq-0.10.2/fairseq/models/nat/nonautoregressive_transformer.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from fairseq import utils
9
+ from fairseq.iterative_refinement_generator import DecoderOut
10
+ from fairseq.models import register_model, register_model_architecture
11
+ from fairseq.models.nat import FairseqNATDecoder, FairseqNATModel, ensemble_decoder
12
+ from fairseq.models.transformer import Embedding
13
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
14
+
15
+
16
+ def _mean_pooling(enc_feats, src_masks):
17
+ # enc_feats: T x B x C
18
+ # src_masks: B x T or None
19
+ if src_masks is None:
20
+ enc_feats = enc_feats.mean(0)
21
+ else:
22
+ src_masks = (~src_masks).transpose(0, 1).type_as(enc_feats)
23
+ enc_feats = (
24
+ (enc_feats / src_masks.sum(0)[None, :, None]) * src_masks[:, :, None]
25
+ ).sum(0)
26
+ return enc_feats
27
+
28
+
29
+ def _argmax(x, dim):
30
+ return (x == x.max(dim, keepdim=True)[0]).type_as(x)
31
+
32
+
33
+ def _uniform_assignment(src_lens, trg_lens):
34
+ max_trg_len = trg_lens.max()
35
+ steps = (src_lens.float() - 1) / (trg_lens.float() - 1) # step-size
36
+ # max_trg_len
37
+ index_t = utils.new_arange(trg_lens, max_trg_len).float()
38
+ index_t = steps[:, None] * index_t[None, :] # batch_size X max_trg_len
39
+ index_t = torch.round(index_t).long().detach()
40
+ return index_t
41
+
42
+
43
+ @register_model("nonautoregressive_transformer")
44
+ class NATransformerModel(FairseqNATModel):
45
+ @property
46
+ def allow_length_beam(self):
47
+ return True
48
+
49
+ @staticmethod
50
+ def add_args(parser):
51
+ FairseqNATModel.add_args(parser)
52
+
53
+ # length prediction
54
+ parser.add_argument(
55
+ "--src-embedding-copy",
56
+ action="store_true",
57
+ help="copy encoder word embeddings as the initial input of the decoder",
58
+ )
59
+ parser.add_argument(
60
+ "--pred-length-offset",
61
+ action="store_true",
62
+ help="predicting the length difference between the target and source sentences",
63
+ )
64
+ parser.add_argument(
65
+ "--sg-length-pred",
66
+ action="store_true",
67
+ help="stop the gradients back-propagated from the length predictor",
68
+ )
69
+ parser.add_argument(
70
+ "--length-loss-factor",
71
+ type=float,
72
+ help="weights on the length prediction loss",
73
+ )
74
+
75
+ @classmethod
76
+ def build_decoder(cls, args, tgt_dict, embed_tokens):
77
+ decoder = NATransformerDecoder(args, tgt_dict, embed_tokens)
78
+ if getattr(args, "apply_bert_init", False):
79
+ decoder.apply(init_bert_params)
80
+ return decoder
81
+
82
+ def forward(
83
+ self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
84
+ ):
85
+ # encoding
86
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
87
+
88
+ # length prediction
89
+ length_out = self.decoder.forward_length(
90
+ normalize=False, encoder_out=encoder_out
91
+ )
92
+ length_tgt = self.decoder.forward_length_prediction(
93
+ length_out, encoder_out, tgt_tokens
94
+ )
95
+
96
+ # decoding
97
+ word_ins_out = self.decoder(
98
+ normalize=False,
99
+ prev_output_tokens=prev_output_tokens,
100
+ encoder_out=encoder_out,
101
+ )
102
+
103
+ return {
104
+ "word_ins": {
105
+ "out": word_ins_out,
106
+ "tgt": tgt_tokens,
107
+ "mask": tgt_tokens.ne(self.pad),
108
+ "ls": self.args.label_smoothing,
109
+ "nll_loss": True,
110
+ },
111
+ "length": {
112
+ "out": length_out,
113
+ "tgt": length_tgt,
114
+ "factor": self.decoder.length_loss_factor,
115
+ },
116
+ }
117
+
118
+ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
119
+ step = decoder_out.step
120
+ output_tokens = decoder_out.output_tokens
121
+ output_scores = decoder_out.output_scores
122
+ history = decoder_out.history
123
+
124
+ # execute the decoder
125
+ output_masks = output_tokens.ne(self.pad)
126
+ _scores, _tokens = self.decoder(
127
+ normalize=True,
128
+ prev_output_tokens=output_tokens,
129
+ encoder_out=encoder_out,
130
+ step=step,
131
+ ).max(-1)
132
+
133
+ output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
134
+ output_scores.masked_scatter_(output_masks, _scores[output_masks])
135
+ if history is not None:
136
+ history.append(output_tokens.clone())
137
+
138
+ return decoder_out._replace(
139
+ output_tokens=output_tokens,
140
+ output_scores=output_scores,
141
+ attn=None,
142
+ history=history,
143
+ )
144
+
145
+ def initialize_output_tokens(self, encoder_out, src_tokens):
146
+ # length prediction
147
+ length_tgt = self.decoder.forward_length_prediction(
148
+ self.decoder.forward_length(normalize=True, encoder_out=encoder_out),
149
+ encoder_out=encoder_out,
150
+ )
151
+
152
+ max_length = length_tgt.clamp_(min=2).max()
153
+ idx_length = utils.new_arange(src_tokens, max_length)
154
+
155
+ initial_output_tokens = src_tokens.new_zeros(
156
+ src_tokens.size(0), max_length
157
+ ).fill_(self.pad)
158
+ initial_output_tokens.masked_fill_(
159
+ idx_length[None, :] < length_tgt[:, None], self.unk
160
+ )
161
+ initial_output_tokens[:, 0] = self.bos
162
+ initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)
163
+
164
+ initial_output_scores = initial_output_tokens.new_zeros(
165
+ *initial_output_tokens.size()
166
+ ).type_as(encoder_out.encoder_out)
167
+
168
+ return DecoderOut(
169
+ output_tokens=initial_output_tokens,
170
+ output_scores=initial_output_scores,
171
+ attn=None,
172
+ step=0,
173
+ max_step=0,
174
+ history=None,
175
+ )
176
+
177
+ def regenerate_length_beam(self, decoder_out, beam_size):
178
+ output_tokens = decoder_out.output_tokens
179
+ length_tgt = output_tokens.ne(self.pad).sum(1)
180
+ length_tgt = (
181
+ length_tgt[:, None]
182
+ + utils.new_arange(length_tgt, 1, beam_size)
183
+ - beam_size // 2
184
+ )
185
+ length_tgt = length_tgt.view(-1).clamp_(min=2)
186
+ max_length = length_tgt.max()
187
+ idx_length = utils.new_arange(length_tgt, max_length)
188
+
189
+ initial_output_tokens = output_tokens.new_zeros(
190
+ length_tgt.size(0), max_length
191
+ ).fill_(self.pad)
192
+ initial_output_tokens.masked_fill_(
193
+ idx_length[None, :] < length_tgt[:, None], self.unk
194
+ )
195
+ initial_output_tokens[:, 0] = self.bos
196
+ initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)
197
+
198
+ initial_output_scores = initial_output_tokens.new_zeros(
199
+ *initial_output_tokens.size()
200
+ ).type_as(decoder_out.output_scores)
201
+
202
+ return decoder_out._replace(
203
+ output_tokens=initial_output_tokens, output_scores=initial_output_scores
204
+ )
205
+
206
+
207
+ class NATransformerDecoder(FairseqNATDecoder):
208
+ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
209
+ super().__init__(
210
+ args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
211
+ )
212
+ self.dictionary = dictionary
213
+ self.bos = dictionary.bos()
214
+ self.unk = dictionary.unk()
215
+ self.eos = dictionary.eos()
216
+
217
+ self.encoder_embed_dim = args.encoder_embed_dim
218
+ self.sg_length_pred = getattr(args, "sg_length_pred", False)
219
+ self.pred_length_offset = getattr(args, "pred_length_offset", False)
220
+ self.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
221
+ self.src_embedding_copy = getattr(args, "src_embedding_copy", False)
222
+ self.embed_length = Embedding(256, self.encoder_embed_dim, None)
223
+
224
+ @ensemble_decoder
225
+ def forward(self, normalize, encoder_out, prev_output_tokens, step=0, **unused):
226
+ features, _ = self.extract_features(
227
+ prev_output_tokens,
228
+ encoder_out=encoder_out,
229
+ embedding_copy=(step == 0) & self.src_embedding_copy,
230
+ )
231
+ decoder_out = self.output_layer(features)
232
+ return F.log_softmax(decoder_out, -1) if normalize else decoder_out
233
+
234
+ @ensemble_decoder
235
+ def forward_length(self, normalize, encoder_out):
236
+ enc_feats = encoder_out.encoder_out # T x B x C
237
+ src_masks = encoder_out.encoder_padding_mask # B x T or None
238
+ enc_feats = _mean_pooling(enc_feats, src_masks)
239
+ if self.sg_length_pred:
240
+ enc_feats = enc_feats.detach()
241
+ length_out = F.linear(enc_feats, self.embed_length.weight)
242
+ return F.log_softmax(length_out, -1) if normalize else length_out
243
+
244
+ def extract_features(
245
+ self,
246
+ prev_output_tokens,
247
+ encoder_out=None,
248
+ early_exit=None,
249
+ embedding_copy=False,
250
+ **unused
251
+ ):
252
+ """
253
+ Similar to *forward* but only return features.
254
+
255
+ Inputs:
256
+ prev_output_tokens: Tensor(B, T)
257
+ encoder_out: a dictionary of hidden states and masks
258
+
259
+ Returns:
260
+ tuple:
261
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
262
+ - a dictionary with any model-specific outputs
263
+ the LevenshteinTransformer decoder has full-attention to all generated tokens
264
+ """
265
+ # embedding
266
+ if embedding_copy:
267
+ src_embd = encoder_out.encoder_embedding
268
+ src_mask = encoder_out.encoder_padding_mask
269
+ src_mask = (
270
+ ~src_mask
271
+ if src_mask is not None
272
+ else prev_output_tokens.new_ones(*src_embd.size()[:2]).bool()
273
+ )
274
+
275
+ x, decoder_padding_mask = self.forward_embedding(
276
+ prev_output_tokens,
277
+ self.forward_copying_source(
278
+ src_embd, src_mask, prev_output_tokens.ne(self.padding_idx)
279
+ ),
280
+ )
281
+
282
+ else:
283
+
284
+ x, decoder_padding_mask = self.forward_embedding(prev_output_tokens)
285
+
286
+ # B x T x C -> T x B x C
287
+ x = x.transpose(0, 1)
288
+ attn = None
289
+ inner_states = [x]
290
+
291
+ # decoder layers
292
+ for i, layer in enumerate(self.layers):
293
+
294
+ # early exit from the decoder.
295
+ if (early_exit is not None) and (i >= early_exit):
296
+ break
297
+
298
+ x, attn, _ = layer(
299
+ x,
300
+ encoder_out.encoder_out if encoder_out is not None else None,
301
+ encoder_out.encoder_padding_mask if encoder_out is not None else None,
302
+ self_attn_mask=None,
303
+ self_attn_padding_mask=decoder_padding_mask,
304
+ )
305
+ inner_states.append(x)
306
+
307
+ if self.layer_norm:
308
+ x = self.layer_norm(x)
309
+
310
+ # T x B x C -> B x T x C
311
+ x = x.transpose(0, 1)
312
+
313
+ if self.project_out_dim is not None:
314
+ x = self.project_out_dim(x)
315
+
316
+ return x, {"attn": attn, "inner_states": inner_states}
317
+
318
+ def forward_embedding(self, prev_output_tokens, states=None):
319
+ # embed positions
320
+ positions = (
321
+ self.embed_positions(prev_output_tokens)
322
+ if self.embed_positions is not None
323
+ else None
324
+ )
325
+
326
+ # embed tokens and positions
327
+ if states is None:
328
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
329
+ if self.project_in_dim is not None:
330
+ x = self.project_in_dim(x)
331
+ else:
332
+ x = states
333
+
334
+ if positions is not None:
335
+ x += positions
336
+ x = self.dropout_module(x)
337
+ decoder_padding_mask = prev_output_tokens.eq(self.padding_idx)
338
+ return x, decoder_padding_mask
339
+
340
+ def forward_copying_source(self, src_embeds, src_masks, tgt_masks):
341
+ length_sources = src_masks.sum(1)
342
+ length_targets = tgt_masks.sum(1)
343
+ mapped_inputs = _uniform_assignment(length_sources, length_targets).masked_fill(
344
+ ~tgt_masks, 0
345
+ )
346
+ copied_embedding = torch.gather(
347
+ src_embeds,
348
+ 1,
349
+ mapped_inputs.unsqueeze(-1).expand(
350
+ *mapped_inputs.size(), src_embeds.size(-1)
351
+ ),
352
+ )
353
+ return copied_embedding
354
+
355
+ def forward_length_prediction(self, length_out, encoder_out, tgt_tokens=None):
356
+ enc_feats = encoder_out.encoder_out # T x B x C
357
+ src_masks = encoder_out.encoder_padding_mask # B x T or None
358
+ if self.pred_length_offset:
359
+ if src_masks is None:
360
+ src_lengs = enc_feats.new_ones(enc_feats.size(1)).fill_(
361
+ enc_feats.size(0)
362
+ )
363
+ else:
364
+ src_lengs = (~src_masks).transpose(0, 1).type_as(enc_feats).sum(0)
365
+ src_lengs = src_lengs.long()
366
+
367
+ if tgt_tokens is not None:
368
+ # obtain the length target
369
+ tgt_lengs = tgt_tokens.ne(self.padding_idx).sum(1).long()
370
+ if self.pred_length_offset:
371
+ length_tgt = tgt_lengs - src_lengs + 128
372
+ else:
373
+ length_tgt = tgt_lengs
374
+ length_tgt = length_tgt.clamp(min=0, max=255)
375
+
376
+ else:
377
+ # predict the length target (greedy for now)
378
+ # TODO: implementing length-beam
379
+ pred_lengs = length_out.max(-1)[1]
380
+ if self.pred_length_offset:
381
+ length_tgt = pred_lengs - 128 + src_lengs
382
+ else:
383
+ length_tgt = pred_lengs
384
+
385
+ return length_tgt
386
+
387
+
388
+ @register_model_architecture(
389
+ "nonautoregressive_transformer", "nonautoregressive_transformer"
390
+ )
391
+ def base_architecture(args):
392
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
393
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
394
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
395
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
396
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
397
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
398
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
399
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
400
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
401
+ args.decoder_ffn_embed_dim = getattr(
402
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
403
+ )
404
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
405
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
406
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
407
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
408
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
409
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
410
+ args.activation_fn = getattr(args, "activation_fn", "relu")
411
+ args.dropout = getattr(args, "dropout", 0.1)
412
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
413
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
414
+ args.share_decoder_input_output_embed = getattr(
415
+ args, "share_decoder_input_output_embed", False
416
+ )
417
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
418
+ args.no_token_positional_embeddings = getattr(
419
+ args, "no_token_positional_embeddings", False
420
+ )
421
+ args.adaptive_input = getattr(args, "adaptive_input", False)
422
+ args.apply_bert_init = getattr(args, "apply_bert_init", False)
423
+
424
+ args.decoder_output_dim = getattr(
425
+ args, "decoder_output_dim", args.decoder_embed_dim
426
+ )
427
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
428
+
429
+ # --- special arguments ---
430
+ args.sg_length_pred = getattr(args, "sg_length_pred", False)
431
+ args.pred_length_offset = getattr(args, "pred_length_offset", False)
432
+ args.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
433
+ args.src_embedding_copy = getattr(args, "src_embedding_copy", False)
434
+
435
+
436
+ @register_model_architecture(
437
+ "nonautoregressive_transformer", "nonautoregressive_transformer_wmt_en_de"
438
+ )
439
+ def nonautoregressive_transformer_wmt_en_de(args):
440
+ base_architecture(args)
fairseq-0.10.2/fairseq/models/roberta/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .hub_interface import * # noqa
7
+ from .model import * # noqa
8
+ from .model_camembert import * # noqa
9
+ from .model_xlmr import * # noqa
fairseq-0.10.2/fairseq/models/roberta/__pycache__/model.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
fairseq-0.10.2/fairseq/models/roberta/alignment_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import Counter
7
+ from typing import List
8
+
9
+ import torch
10
+
11
+
12
+ def align_bpe_to_words(roberta, bpe_tokens: torch.LongTensor, other_tokens: List[str]):
13
+ """
14
+ Helper to align GPT-2 BPE to other tokenization formats (e.g., spaCy).
15
+
16
+ Args:
17
+ roberta (RobertaHubInterface): RoBERTa instance
18
+ bpe_tokens (torch.LongTensor): GPT-2 BPE tokens of shape `(T_bpe)`
19
+ other_tokens (List[str]): other tokens of shape `(T_words)`
20
+
21
+ Returns:
22
+ List[str]: mapping from *other_tokens* to corresponding *bpe_tokens*.
23
+ """
24
+ assert bpe_tokens.dim() == 1
25
+ assert bpe_tokens[0] == 0
26
+
27
+ def clean(text):
28
+ return text.strip()
29
+
30
+ # remove whitespaces to simplify alignment
31
+ bpe_tokens = [roberta.task.source_dictionary.string([x]) for x in bpe_tokens]
32
+ bpe_tokens = [
33
+ clean(roberta.bpe.decode(x) if x not in {"<s>", ""} else x) for x in bpe_tokens
34
+ ]
35
+ other_tokens = [clean(str(o)) for o in other_tokens]
36
+
37
+ # strip leading <s>
38
+ bpe_tokens = bpe_tokens[1:]
39
+ assert "".join(bpe_tokens) == "".join(other_tokens)
40
+
41
+ # create alignment from every word to a list of BPE tokens
42
+ alignment = []
43
+ bpe_toks = filter(lambda item: item[1] != "", enumerate(bpe_tokens, start=1))
44
+ j, bpe_tok = next(bpe_toks)
45
+ for other_tok in other_tokens:
46
+ bpe_indices = []
47
+ while True:
48
+ if other_tok.startswith(bpe_tok):
49
+ bpe_indices.append(j)
50
+ other_tok = other_tok[len(bpe_tok) :]
51
+ try:
52
+ j, bpe_tok = next(bpe_toks)
53
+ except StopIteration:
54
+ j, bpe_tok = None, None
55
+ elif bpe_tok.startswith(other_tok):
56
+ # other_tok spans multiple BPE tokens
57
+ bpe_indices.append(j)
58
+ bpe_tok = bpe_tok[len(other_tok) :]
59
+ other_tok = ""
60
+ else:
61
+ raise Exception('Cannot align "{}" and "{}"'.format(other_tok, bpe_tok))
62
+ if other_tok == "":
63
+ break
64
+ assert len(bpe_indices) > 0
65
+ alignment.append(bpe_indices)
66
+ assert len(alignment) == len(other_tokens)
67
+
68
+ return alignment
69
+
70
+
71
+ def align_features_to_words(roberta, features, alignment):
72
+ """
73
+ Align given features to words.
74
+
75
+ Args:
76
+ roberta (RobertaHubInterface): RoBERTa instance
77
+ features (torch.Tensor): features to align of shape `(T_bpe x C)`
78
+ alignment: alignment between BPE tokens and words returned by
79
+ func:`align_bpe_to_words`.
80
+ """
81
+ assert features.dim() == 2
82
+
83
+ bpe_counts = Counter(j for bpe_indices in alignment for j in bpe_indices)
84
+ assert bpe_counts[0] == 0 # <s> shouldn't be aligned
85
+ denom = features.new([bpe_counts.get(j, 1) for j in range(len(features))])
86
+ weighted_features = features / denom.unsqueeze(-1)
87
+
88
+ output = [weighted_features[0]]
89
+ largest_j = -1
90
+ for bpe_indices in alignment:
91
+ output.append(weighted_features[bpe_indices].sum(dim=0))
92
+ largest_j = max(largest_j, *bpe_indices)
93
+ for j in range(largest_j + 1, len(features)):
94
+ output.append(weighted_features[j])
95
+ output = torch.stack(output)
96
+ assert torch.all(torch.abs(output.sum(dim=0) - features.sum(dim=0)) < 1e-4)
97
+ return output
98
+
99
+
100
+ def spacy_nlp():
101
+ if getattr(spacy_nlp, "_nlp", None) is None:
102
+ try:
103
+ from spacy.lang.en import English
104
+
105
+ spacy_nlp._nlp = English()
106
+ except ImportError:
107
+ raise ImportError("Please install spacy with: pip install spacy")
108
+ return spacy_nlp._nlp
109
+
110
+
111
+ def spacy_tokenizer():
112
+ if getattr(spacy_tokenizer, "_tokenizer", None) is None:
113
+ try:
114
+ nlp = spacy_nlp()
115
+ spacy_tokenizer._tokenizer = nlp.Defaults.create_tokenizer(nlp)
116
+ except ImportError:
117
+ raise ImportError("Please install spacy with: pip install spacy")
118
+ return spacy_tokenizer._tokenizer
fairseq-0.10.2/fairseq/models/roberta/hub_interface.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from fairseq import utils
11
+ from fairseq.data import encoders
12
+
13
+
14
+ class RobertaHubInterface(nn.Module):
15
+ """A simple PyTorch Hub interface to RoBERTa.
16
+
17
+ Usage: https://github.com/pytorch/fairseq/tree/master/examples/roberta
18
+ """
19
+
20
+ def __init__(self, args, task, model):
21
+ super().__init__()
22
+ self.args = args
23
+ self.task = task
24
+ self.model = model
25
+
26
+ self.bpe = encoders.build_bpe(args)
27
+
28
+ # this is useful for determining the device
29
+ self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
30
+
31
+ @property
32
+ def device(self):
33
+ return self._float_tensor.device
34
+
35
+ def encode(
36
+ self, sentence: str, *addl_sentences, no_separator=False
37
+ ) -> torch.LongTensor:
38
+ """
39
+ BPE-encode a sentence (or multiple sentences).
40
+
41
+ Every sequence begins with a beginning-of-sentence (`<s>`) symbol.
42
+ Every sentence ends with an end-of-sentence (`</s>`) and we use an
43
+ extra end-of-sentence (`</s>`) as a separator.
44
+
45
+ Example (single sentence): `<s> a b c </s>`
46
+ Example (sentence pair): `<s> d e f </s> </s> 1 2 3 </s>`
47
+
48
+ The BPE encoding follows GPT-2. One subtle detail is that the GPT-2 BPE
49
+ requires leading spaces. For example::
50
+
51
+ >>> roberta.encode('Hello world').tolist()
52
+ [0, 31414, 232, 2]
53
+ >>> roberta.encode(' world').tolist()
54
+ [0, 232, 2]
55
+ >>> roberta.encode('world').tolist()
56
+ [0, 8331, 2]
57
+ """
58
+ bpe_sentence = "<s> " + self.bpe.encode(sentence) + " </s>"
59
+ for s in addl_sentences:
60
+ bpe_sentence += " </s>" if not no_separator else ""
61
+ bpe_sentence += " " + self.bpe.encode(s) + " </s>"
62
+ tokens = self.task.source_dictionary.encode_line(
63
+ bpe_sentence, append_eos=False, add_if_not_exist=False
64
+ )
65
+ return tokens.long()
66
+
67
+ def decode(self, tokens: torch.LongTensor):
68
+ assert tokens.dim() == 1
69
+ tokens = tokens.numpy()
70
+ if tokens[0] == self.task.source_dictionary.bos():
71
+ tokens = tokens[1:] # remove <s>
72
+ eos_mask = tokens == self.task.source_dictionary.eos()
73
+ doc_mask = eos_mask[1:] & eos_mask[:-1]
74
+ sentences = np.split(tokens, doc_mask.nonzero()[0] + 1)
75
+ sentences = [
76
+ self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences
77
+ ]
78
+ if len(sentences) == 1:
79
+ return sentences[0]
80
+ return sentences
81
+
82
+ def extract_features(
83
+ self, tokens: torch.LongTensor, return_all_hiddens: bool = False
84
+ ) -> torch.Tensor:
85
+ if tokens.dim() == 1:
86
+ tokens = tokens.unsqueeze(0)
87
+ if tokens.size(-1) > self.model.max_positions():
88
+ raise ValueError(
89
+ "tokens exceeds maximum length: {} > {}".format(
90
+ tokens.size(-1), self.model.max_positions()
91
+ )
92
+ )
93
+ features, extra = self.model(
94
+ tokens.to(device=self.device),
95
+ features_only=True,
96
+ return_all_hiddens=return_all_hiddens,
97
+ )
98
+ if return_all_hiddens:
99
+ # convert from T x B x C -> B x T x C
100
+ inner_states = extra["inner_states"]
101
+ return [inner_state.transpose(0, 1) for inner_state in inner_states]
102
+ else:
103
+ return features # just the last layer's features
104
+
105
+ def register_classification_head(
106
+ self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs
107
+ ):
108
+ self.model.register_classification_head(
109
+ name, num_classes=num_classes, embedding_size=embedding_size, **kwargs
110
+ )
111
+
112
+ def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False):
113
+ features = self.extract_features(tokens.to(device=self.device))
114
+ logits = self.model.classification_heads[head](features)
115
+ if return_logits:
116
+ return logits
117
+ return F.log_softmax(logits, dim=-1)
118
+
119
+ def extract_features_aligned_to_words(
120
+ self, sentence: str, return_all_hiddens: bool = False
121
+ ) -> torch.Tensor:
122
+ """Extract RoBERTa features, aligned to spaCy's word-level tokenizer."""
123
+ from fairseq.models.roberta import alignment_utils
124
+ from spacy.tokens import Doc
125
+
126
+ nlp = alignment_utils.spacy_nlp()
127
+ tokenizer = alignment_utils.spacy_tokenizer()
128
+
129
+ # tokenize both with GPT-2 BPE and spaCy
130
+ bpe_toks = self.encode(sentence)
131
+ spacy_toks = tokenizer(sentence)
132
+ spacy_toks_ws = [t.text_with_ws for t in tokenizer(sentence)]
133
+ alignment = alignment_utils.align_bpe_to_words(self, bpe_toks, spacy_toks_ws)
134
+
135
+ # extract features and align them
136
+ features = self.extract_features(
137
+ bpe_toks, return_all_hiddens=return_all_hiddens
138
+ )
139
+ features = features.squeeze(0)
140
+ aligned_feats = alignment_utils.align_features_to_words(
141
+ self, features, alignment
142
+ )
143
+
144
+ # wrap in spaCy Doc
145
+ doc = Doc(
146
+ nlp.vocab,
147
+ words=["<s>"] + [x.text for x in spacy_toks] + ["</s>"],
148
+ spaces=[True]
149
+ + [x.endswith(" ") for x in spacy_toks_ws[:-1]]
150
+ + [True, False],
151
+ )
152
+ assert len(doc) == aligned_feats.size(0)
153
+ doc.user_token_hooks["vector"] = lambda token: aligned_feats[token.i]
154
+ return doc
155
+
156
+ def fill_mask(self, masked_input: str, topk: int = 5):
157
+ masked_token = "<mask>"
158
+ assert (
159
+ masked_token in masked_input and masked_input.count(masked_token) == 1
160
+ ), "Please add one {0} token for the input, eg: 'He is a {0} guy'".format(
161
+ masked_token
162
+ )
163
+
164
+ text_spans = masked_input.split(masked_token)
165
+ text_spans_bpe = (
166
+ (" {0} ".format(masked_token))
167
+ .join([self.bpe.encode(text_span.rstrip()) for text_span in text_spans])
168
+ .strip()
169
+ )
170
+ tokens = self.task.source_dictionary.encode_line(
171
+ "<s> " + text_spans_bpe + " </s>",
172
+ append_eos=False,
173
+ add_if_not_exist=False,
174
+ )
175
+
176
+ masked_index = (tokens == self.task.mask_idx).nonzero()
177
+ if tokens.dim() == 1:
178
+ tokens = tokens.unsqueeze(0)
179
+
180
+ with utils.model_eval(self.model):
181
+ features, extra = self.model(
182
+ tokens.long().to(device=self.device),
183
+ features_only=False,
184
+ return_all_hiddens=False,
185
+ )
186
+ logits = features[0, masked_index, :].squeeze()
187
+ prob = logits.softmax(dim=0)
188
+ values, index = prob.topk(k=topk, dim=0)
189
+ topk_predicted_token_bpe = self.task.source_dictionary.string(index)
190
+
191
+ topk_filled_outputs = []
192
+ for index, predicted_token_bpe in enumerate(
193
+ topk_predicted_token_bpe.split(" ")
194
+ ):
195
+ predicted_token = self.bpe.decode(predicted_token_bpe)
196
+ # Quick hack to fix https://github.com/pytorch/fairseq/issues/1306
197
+ if predicted_token_bpe.startswith("\u2581"):
198
+ predicted_token = " " + predicted_token
199
+ if " {0}".format(masked_token) in masked_input:
200
+ topk_filled_outputs.append(
201
+ (
202
+ masked_input.replace(
203
+ " {0}".format(masked_token), predicted_token
204
+ ),
205
+ values[index].item(),
206
+ predicted_token,
207
+ )
208
+ )
209
+ else:
210
+ topk_filled_outputs.append(
211
+ (
212
+ masked_input.replace(masked_token, predicted_token),
213
+ values[index].item(),
214
+ predicted_token,
215
+ )
216
+ )
217
+ return topk_filled_outputs
218
+
219
+ def disambiguate_pronoun(self, sentence: str) -> bool:
220
+ """
221
+ Usage::
222
+
223
+ >>> disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.')
224
+ True
225
+
226
+ >>> disambiguate_pronoun('The trophy would not fit in the brown suitcase because [it] was too big.')
227
+ 'The trophy'
228
+ """
229
+ assert hasattr(
230
+ self.task, "disambiguate_pronoun"
231
+ ), "roberta.disambiguate_pronoun() requires a model trained with the WSC task."
232
+ with utils.model_eval(self.model):
233
+ return self.task.disambiguate_pronoun(
234
+ self.model, sentence, use_cuda=self.device.type == "cuda"
235
+ )
fairseq-0.10.2/fairseq/models/roberta/model.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ RoBERTa: A Robustly Optimized BERT Pretraining Approach.
7
+ """
8
+
9
+ import logging
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from fairseq import utils
15
+ from fairseq.models import (
16
+ FairseqEncoder,
17
+ FairseqEncoderModel,
18
+ register_model,
19
+ register_model_architecture,
20
+ )
21
+ from fairseq.modules import LayerNorm, TransformerSentenceEncoder
22
+ from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
23
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
24
+
25
+ from .hub_interface import RobertaHubInterface
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @register_model("roberta")
32
+ class RobertaModel(FairseqEncoderModel):
33
+ @classmethod
34
+ def hub_models(cls):
35
+ return {
36
+ "roberta.base": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz",
37
+ "roberta.large": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz",
38
+ "roberta.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz",
39
+ "roberta.large.wsc": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz",
40
+ }
41
+
42
+ def __init__(self, args, encoder):
43
+ super().__init__(encoder)
44
+ self.args = args
45
+
46
+ # We follow BERT's random weight initialization
47
+ self.apply(init_bert_params)
48
+
49
+ self.classification_heads = nn.ModuleDict()
50
+
51
+ @staticmethod
52
+ def add_args(parser):
53
+ """Add model-specific arguments to the parser."""
54
+ parser.add_argument(
55
+ "--encoder-layers", type=int, metavar="L", help="num encoder layers"
56
+ )
57
+ parser.add_argument(
58
+ "--encoder-embed-dim",
59
+ type=int,
60
+ metavar="H",
61
+ help="encoder embedding dimension",
62
+ )
63
+ parser.add_argument(
64
+ "--encoder-ffn-embed-dim",
65
+ type=int,
66
+ metavar="F",
67
+ help="encoder embedding dimension for FFN",
68
+ )
69
+ parser.add_argument(
70
+ "--encoder-attention-heads",
71
+ type=int,
72
+ metavar="A",
73
+ help="num encoder attention heads",
74
+ )
75
+ parser.add_argument(
76
+ "--activation-fn",
77
+ choices=utils.get_available_activation_fns(),
78
+ help="activation function to use",
79
+ )
80
+ parser.add_argument(
81
+ "--pooler-activation-fn",
82
+ choices=utils.get_available_activation_fns(),
83
+ help="activation function to use for pooler layer",
84
+ )
85
+ parser.add_argument(
86
+ "--encoder-normalize-before",
87
+ action="store_true",
88
+ help="apply layernorm before each encoder block",
89
+ )
90
+ parser.add_argument(
91
+ "--dropout", type=float, metavar="D", help="dropout probability"
92
+ )
93
+ parser.add_argument(
94
+ "--attention-dropout",
95
+ type=float,
96
+ metavar="D",
97
+ help="dropout probability for attention weights",
98
+ )
99
+ parser.add_argument(
100
+ "--activation-dropout",
101
+ type=float,
102
+ metavar="D",
103
+ help="dropout probability after activation in FFN",
104
+ )
105
+ parser.add_argument(
106
+ "--pooler-dropout",
107
+ type=float,
108
+ metavar="D",
109
+ help="dropout probability in the masked_lm pooler layers",
110
+ )
111
+ parser.add_argument(
112
+ "--max-positions", type=int, help="number of positional embeddings to learn"
113
+ )
114
+ parser.add_argument(
115
+ "--load-checkpoint-heads",
116
+ action="store_true",
117
+ help="(re-)register and load heads when loading checkpoints",
118
+ )
119
+ # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
120
+ parser.add_argument(
121
+ "--encoder-layerdrop",
122
+ type=float,
123
+ metavar="D",
124
+ default=0,
125
+ help="LayerDrop probability for encoder",
126
+ )
127
+ parser.add_argument(
128
+ "--encoder-layers-to-keep",
129
+ default=None,
130
+ help="which layers to *keep* when pruning as a comma-separated list",
131
+ )
132
+ # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
133
+ parser.add_argument(
134
+ "--quant-noise-pq",
135
+ type=float,
136
+ metavar="D",
137
+ default=0,
138
+ help="iterative PQ quantization noise at training time",
139
+ )
140
+ parser.add_argument(
141
+ "--quant-noise-pq-block-size",
142
+ type=int,
143
+ metavar="D",
144
+ default=8,
145
+ help="block size of quantization noise at training time",
146
+ )
147
+ parser.add_argument(
148
+ "--quant-noise-scalar",
149
+ type=float,
150
+ metavar="D",
151
+ default=0,
152
+ help="scalar quantization noise and scalar quantization at training time",
153
+ )
154
+ parser.add_argument(
155
+ "--untie-weights-roberta",
156
+ action="store_true",
157
+ help="Untie weights between embeddings and classifiers in RoBERTa",
158
+ )
159
+ parser.add_argument(
160
+ "--spectral-norm-classification-head",
161
+ action="store_true",
162
+ default=False,
163
+ help="Apply spectral normalization on the classification head",
164
+ )
165
+
166
+ @classmethod
167
+ def build_model(cls, args, task):
168
+ """Build a new model instance."""
169
+
170
+ # make sure all arguments are present
171
+ base_architecture(args)
172
+
173
+ if not hasattr(args, "max_positions"):
174
+ args.max_positions = args.tokens_per_sample
175
+
176
+ encoder = RobertaEncoder(args, task.source_dictionary)
177
+ return cls(args, encoder)
178
+
179
+ def forward(
180
+ self,
181
+ src_tokens,
182
+ features_only=False,
183
+ return_all_hiddens=False,
184
+ classification_head_name=None,
185
+ **kwargs
186
+ ):
187
+ if classification_head_name is not None:
188
+ features_only = True
189
+
190
+ x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
191
+
192
+ if classification_head_name is not None:
193
+ x = self.classification_heads[classification_head_name](x)
194
+ return x, extra
195
+
196
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
197
+ """Get normalized probabilities (or log probs) from a net's output."""
198
+ logits = net_output[0].float()
199
+ if log_probs:
200
+ return F.log_softmax(logits, dim=-1)
201
+ else:
202
+ return F.softmax(logits, dim=-1)
203
+
204
+ def register_classification_head(
205
+ self, name, num_classes=None, inner_dim=None, **kwargs
206
+ ):
207
+ """Register a classification head."""
208
+ if name in self.classification_heads:
209
+ prev_num_classes = self.classification_heads[name].out_proj.out_features
210
+ prev_inner_dim = self.classification_heads[name].dense.out_features
211
+ if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
212
+ logger.warning(
213
+ 're-registering head "{}" with num_classes {} (prev: {}) '
214
+ "and inner_dim {} (prev: {})".format(
215
+ name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
216
+ )
217
+ )
218
+ self.classification_heads[name] = RobertaClassificationHead(
219
+ input_dim=self.args.encoder_embed_dim,
220
+ inner_dim=inner_dim or self.args.encoder_embed_dim,
221
+ num_classes=num_classes,
222
+ activation_fn=self.args.pooler_activation_fn,
223
+ pooler_dropout=self.args.pooler_dropout,
224
+ q_noise=self.args.quant_noise_pq,
225
+ qn_block_size=self.args.quant_noise_pq_block_size,
226
+ do_spectral_norm=self.args.spectral_norm_classification_head,
227
+ )
228
+
229
+ @property
230
+ def supported_targets(self):
231
+ return {"self"}
232
+
233
+ @classmethod
234
+ def from_pretrained(
235
+ cls,
236
+ model_name_or_path,
237
+ checkpoint_file="model.pt",
238
+ data_name_or_path=".",
239
+ bpe="gpt2",
240
+ **kwargs
241
+ ):
242
+ from fairseq import hub_utils
243
+
244
+ x = hub_utils.from_pretrained(
245
+ model_name_or_path,
246
+ checkpoint_file,
247
+ data_name_or_path,
248
+ archive_map=cls.hub_models(),
249
+ bpe=bpe,
250
+ load_checkpoint_heads=True,
251
+ **kwargs,
252
+ )
253
+ cls.upgrade_args(x["args"])
254
+
255
+ logger.info(x["args"])
256
+ return RobertaHubInterface(x["args"], x["task"], x["models"][0])
257
+
258
+ def upgrade_state_dict_named(self, state_dict, name):
259
+ prefix = name + "." if name != "" else ""
260
+
261
+ # rename decoder -> encoder before upgrading children modules
262
+ for k in list(state_dict.keys()):
263
+ if k.startswith(prefix + "decoder"):
264
+ new_k = prefix + "encoder" + k[len(prefix + "decoder") :]
265
+ state_dict[new_k] = state_dict[k]
266
+ del state_dict[k]
267
+
268
+ # upgrade children modules
269
+ super().upgrade_state_dict_named(state_dict, name)
270
+
271
+ # Handle new classification heads present in the state dict.
272
+ current_head_names = (
273
+ []
274
+ if not hasattr(self, "classification_heads")
275
+ else self.classification_heads.keys()
276
+ )
277
+ keys_to_delete = []
278
+ for k in state_dict.keys():
279
+ if not k.startswith(prefix + "classification_heads."):
280
+ continue
281
+
282
+ head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
283
+ num_classes = state_dict[
284
+ prefix + "classification_heads." + head_name + ".out_proj.weight"
285
+ ].size(0)
286
+ inner_dim = state_dict[
287
+ prefix + "classification_heads." + head_name + ".dense.weight"
288
+ ].size(0)
289
+
290
+ if getattr(self.args, "load_checkpoint_heads", False):
291
+ if head_name not in current_head_names:
292
+ self.register_classification_head(head_name, num_classes, inner_dim)
293
+ else:
294
+ if head_name not in current_head_names:
295
+ logger.warning(
296
+ "deleting classification head ({}) from checkpoint "
297
+ "not present in current model: {}".format(head_name, k)
298
+ )
299
+ keys_to_delete.append(k)
300
+ elif (
301
+ num_classes
302
+ != self.classification_heads[head_name].out_proj.out_features
303
+ or inner_dim
304
+ != self.classification_heads[head_name].dense.out_features
305
+ ):
306
+ logger.warning(
307
+ "deleting classification head ({}) from checkpoint "
308
+ "with different dimensions than current model: {}".format(
309
+ head_name, k
310
+ )
311
+ )
312
+ keys_to_delete.append(k)
313
+ for k in keys_to_delete:
314
+ del state_dict[k]
315
+
316
+ # Copy any newly-added classification heads into the state dict
317
+ # with their current weights.
318
+ if hasattr(self, "classification_heads"):
319
+ cur_state = self.classification_heads.state_dict()
320
+ for k, v in cur_state.items():
321
+ if prefix + "classification_heads." + k not in state_dict:
322
+ logger.info("Overwriting " + prefix + "classification_heads." + k)
323
+ state_dict[prefix + "classification_heads." + k] = v
324
+
325
+
326
+ class RobertaLMHead(nn.Module):
327
+ """Head for masked language modeling."""
328
+
329
+ def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
330
+ super().__init__()
331
+ self.dense = nn.Linear(embed_dim, embed_dim)
332
+ self.activation_fn = utils.get_activation_fn(activation_fn)
333
+ self.layer_norm = LayerNorm(embed_dim)
334
+
335
+ if weight is None:
336
+ weight = nn.Linear(embed_dim, output_dim, bias=False).weight
337
+ self.weight = weight
338
+ self.bias = nn.Parameter(torch.zeros(output_dim))
339
+
340
+ def forward(self, features, masked_tokens=None, **kwargs):
341
+ # Only project the masked tokens while training,
342
+ # saves both memory and computation
343
+ if masked_tokens is not None:
344
+ features = features[masked_tokens, :]
345
+
346
+ x = self.dense(features)
347
+ x = self.activation_fn(x)
348
+ x = self.layer_norm(x)
349
+ # project back to size of vocabulary with bias
350
+ x = F.linear(x, self.weight) + self.bias
351
+ return x
352
+
353
+
354
+ class RobertaClassificationHead(nn.Module):
355
+ """Head for sentence-level classification tasks."""
356
+
357
+ def __init__(
358
+ self,
359
+ input_dim,
360
+ inner_dim,
361
+ num_classes,
362
+ activation_fn,
363
+ pooler_dropout,
364
+ q_noise=0,
365
+ qn_block_size=8,
366
+ do_spectral_norm=False,
367
+ ):
368
+ super().__init__()
369
+ self.dense = nn.Linear(input_dim, inner_dim)
370
+ self.activation_fn = utils.get_activation_fn(activation_fn)
371
+ self.dropout = nn.Dropout(p=pooler_dropout)
372
+ self.out_proj = apply_quant_noise_(
373
+ nn.Linear(inner_dim, num_classes), q_noise, qn_block_size
374
+ )
375
+ if do_spectral_norm:
376
+ if q_noise != 0:
377
+ raise NotImplementedError(
378
+ "Attempting to use Spectral Normalization with Quant Noise. This is not officially supported"
379
+ )
380
+ self.out_proj = torch.nn.utils.spectral_norm(self.out_proj)
381
+
382
+ def forward(self, features, **kwargs):
383
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
384
+ x = self.dropout(x)
385
+ x = self.dense(x)
386
+ x = self.activation_fn(x)
387
+ x = self.dropout(x)
388
+ x = self.out_proj(x)
389
+ return x
390
+
391
+
392
+ class RobertaEncoder(FairseqEncoder):
393
+ """RoBERTa encoder."""
394
+
395
+ def __init__(self, args, dictionary):
396
+ super().__init__(dictionary)
397
+ self.args = args
398
+
399
+ if args.encoder_layers_to_keep:
400
+ args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
401
+
402
+ self.sentence_encoder = TransformerSentenceEncoder(
403
+ padding_idx=dictionary.pad(),
404
+ vocab_size=len(dictionary),
405
+ num_encoder_layers=args.encoder_layers,
406
+ embedding_dim=args.encoder_embed_dim,
407
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
408
+ num_attention_heads=args.encoder_attention_heads,
409
+ dropout=args.dropout,
410
+ attention_dropout=args.attention_dropout,
411
+ activation_dropout=args.activation_dropout,
412
+ layerdrop=args.encoder_layerdrop,
413
+ max_seq_len=args.max_positions,
414
+ num_segments=0,
415
+ encoder_normalize_before=True,
416
+ apply_bert_init=True,
417
+ activation_fn=args.activation_fn,
418
+ q_noise=args.quant_noise_pq,
419
+ qn_block_size=args.quant_noise_pq_block_size,
420
+ )
421
+ args.untie_weights_roberta = getattr(args, "untie_weights_roberta", False)
422
+
423
+ self.lm_head = RobertaLMHead(
424
+ embed_dim=args.encoder_embed_dim,
425
+ output_dim=len(dictionary),
426
+ activation_fn=args.activation_fn,
427
+ weight=(
428
+ self.sentence_encoder.embed_tokens.weight
429
+ if not args.untie_weights_roberta
430
+ else None
431
+ ),
432
+ )
433
+
434
+ def forward(
435
+ self,
436
+ src_tokens,
437
+ features_only=False,
438
+ return_all_hiddens=False,
439
+ masked_tokens=None,
440
+ **unused
441
+ ):
442
+ """
443
+ Args:
444
+ src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
445
+ features_only (bool, optional): skip LM head and just return
446
+ features. If True, the output will be of shape
447
+ `(batch, src_len, embed_dim)`.
448
+ return_all_hiddens (bool, optional): also return all of the
449
+ intermediate hidden states (default: False).
450
+
451
+ Returns:
452
+ tuple:
453
+ - the LM output of shape `(batch, src_len, vocab)`
454
+ - a dictionary of additional data, where 'inner_states'
455
+ is a list of hidden states. Note that the hidden
456
+ states have shape `(src_len, batch, vocab)`.
457
+ """
458
+ x, extra = self.extract_features(
459
+ src_tokens, return_all_hiddens=return_all_hiddens
460
+ )
461
+ if not features_only:
462
+ x = self.output_layer(x, masked_tokens=masked_tokens)
463
+ return x, extra
464
+
465
+ def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs):
466
+ inner_states, _ = self.sentence_encoder(
467
+ src_tokens,
468
+ last_state_only=not return_all_hiddens,
469
+ token_embeddings=kwargs.get("token_embeddings", None),
470
+ )
471
+ features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C
472
+ return features, {"inner_states": inner_states if return_all_hiddens else None}
473
+
474
+ def output_layer(self, features, masked_tokens=None, **unused):
475
+ return self.lm_head(features, masked_tokens)
476
+
477
+ def max_positions(self):
478
+ """Maximum output length supported by the encoder."""
479
+ return self.args.max_positions
480
+
481
+
482
+ @register_model_architecture("roberta", "roberta")
483
+ def base_architecture(args):
484
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
485
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
486
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
487
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
488
+
489
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
490
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
491
+
492
+ args.dropout = getattr(args, "dropout", 0.1)
493
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
494
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
495
+ args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
496
+ args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
497
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
498
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
499
+ args.spectral_norm_classification_head = getattr(
500
+ args, "spectral_nrom_classification_head", False
501
+ )
502
+
503
+
504
+ @register_model_architecture("roberta", "roberta_base")
505
+ def roberta_base_architecture(args):
506
+ base_architecture(args)
507
+
508
+
509
+ @register_model_architecture("roberta", "roberta_large")
510
+ def roberta_large_architecture(args):
511
+ args.encoder_layers = getattr(args, "encoder_layers", 24)
512
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
513
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
514
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
515
+ base_architecture(args)
516
+
517
+
518
+ @register_model_architecture("roberta", "xlm")
519
+ def xlm_architecture(args):
520
+ args.encoder_layers = getattr(args, "encoder_layers", 16)
521
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1280)
522
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1280 * 4)
523
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
524
+ base_architecture(args)
fairseq-0.10.2/fairseq/models/roberta/model_camembert.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ CamemBERT: a Tasty French Language Model
7
+ """
8
+
9
+ from fairseq.models import register_model
10
+
11
+ from .hub_interface import RobertaHubInterface
12
+ from .model import RobertaModel
13
+
14
+
15
+ @register_model("camembert")
16
+ class CamembertModel(RobertaModel):
17
+ @classmethod
18
+ def hub_models(cls):
19
+ return {
20
+ "camembert": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz",
21
+ "camembert.v0": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz",
22
+ "camembert-base": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz",
23
+ "camembert-large": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-large.tar.gz",
24
+ "camembert-base-ccnet": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet.tar.gz",
25
+ "camembert-base-ccnet-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet-4gb.tar.gz",
26
+ "camembert-base-wikipedia-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-wikipedia-4gb.tar.gz",
27
+ "camembert-base-oscar-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-oscar-4gb.tar.gz",
28
+ }
29
+
30
+ @classmethod
31
+ def from_pretrained(
32
+ cls,
33
+ model_name_or_path,
34
+ checkpoint_file="model.pt",
35
+ data_name_or_path=".",
36
+ bpe="sentencepiece",
37
+ **kwargs
38
+ ):
39
+ from fairseq import hub_utils
40
+
41
+ x = hub_utils.from_pretrained(
42
+ model_name_or_path,
43
+ checkpoint_file,
44
+ data_name_or_path,
45
+ archive_map=cls.hub_models(),
46
+ bpe=bpe,
47
+ load_checkpoint_heads=True,
48
+ **kwargs,
49
+ )
50
+ return RobertaHubInterface(x["args"], x["task"], x["models"][0])
fairseq-0.10.2/fairseq/models/wav2vec/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .wav2vec import * # noqa
7
+ from .wav2vec2 import * # noqa
8
+ from .wav2vec2_asr import * # noqa
fairseq-0.10.2/fairseq/models/wav2vec/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (239 Bytes). View file