sleepyhead111 commited on
Commit
3771248
·
verified ·
1 Parent(s): 3160b62

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq-0.10.2/fairseq/__pycache__/pdb.cpython-310.pyc +0 -0
  2. fairseq-0.10.2/fairseq/hub_utils.py +294 -0
  3. fairseq-0.10.2/fairseq/iterative_refinement_generator.py +359 -0
  4. fairseq-0.10.2/fairseq/legacy_distributed_data_parallel.py +171 -0
  5. fairseq-0.10.2/fairseq/model_parallel/__pycache__/megatron_trainer.cpython-310.pyc +0 -0
  6. fairseq-0.10.2/fairseq/model_parallel/criterions/__pycache__/vocab_parallel_cross_entropy.cpython-310.pyc +0 -0
  7. fairseq-0.10.2/fairseq/model_parallel/megatron_trainer.py +66 -0
  8. fairseq-0.10.2/fairseq/model_parallel/models/__pycache__/__init__.cpython-310.pyc +0 -0
  9. fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py +6 -0
  10. fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/__init__.cpython-310.pyc +0 -0
  11. fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py +600 -0
  12. fairseq-0.10.2/fairseq/model_parallel/models/roberta/model.py +287 -0
  13. fairseq-0.10.2/fairseq/model_parallel/modules/__pycache__/transformer_sentence_encoder_layer.cpython-310.pyc +0 -0
  14. fairseq-0.10.2/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py +77 -0
  15. fairseq-0.10.2/fairseq/modules/__init__.py +76 -0
  16. fairseq-0.10.2/fairseq/modules/__pycache__/adaptive_softmax.cpython-310.pyc +0 -0
  17. fairseq-0.10.2/fairseq/modules/__pycache__/dynamic_convolution.cpython-310.pyc +0 -0
  18. fairseq-0.10.2/fairseq/modules/__pycache__/grad_multiply.cpython-310.pyc +0 -0
  19. fairseq-0.10.2/fairseq/modules/__pycache__/gumbel_vector_quantizer.cpython-310.pyc +0 -0
  20. fairseq-0.10.2/fairseq/modules/__pycache__/kmeans_vector_quantizer.cpython-310.pyc +0 -0
  21. fairseq-0.10.2/fairseq/modules/__pycache__/learned_positional_embedding.cpython-310.pyc +0 -0
  22. fairseq-0.10.2/fairseq/modules/__pycache__/lightweight_convolution.cpython-310.pyc +0 -0
  23. fairseq-0.10.2/fairseq/modules/__pycache__/transformer_sentence_encoder_layer.cpython-310.pyc +0 -0
  24. fairseq-0.10.2/fairseq/modules/__pycache__/vggblock.cpython-310.pyc +0 -0
  25. fairseq-0.10.2/fairseq/modules/adaptive_softmax.py +268 -0
  26. fairseq-0.10.2/fairseq/modules/beamable_mm.py +49 -0
  27. fairseq-0.10.2/fairseq/modules/character_token_embedder.py +214 -0
  28. fairseq-0.10.2/fairseq/modules/cross_entropy.py +59 -0
  29. fairseq-0.10.2/fairseq/modules/fp32_group_norm.py +25 -0
  30. fairseq-0.10.2/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu +375 -0
  31. fairseq-0.10.2/fairseq/modules/linearized_convolution.py +104 -0
  32. fairseq-0.10.2/fairseq/modules/multihead_attention.py +488 -0
  33. fairseq-0.10.2/fairseq/modules/positional_embedding.py +35 -0
  34. fairseq-0.10.2/fairseq/modules/quant_noise.py +107 -0
  35. fairseq-0.10.2/fairseq/modules/quantization/__pycache__/quantization_options.cpython-310.pyc +0 -0
  36. fairseq-0.10.2/fairseq/modules/quantization/pq/__pycache__/__init__.cpython-310.pyc +0 -0
  37. fairseq-0.10.2/fairseq/modules/quantization/pq/__pycache__/pq.cpython-310.pyc +0 -0
  38. fairseq-0.10.2/fairseq/modules/quantization/pq/__pycache__/utils.cpython-310.pyc +0 -0
  39. fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__init__.py +8 -0
  40. fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  41. fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__pycache__/qconv.cpython-310.pyc +0 -0
  42. fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__pycache__/qemb.cpython-310.pyc +0 -0
  43. fairseq-0.10.2/fairseq/modules/quantization/pq/modules/qemb.py +107 -0
  44. fairseq-0.10.2/fairseq/modules/quantization/pq/modules/qlinear.py +71 -0
  45. fairseq-0.10.2/fairseq/modules/quantization/pq/pq.py +128 -0
  46. fairseq-0.10.2/fairseq/modules/quantization/quantization_options.py +44 -0
  47. fairseq-0.10.2/fairseq/modules/quantization/scalar/__pycache__/ops.cpython-310.pyc +0 -0
  48. fairseq-0.10.2/fairseq/modules/quantization/scalar/modules/__pycache__/qemb.cpython-310.pyc +0 -0
  49. fairseq-0.10.2/fairseq/modules/quantization/scalar/modules/qemb.py +147 -0
  50. fairseq-0.10.2/fairseq/modules/sparse_transformer_sentence_encoder.py +96 -0
fairseq-0.10.2/fairseq/__pycache__/pdb.cpython-310.pyc ADDED
Binary file (1.33 kB). View file
 
fairseq-0.10.2/fairseq/hub_utils.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import copy
9
+ import logging
10
+ import os
11
+ from typing import Any, Dict, Iterator, List, Tuple
12
+
13
+ import torch
14
+ from fairseq import utils
15
+ from fairseq.data import encoders
16
+ from torch import nn
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def from_pretrained(
23
+ model_name_or_path,
24
+ checkpoint_file="model.pt",
25
+ data_name_or_path=".",
26
+ archive_map=None,
27
+ **kwargs
28
+ ):
29
+ from fairseq import checkpoint_utils, file_utils
30
+
31
+ if archive_map is not None:
32
+ if model_name_or_path in archive_map:
33
+ model_name_or_path = archive_map[model_name_or_path]
34
+ if data_name_or_path is not None and data_name_or_path in archive_map:
35
+ data_name_or_path = archive_map[data_name_or_path]
36
+
37
+ # allow archive_map to set default arg_overrides (e.g., tokenizer, bpe)
38
+ # for each model
39
+ if isinstance(model_name_or_path, dict):
40
+ for k, v in model_name_or_path.items():
41
+ if k == "checkpoint_file":
42
+ checkpoint_file = v
43
+ elif (
44
+ k != "path"
45
+ # only set kwargs that don't already have overrides
46
+ and k not in kwargs
47
+ ):
48
+ kwargs[k] = v
49
+ model_name_or_path = model_name_or_path["path"]
50
+
51
+ model_path = file_utils.load_archive_file(model_name_or_path)
52
+
53
+ # convenience hack for loading data and BPE codes from model archive
54
+ if data_name_or_path.startswith("."):
55
+ kwargs["data"] = os.path.abspath(os.path.join(model_path, data_name_or_path))
56
+ else:
57
+ kwargs["data"] = file_utils.load_archive_file(data_name_or_path)
58
+ for file, arg in {
59
+ "code": "bpe_codes",
60
+ "bpecodes": "bpe_codes",
61
+ "sentencepiece.bpe.model": "sentencepiece_model",
62
+ }.items():
63
+ path = os.path.join(model_path, file)
64
+ if os.path.exists(path):
65
+ kwargs[arg] = path
66
+
67
+ if "user_dir" in kwargs:
68
+ utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"]))
69
+
70
+ models, args, task = checkpoint_utils.load_model_ensemble_and_task(
71
+ [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)],
72
+ arg_overrides=kwargs,
73
+ )
74
+
75
+ return {
76
+ "args": args,
77
+ "task": task,
78
+ "models": models,
79
+ }
80
+
81
+
82
+ class GeneratorHubInterface(nn.Module):
83
+ """
84
+ PyTorch Hub interface for generating sequences from a pre-trained
85
+ translation or language model.
86
+ """
87
+
88
+ def __init__(self, args, task, models):
89
+ super().__init__()
90
+ self.args = args
91
+ self.task = task
92
+ self.models = nn.ModuleList(models)
93
+ self.src_dict = task.source_dictionary
94
+ self.tgt_dict = task.target_dictionary
95
+
96
+ # optimize model for generation
97
+ for model in self.models:
98
+ model.prepare_for_inference_(args)
99
+
100
+ # Load alignment dictionary for unknown word replacement
101
+ # (None if no unknown word replacement, empty if no path to align dictionary)
102
+ self.align_dict = utils.load_align_dict(getattr(args, "replace_unk", None))
103
+
104
+ self.tokenizer = encoders.build_tokenizer(args)
105
+ self.bpe = encoders.build_bpe(args)
106
+
107
+ self.max_positions = utils.resolve_max_positions(
108
+ self.task.max_positions(), *[model.max_positions() for model in models]
109
+ )
110
+
111
+ # this is useful for determining the device
112
+ self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
113
+
114
+ @property
115
+ def device(self):
116
+ return self._float_tensor.device
117
+
118
+ def translate(
119
+ self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs
120
+ ) -> List[str]:
121
+ return self.sample(sentences, beam, verbose, **kwargs)
122
+
123
+ def sample(
124
+ self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs
125
+ ) -> List[str]:
126
+ if isinstance(sentences, str):
127
+ return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
128
+ tokenized_sentences = [self.encode(sentence) for sentence in sentences]
129
+ batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
130
+ return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos]
131
+
132
+ def score(self, sentences: List[str], **kwargs):
133
+ if isinstance(sentences, str):
134
+ return self.score([sentences], **kwargs)[0]
135
+ # NOTE: this doesn't support translation tasks currently
136
+ tokenized_sentences = [self.encode(sentence) for sentence in sentences]
137
+ return [
138
+ hypos[0]
139
+ for hypos in self.generate(
140
+ tokenized_sentences, score_reference=True, **kwargs
141
+ )
142
+ ]
143
+
144
+ def generate(
145
+ self,
146
+ tokenized_sentences: List[torch.LongTensor],
147
+ beam: int = 5,
148
+ verbose: bool = False,
149
+ skip_invalid_size_inputs=False,
150
+ inference_step_args=None,
151
+ **kwargs
152
+ ) -> List[List[Dict[str, torch.Tensor]]]:
153
+ if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1:
154
+ return self.generate(
155
+ tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs
156
+ )[0]
157
+
158
+ # build generator using current args as well as any kwargs
159
+ gen_args = copy.copy(self.args)
160
+ gen_args.beam = beam
161
+ for k, v in kwargs.items():
162
+ setattr(gen_args, k, v)
163
+ generator = self.task.build_generator(self.models, gen_args)
164
+
165
+ inference_step_args = inference_step_args or {}
166
+ results = []
167
+ for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
168
+ batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
169
+ translations = self.task.inference_step(
170
+ generator, self.models, batch, **inference_step_args
171
+ )
172
+ for id, hypos in zip(batch["id"].tolist(), translations):
173
+ results.append((id, hypos))
174
+
175
+ # sort output to match input order
176
+ outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
177
+
178
+ if verbose:
179
+
180
+ def getarg(name, default):
181
+ return getattr(gen_args, name, getattr(self.args, name, default))
182
+
183
+ for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
184
+ src_str_with_unk = self.string(source_tokens)
185
+ logger.info("S\t{}".format(src_str_with_unk))
186
+ for hypo in target_hypotheses:
187
+ hypo_str = self.decode(hypo["tokens"])
188
+ logger.info("H\t{}\t{}".format(hypo["score"], hypo_str))
189
+ logger.info(
190
+ "P\t{}".format(
191
+ " ".join(
192
+ map(
193
+ lambda x: "{:.4f}".format(x),
194
+ hypo["positional_scores"].tolist(),
195
+ )
196
+ )
197
+ )
198
+ )
199
+ if hypo["alignment"] is not None and getarg(
200
+ "print_alignment", False
201
+ ):
202
+ logger.info(
203
+ "A\t{}".format(
204
+ " ".join(
205
+ [
206
+ "{}-{}".format(src_idx, tgt_idx)
207
+ for src_idx, tgt_idx in hypo["alignment"]
208
+ ]
209
+ )
210
+ )
211
+ )
212
+ return outputs
213
+
214
+ def encode(self, sentence: str) -> torch.LongTensor:
215
+ sentence = self.tokenize(sentence)
216
+ sentence = self.apply_bpe(sentence)
217
+ return self.binarize(sentence)
218
+
219
+ def decode(self, tokens: torch.LongTensor) -> str:
220
+ sentence = self.string(tokens)
221
+ sentence = self.remove_bpe(sentence)
222
+ return self.detokenize(sentence)
223
+
224
+ def tokenize(self, sentence: str) -> str:
225
+ if self.tokenizer is not None:
226
+ sentence = self.tokenizer.encode(sentence)
227
+ return sentence
228
+
229
+ def detokenize(self, sentence: str) -> str:
230
+ if self.tokenizer is not None:
231
+ sentence = self.tokenizer.decode(sentence)
232
+ return sentence
233
+
234
+ def apply_bpe(self, sentence: str) -> str:
235
+ if self.bpe is not None:
236
+ sentence = self.bpe.encode(sentence)
237
+ return sentence
238
+
239
+ def remove_bpe(self, sentence: str) -> str:
240
+ if self.bpe is not None:
241
+ sentence = self.bpe.decode(sentence)
242
+ return sentence
243
+
244
+ def binarize(self, sentence: str) -> torch.LongTensor:
245
+ return self.src_dict.encode_line(sentence, add_if_not_exist=False).long()
246
+
247
+ def string(self, tokens: torch.LongTensor) -> str:
248
+ return self.tgt_dict.string(tokens)
249
+
250
+ def _build_batches(
251
+ self, tokens: List[List[int]], skip_invalid_size_inputs: bool
252
+ ) -> Iterator[Dict[str, Any]]:
253
+ lengths = torch.LongTensor([t.numel() for t in tokens])
254
+ batch_iterator = self.task.get_batch_iterator(
255
+ dataset=self.task.build_dataset_for_inference(tokens, lengths),
256
+ max_tokens=self.args.max_tokens,
257
+ max_sentences=self.args.batch_size,
258
+ max_positions=self.max_positions,
259
+ ignore_invalid_inputs=skip_invalid_size_inputs,
260
+ disable_iterator_cache=True,
261
+ ).next_epoch_itr(shuffle=False)
262
+ return batch_iterator
263
+
264
+
265
+ class BPEHubInterface(object):
266
+ """PyTorch Hub interface for Byte-Pair Encoding (BPE)."""
267
+
268
+ def __init__(self, bpe, **kwargs):
269
+ super().__init__()
270
+ args = argparse.Namespace(bpe=bpe, **kwargs)
271
+ self.bpe = encoders.build_bpe(args)
272
+ assert self.bpe is not None
273
+
274
+ def encode(self, sentence: str) -> str:
275
+ return self.bpe.encode(sentence)
276
+
277
+ def decode(self, sentence: str) -> str:
278
+ return self.bpe.decode(sentence)
279
+
280
+
281
+ class TokenizerHubInterface(object):
282
+ """PyTorch Hub interface for tokenization."""
283
+
284
+ def __init__(self, tokenizer, **kwargs):
285
+ super().__init__()
286
+ args = argparse.Namespace(tokenizer=tokenizer, **kwargs)
287
+ self.tokenizer = encoders.build_tokenizer(args)
288
+ assert self.tokenizer is not None
289
+
290
+ def encode(self, sentence: str) -> str:
291
+ return self.tokenizer.encode(sentence)
292
+
293
+ def decode(self, sentence: str) -> str:
294
+ return self.tokenizer.decode(sentence)
fairseq-0.10.2/fairseq/iterative_refinement_generator.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import namedtuple
7
+
8
+ import numpy as np
9
+ import torch
10
+ from fairseq import utils
11
+
12
+
13
+ DecoderOut = namedtuple(
14
+ "IterativeRefinementDecoderOut",
15
+ ["output_tokens", "output_scores", "attn", "step", "max_step", "history"],
16
+ )
17
+
18
+
19
+ class IterativeRefinementGenerator(object):
20
+ def __init__(
21
+ self,
22
+ tgt_dict,
23
+ models=None,
24
+ eos_penalty=0.0,
25
+ max_iter=10,
26
+ max_ratio=2,
27
+ beam_size=1,
28
+ decoding_format=None,
29
+ retain_dropout=False,
30
+ adaptive=True,
31
+ retain_history=False,
32
+ reranking=False,
33
+ ):
34
+ """
35
+ Generates translations based on iterative refinement.
36
+
37
+ Args:
38
+ tgt_dict: target dictionary
39
+ eos_penalty: if > 0.0, it penalized early-stopping in decoding
40
+ max_iter: maximum number of refinement iterations
41
+ max_ratio: generate sequences of maximum length ax, where x is the source length
42
+ decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'}
43
+ retain_dropout: retaining dropout in the inference
44
+ adaptive: decoding with early stop
45
+ """
46
+ self.bos = tgt_dict.bos()
47
+ self.pad = tgt_dict.pad()
48
+ self.unk = tgt_dict.unk()
49
+ self.eos = tgt_dict.eos()
50
+ self.vocab_size = len(tgt_dict)
51
+ self.eos_penalty = eos_penalty
52
+ self.max_iter = max_iter
53
+ self.max_ratio = max_ratio
54
+ self.beam_size = beam_size
55
+ self.reranking = reranking
56
+ self.decoding_format = decoding_format
57
+ self.retain_dropout = retain_dropout
58
+ self.retain_history = retain_history
59
+ self.adaptive = adaptive
60
+ self.models = models
61
+
62
+ def generate_batched_itr(
63
+ self,
64
+ data_itr,
65
+ maxlen_a=None,
66
+ maxlen_b=None,
67
+ cuda=False,
68
+ timer=None,
69
+ prefix_size=0,
70
+ ):
71
+ """Iterate over a batched dataset and yield individual translations.
72
+
73
+ Args:
74
+ maxlen_a/b: generate sequences of maximum length ax + b,
75
+ where x is the source sentence length.
76
+ cuda: use GPU for generation
77
+ timer: StopwatchMeter for timing generations.
78
+ """
79
+
80
+ for sample in data_itr:
81
+ if "net_input" not in sample:
82
+ continue
83
+ if timer is not None:
84
+ timer.start()
85
+ with torch.no_grad():
86
+ hypos = self.generate(
87
+ self.models,
88
+ sample,
89
+ prefix_tokens=sample["target"][:, :prefix_size]
90
+ if prefix_size > 0
91
+ else None,
92
+ )
93
+ if timer is not None:
94
+ timer.stop(sample["ntokens"])
95
+ for i, id in enumerate(sample["id"]):
96
+ # remove padding
97
+ src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad)
98
+ ref = utils.strip_pad(sample["target"][i, :], self.pad)
99
+ yield id, src, ref, hypos[i]
100
+
101
+ @torch.no_grad()
102
+ def generate(self, models, sample, prefix_tokens=None, constraints=None):
103
+ if constraints is not None:
104
+ raise NotImplementedError(
105
+ "Constrained decoding with the IterativeRefinementGenerator is not supported"
106
+ )
107
+
108
+ # TODO: iterative refinement generator does not support ensemble for now.
109
+ if not self.retain_dropout:
110
+ for model in models:
111
+ model.eval()
112
+
113
+ model, reranker = models[0], None
114
+ if self.reranking:
115
+ assert len(models) > 1, "Assuming the last checkpoint is the reranker"
116
+ assert (
117
+ self.beam_size > 1
118
+ ), "Reranking requires multiple translation for each example"
119
+
120
+ reranker = models[-1]
121
+ models = models[:-1]
122
+
123
+ if len(models) > 1 and hasattr(model, "enable_ensemble"):
124
+ assert model.allow_ensemble, "{} does not support ensembling".format(
125
+ model.__class__.__name__
126
+ )
127
+ model.enable_ensemble(models)
128
+
129
+ # TODO: better encoder inputs?
130
+ src_tokens = sample["net_input"]["src_tokens"]
131
+ src_lengths = sample["net_input"]["src_lengths"]
132
+ bsz, src_len = src_tokens.size()
133
+
134
+ # initialize
135
+ encoder_out = model.forward_encoder([src_tokens, src_lengths])
136
+ prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
137
+
138
+ if self.beam_size > 1:
139
+ assert (
140
+ model.allow_length_beam
141
+ ), "{} does not support decoding with length beam.".format(
142
+ model.__class__.__name__
143
+ )
144
+
145
+ # regenerate data based on length-beam
146
+ length_beam_order = (
147
+ utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1)
148
+ )
149
+ encoder_out = model.encoder.reorder_encoder_out(
150
+ encoder_out, length_beam_order
151
+ )
152
+ prev_decoder_out = model.regenerate_length_beam(
153
+ prev_decoder_out, self.beam_size
154
+ )
155
+ bsz = bsz * self.beam_size
156
+
157
+ sent_idxs = torch.arange(bsz)
158
+ prev_output_tokens = prev_decoder_out.output_tokens.clone()
159
+
160
+ if self.retain_history:
161
+ prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens])
162
+
163
+ finalized = [[] for _ in range(bsz)]
164
+
165
+ def is_a_loop(x, y, s, a):
166
+ b, l_x, l_y = x.size(0), x.size(1), y.size(1)
167
+ if l_x > l_y:
168
+ y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1)
169
+ s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1)
170
+ if a is not None:
171
+ a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1)
172
+ elif l_x < l_y:
173
+ x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1)
174
+ return (x == y).all(1), y, s, a
175
+
176
+ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
177
+ cutoff = prev_out_token.ne(self.pad)
178
+ tokens = prev_out_token[cutoff]
179
+ if prev_out_score is None:
180
+ scores, score = None, None
181
+ else:
182
+ scores = prev_out_score[cutoff]
183
+ score = scores.mean()
184
+
185
+ if prev_out_attn is None:
186
+ hypo_attn, alignment = None, None
187
+ else:
188
+ hypo_attn = prev_out_attn[cutoff]
189
+ alignment = hypo_attn.max(dim=1)[1]
190
+ return {
191
+ "steps": step,
192
+ "tokens": tokens,
193
+ "positional_scores": scores,
194
+ "score": score,
195
+ "hypo_attn": hypo_attn,
196
+ "alignment": alignment,
197
+ }
198
+
199
+ for step in range(self.max_iter + 1):
200
+
201
+ decoder_options = {
202
+ "eos_penalty": self.eos_penalty,
203
+ "max_ratio": self.max_ratio,
204
+ "decoding_format": self.decoding_format,
205
+ }
206
+ prev_decoder_out = prev_decoder_out._replace(
207
+ step=step,
208
+ max_step=self.max_iter + 1,
209
+ )
210
+
211
+ decoder_out = model.forward_decoder(
212
+ prev_decoder_out, encoder_out, **decoder_options
213
+ )
214
+
215
+ if self.adaptive:
216
+ # terminate if there is a loop
217
+ terminated, out_tokens, out_scores, out_attn = is_a_loop(
218
+ prev_output_tokens,
219
+ decoder_out.output_tokens,
220
+ decoder_out.output_scores,
221
+ decoder_out.attn,
222
+ )
223
+ decoder_out = decoder_out._replace(
224
+ output_tokens=out_tokens,
225
+ output_scores=out_scores,
226
+ attn=out_attn,
227
+ )
228
+
229
+ else:
230
+ terminated = decoder_out.output_tokens.new_zeros(
231
+ decoder_out.output_tokens.size(0)
232
+ ).bool()
233
+
234
+ if step == self.max_iter: # reach last iteration, terminate
235
+ terminated.fill_(1)
236
+
237
+ # collect finalized sentences
238
+ finalized_idxs = sent_idxs[terminated]
239
+ finalized_tokens = decoder_out.output_tokens[terminated]
240
+ finalized_scores = decoder_out.output_scores[terminated]
241
+ finalized_attn = (
242
+ None
243
+ if (decoder_out.attn is None or decoder_out.attn.size(0) == 0)
244
+ else decoder_out.attn[terminated]
245
+ )
246
+
247
+ if self.retain_history:
248
+ finalized_history_tokens = [h[terminated] for h in decoder_out.history]
249
+
250
+ for i in range(finalized_idxs.size(0)):
251
+ finalized[finalized_idxs[i]] = [
252
+ finalized_hypos(
253
+ step,
254
+ finalized_tokens[i],
255
+ finalized_scores[i],
256
+ None if finalized_attn is None else finalized_attn[i],
257
+ )
258
+ ]
259
+
260
+ if self.retain_history:
261
+ finalized[finalized_idxs[i]][0]["history"] = []
262
+ for j in range(len(finalized_history_tokens)):
263
+ finalized[finalized_idxs[i]][0]["history"].append(
264
+ finalized_hypos(
265
+ step, finalized_history_tokens[j][i], None, None
266
+ )
267
+ )
268
+
269
+ # check if all terminated
270
+ if terminated.sum() == terminated.size(0):
271
+ break
272
+
273
+ # for next step
274
+ not_terminated = ~terminated
275
+ prev_decoder_out = decoder_out._replace(
276
+ output_tokens=decoder_out.output_tokens[not_terminated],
277
+ output_scores=decoder_out.output_scores[not_terminated],
278
+ attn=decoder_out.attn[not_terminated]
279
+ if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0)
280
+ else None,
281
+ history=[h[not_terminated] for h in decoder_out.history]
282
+ if decoder_out.history is not None
283
+ else None,
284
+ )
285
+ encoder_out = model.encoder.reorder_encoder_out(
286
+ encoder_out, not_terminated.nonzero(as_tuple=False).squeeze()
287
+ )
288
+ sent_idxs = sent_idxs[not_terminated]
289
+ prev_output_tokens = prev_decoder_out.output_tokens.clone()
290
+
291
+ if self.beam_size > 1:
292
+ if reranker is not None:
293
+ finalized = self.rerank(
294
+ reranker, finalized, [src_tokens, src_lengths], self.beam_size
295
+ )
296
+
297
+ # aggregate information from length beam
298
+ finalized = [
299
+ finalized[
300
+ np.argmax(
301
+ [
302
+ finalized[self.beam_size * i + j][0]["score"]
303
+ for j in range(self.beam_size)
304
+ ]
305
+ )
306
+ + self.beam_size * i
307
+ ]
308
+ for i in range(len(finalized) // self.beam_size)
309
+ ]
310
+
311
+ return finalized
312
+
313
+ def rerank(self, reranker, finalized, encoder_input, beam_size):
314
+ def rebuild_batch(finalized):
315
+ finalized_tokens = [f[0]["tokens"] for f in finalized]
316
+ finalized_maxlen = max(f.size(0) for f in finalized_tokens)
317
+ final_output_tokens = (
318
+ finalized_tokens[0]
319
+ .new_zeros(len(finalized_tokens), finalized_maxlen)
320
+ .fill_(self.pad)
321
+ )
322
+ for i, f in enumerate(finalized_tokens):
323
+ final_output_tokens[i, : f.size(0)] = f
324
+ return final_output_tokens
325
+
326
+ final_output_tokens = rebuild_batch(finalized)
327
+ final_output_tokens[
328
+ :, 0
329
+ ] = self.eos # autoregressive model assumes starting with EOS
330
+
331
+ reranker_encoder_out = reranker.encoder(*encoder_input)
332
+ length_beam_order = (
333
+ utils.new_arange(
334
+ final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1)
335
+ )
336
+ .t()
337
+ .reshape(-1)
338
+ )
339
+ reranker_encoder_out = reranker.encoder.reorder_encoder_out(
340
+ reranker_encoder_out, length_beam_order
341
+ )
342
+ reranking_scores = reranker.get_normalized_probs(
343
+ reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out),
344
+ True,
345
+ None,
346
+ )
347
+ reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None])
348
+ reranking_masks = final_output_tokens[:, 1:].ne(self.pad)
349
+ reranking_scores = (
350
+ reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1)
351
+ )
352
+ reranking_scores = reranking_scores / reranking_masks.sum(1).type_as(
353
+ reranking_scores
354
+ )
355
+
356
+ for i in range(len(finalized)):
357
+ finalized[i][0]["score"] = reranking_scores[i]
358
+
359
+ return finalized
fairseq-0.10.2/fairseq/legacy_distributed_data_parallel.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ A modified version of the legacy DistributedDataParallel module that uses c10d
8
+ communication primitives. This version is simpler than the latest PyTorch
9
+ version and is useful for debugging. Notably it does not overlap gradient
10
+ communication with the backward pass, which makes it slower but more robust
11
+ than the PyTorch version.
12
+
13
+ This version also supports the *no_sync* context manager, which allows faster
14
+ training with `--update-freq`.
15
+ """
16
+
17
+ import copy
18
+ from collections import OrderedDict
19
+ from contextlib import contextmanager
20
+
21
+ import torch
22
+ from torch import nn
23
+ from torch.autograd import Variable
24
+
25
+ from . import distributed_utils
26
+
27
+
28
+ class LegacyDistributedDataParallel(nn.Module):
29
+ """Implements distributed data parallelism at the module level.
30
+
31
+ A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
32
+ This version uses a c10d process group for communication and does not
33
+ broadcast buffers.
34
+
35
+ Args:
36
+ module (~torch.nn.Module): module to be parallelized
37
+ world_size (int): number of parallel workers
38
+ process_group (optional): the c10d process group to be used for
39
+ distributed data all-reduction. If None, the default process group
40
+ will be used.
41
+ buffer_size (int, optional): number of elements to buffer before
42
+ performing all-reduce (default: 256M).
43
+ """
44
+
45
+ def __init__(self, module, world_size, process_group=None, buffer_size=2 ** 28):
46
+ super().__init__()
47
+
48
+ self.module = module
49
+ self.world_size = world_size
50
+ self.process_group = process_group
51
+
52
+ # Never use a bigger buffer than the number of model params
53
+ self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters()))
54
+ self.buffer = None
55
+
56
+ # We can also forcibly accumulate grads locally and only do the
57
+ # all-reduce at some later time
58
+ self.accumulate_grads = False
59
+
60
+ # make per-device lists of parameters
61
+ paramlists = OrderedDict()
62
+ for param in self.module.parameters():
63
+ device = param.device
64
+ if paramlists.get(device) is None:
65
+ paramlists[device] = []
66
+ paramlists[device] += [param]
67
+ self.per_device_params = list(paramlists.values())
68
+
69
+ def __getstate__(self):
70
+ attrs = copy.copy(self.__dict__)
71
+ return attrs
72
+
73
+ def __setstate__(self, state):
74
+ super().__setstate__(state)
75
+
76
+ @contextmanager
77
+ def no_sync(self):
78
+ """A context manager to disable gradient synchronization."""
79
+ old_accumulate_grads = self.accumulate_grads
80
+ self.accumulate_grads = True
81
+ yield
82
+ self.accumulate_grads = old_accumulate_grads
83
+
84
+ def forward(self, *inputs, **kwargs):
85
+ return self.module(*inputs, **kwargs)
86
+
87
+ def all_reduce(self):
88
+ """
89
+ This function must be called explicitly after backward to reduce
90
+ gradients. There is no automatic hook like c10d.
91
+ """
92
+
93
+ def all_reduce_params(params):
94
+ buffer = self.buffer
95
+ nonzero_buffer = False
96
+ if len(params) > 1:
97
+ offset = 0
98
+ for p in params:
99
+ sz = p.numel()
100
+ if p.grad is not None:
101
+ buffer[offset : offset + sz].copy_(p.grad.data.view(-1))
102
+ nonzero_buffer = True
103
+ else:
104
+ buffer[offset : offset + sz].zero_()
105
+ offset += sz
106
+ else:
107
+ # we only have a single grad to all-reduce
108
+ p = params[0]
109
+ if p.grad is not None:
110
+ buffer = p.grad.data
111
+ nonzero_buffer = True
112
+ elif p.numel() <= self.buffer.numel():
113
+ buffer = buffer[: p.numel()]
114
+ buffer.zero_()
115
+ else:
116
+ buffer = torch.zeros_like(p)
117
+
118
+ if nonzero_buffer:
119
+ buffer.div_(self.world_size)
120
+
121
+ distributed_utils.all_reduce(buffer, self.process_group)
122
+
123
+ # copy all-reduced grads back into their original place
124
+ offset = 0
125
+ for p in params:
126
+ sz = p.numel()
127
+ if p.grad is not None:
128
+ p.grad.data.copy_(buffer[offset : offset + sz].view_as(p))
129
+ else:
130
+ p.grad = buffer[offset : offset + sz].view_as(p).clone()
131
+ offset += sz
132
+
133
+ def reduction_fn():
134
+ # This function only needs to be called once
135
+ if self.accumulate_grads:
136
+ return
137
+
138
+ if self.buffer is None:
139
+ self.buffer = next(self.module.parameters()).new(self.buffer_size)
140
+
141
+ for params in self.per_device_params:
142
+ # All-reduce the gradients in buckets
143
+ offset = 0
144
+ buffered_params = []
145
+ for param in params:
146
+ if not param.requires_grad:
147
+ continue
148
+ if param.grad is None:
149
+ param.grad = torch.zeros_like(param)
150
+ if param.grad.requires_grad:
151
+ raise RuntimeError(
152
+ "DistributedDataParallel only works "
153
+ "with gradients that don't require "
154
+ "grad"
155
+ )
156
+ sz = param.numel()
157
+ if sz > self.buffer.numel():
158
+ # all-reduce big params directly
159
+ all_reduce_params([param])
160
+ else:
161
+ if offset + sz > self.buffer.numel():
162
+ all_reduce_params(buffered_params)
163
+ offset = 0
164
+ buffered_params.clear()
165
+ buffered_params.append(param)
166
+ offset += sz
167
+
168
+ if len(buffered_params) > 0:
169
+ all_reduce_params(buffered_params)
170
+
171
+ reduction_fn()
fairseq-0.10.2/fairseq/model_parallel/__pycache__/megatron_trainer.cpython-310.pyc ADDED
Binary file (2.41 kB). View file
 
fairseq-0.10.2/fairseq/model_parallel/criterions/__pycache__/vocab_parallel_cross_entropy.cpython-310.pyc ADDED
Binary file (3.5 kB). View file
 
fairseq-0.10.2/fairseq/model_parallel/megatron_trainer.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ Train a network across multiple GPUs.
8
+ """
9
+
10
+ from fairseq import distributed_utils
11
+ from fairseq.trainer import Trainer
12
+
13
+
14
+ try:
15
+ from fairseq.model_parallel.megatron.mpu import (
16
+ get_data_parallel_group,
17
+ get_data_parallel_rank,
18
+ get_data_parallel_world_size,
19
+ get_model_parallel_group,
20
+ get_model_parallel_src_rank,
21
+ )
22
+
23
+ has_megatron_submodule = True
24
+ except (ImportError, ModuleNotFoundError):
25
+ has_megatron_submodule = False
26
+
27
+
28
+ class MegatronTrainer(Trainer):
29
+ """Main class for model parallel with data parallel training."""
30
+
31
+ def __init__(self, args, task, model, criterion):
32
+ if not has_megatron_submodule:
33
+ raise ImportError(
34
+ "\n\nPlease install the megatron submodule:"
35
+ "\n\n git submodule update --init "
36
+ "fairseq/model_parallel/megatron"
37
+ )
38
+ super().__init__(args, task, model, criterion)
39
+
40
+ @property
41
+ def data_parallel_world_size(self):
42
+ return get_data_parallel_world_size()
43
+
44
+ @property
45
+ def data_parallel_process_group(self):
46
+ return get_data_parallel_group()
47
+
48
+ @property
49
+ def data_parallel_rank(self):
50
+ return get_data_parallel_rank()
51
+
52
+ @property
53
+ def is_data_parallel_master(self):
54
+ return get_model_parallel_src_rank() == 0
55
+
56
+ def clip_grad_norm(self, clip_norm):
57
+ def _aggregate_model_parallel_grad_norm(total_norm):
58
+ total_norm = total_norm ** 2
59
+ distributed_utils.all_reduce(total_norm, group=get_model_parallel_group())
60
+ total_norm = total_norm ** 0.5
61
+ return total_norm
62
+
63
+ return self.optimizer.clip_grad_norm(
64
+ clip_norm,
65
+ aggregate_norm_fn=_aggregate_model_parallel_grad_norm,
66
+ )
fairseq-0.10.2/fairseq/model_parallel/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (544 Bytes). View file
 
fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .model import * # noqa
fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (230 Bytes). View file
 
fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from collections import namedtuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from fairseq import options, utils
13
+ from fairseq.modules import (
14
+ AdaptiveSoftmax,
15
+ LayerNorm,
16
+ MultiheadAttention,
17
+ PositionalEmbedding,
18
+ )
19
+
20
+
21
+ EncoderOut = namedtuple(
22
+ "TransformerEncoderOut",
23
+ [
24
+ "encoder_out", # T x B x C
25
+ "encoder_padding_mask", # B x T
26
+ "encoder_embedding", # B x T x C
27
+ "encoder_states", # List[T x B x C]
28
+ ],
29
+ )
30
+
31
+
32
+ class TransformerEncoderEmbedding(nn.Module):
33
+ """ Encoder Embedding + Positional Embedding """
34
+
35
+ def __init__(self, args, embed_tokens):
36
+ super().__init__()
37
+ self.dropout = args.dropout
38
+ self.max_source_positions = args.max_source_positions
39
+ self.embed_tokens = embed_tokens
40
+ if isinstance(embed_tokens, nn.ModuleList):
41
+ self.padding_idx = embed_tokens[0].padding_idx
42
+ embed_dim = sum(e.embedding_dim for e in embed_tokens)
43
+ else:
44
+ self.padding_idx = embed_tokens.padding_idx
45
+ embed_dim = embed_tokens.embedding_dim
46
+ self.embed_scale = math.sqrt(embed_dim)
47
+ self.embed_positions = (
48
+ PositionalEmbedding(
49
+ args.max_source_positions,
50
+ embed_dim,
51
+ self.padding_idx,
52
+ learned=args.encoder_learned_pos,
53
+ )
54
+ if not args.no_token_positional_embeddings
55
+ else None
56
+ )
57
+ if getattr(args, "layernorm_embedding", False):
58
+ self.layernorm_embedding = LayerNorm(embed_dim)
59
+ else:
60
+ self.layernorm_embedding = None
61
+
62
+ def forward(self, input):
63
+ # embed tokens and positions
64
+ src_tokens = input[0]
65
+ prev_output_tokens = input[2]
66
+ if isinstance(self.embed_tokens, nn.ModuleList):
67
+ x_embed_list = []
68
+ for embed_tokens_part in self.embed_tokens:
69
+ x_embed_list.append(embed_tokens_part(src_tokens))
70
+
71
+ embedded = torch.cat(x_embed_list, dim=-1)
72
+ else:
73
+ embedded = self.embed_tokens(src_tokens)
74
+ x = embed = self.embed_scale * embedded
75
+ if self.embed_positions is not None:
76
+ x = embed + self.embed_positions(src_tokens)
77
+ if self.layernorm_embedding:
78
+ x = self.layernorm_embedding(x)
79
+ x = F.dropout(x, p=self.dropout, training=self.training)
80
+ # B x T x C -> T x B x C
81
+ x = x.transpose(0, 1)
82
+
83
+ # compute padding mask
84
+ encoder_padding_mask = src_tokens.eq(self.padding_idx)
85
+ return (x, encoder_padding_mask, prev_output_tokens)
86
+
87
+
88
+ class TransformerEncoderLayerNorm(nn.Module):
89
+ """
90
+ Layer norm at the the end of all encoder layers if
91
+ args.encoder_enormalize_before = True
92
+ """
93
+
94
+ def __init__(self, args, embed_dim):
95
+ super().__init__()
96
+ if args.encoder_normalize_before:
97
+ self.layer_norm = LayerNorm(embed_dim)
98
+ else:
99
+ self.layer_norm = None
100
+
101
+ def forward(self, input):
102
+ x = input[0]
103
+ encoder_padding_mask = input[1]
104
+ prev_output_tokens = input[2]
105
+ if self.layer_norm:
106
+ x = self.layer_norm(x)
107
+ # keeping track of the incremental_state is not supported yet
108
+ return (x, encoder_padding_mask, prev_output_tokens)
109
+
110
+
111
+ class TransformerDecoderEmbedding(nn.Module):
112
+ """ Decoder Embedding + Positional Embedding """
113
+
114
+ def __init__(self, args, embed_tokens):
115
+ super().__init__()
116
+ self.dropout = args.dropout
117
+ self.share_input_output_embed = args.share_decoder_input_output_embed
118
+ input_embed_dim = (
119
+ sum(e.embedding_dim for e in embed_tokens)
120
+ if isinstance(embed_tokens, nn.ModuleList)
121
+ else embed_tokens.embedding_dim
122
+ )
123
+ embed_dim = args.decoder_embed_dim
124
+ self.output_embed_dim = args.decoder_output_dim
125
+
126
+ padding_idx = (
127
+ embed_tokens[0].padding_idx
128
+ if isinstance(embed_tokens, nn.ModuleList)
129
+ else embed_tokens.padding_idx
130
+ )
131
+ self.max_target_positions = args.max_target_positions
132
+
133
+ self.embed_tokens = embed_tokens
134
+ self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
135
+
136
+ self.project_in_dim = (
137
+ Linear(input_embed_dim, embed_dim, bias=False)
138
+ if embed_dim != input_embed_dim
139
+ else None
140
+ )
141
+
142
+ self.embed_positions = (
143
+ PositionalEmbedding(
144
+ args.max_target_positions,
145
+ embed_dim,
146
+ padding_idx,
147
+ learned=args.decoder_learned_pos,
148
+ )
149
+ if not args.no_token_positional_embeddings
150
+ else None
151
+ )
152
+
153
+ def forward(self, input):
154
+ mt_task = False
155
+ if isinstance(input, tuple):
156
+ if len(input) == 3:
157
+ encoder_out = input[0]
158
+ encoder_padding_mask = input[1]
159
+ prev_output_tokens = input[2]
160
+ incremental_state = None # Hardcoding to avoid passing of None objects
161
+ mt_task = True
162
+ else:
163
+ # HACK for now, need to fix (TODO sidgoyal)
164
+ prev_output_tokens = input[0]
165
+ # discard "src_lengths"
166
+ encoder_out = None
167
+ encoder_padding_mask = None
168
+ incremental_state = None
169
+
170
+ else:
171
+ prev_output_tokens = input
172
+ encoder_out = None
173
+ encoder_padding_mask = None
174
+ incremental_state = None
175
+
176
+ positions = (
177
+ self.embed_positions(
178
+ prev_output_tokens,
179
+ incremental_state=incremental_state,
180
+ )
181
+ if self.embed_positions is not None
182
+ else None
183
+ )
184
+
185
+ if incremental_state is not None:
186
+ prev_output_tokens = prev_output_tokens[:, -1:]
187
+ if positions is not None:
188
+ positions = positions[:, -1:]
189
+
190
+ # embed tokens and positions
191
+
192
+ if isinstance(self.embed_tokens, nn.ModuleList):
193
+ x_embed_list = []
194
+ for embed_tokens_part in self.embed_tokens:
195
+ x_embed_list.append(embed_tokens_part(prev_output_tokens))
196
+
197
+ x = self.embed_scale * torch.cat(x_embed_list, dim=-1)
198
+ else:
199
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
200
+
201
+ if self.project_in_dim is not None:
202
+ x = self.project_in_dim(x)
203
+
204
+ if positions is not None:
205
+ x += positions
206
+ x = F.dropout(x, p=self.dropout, training=self.training)
207
+
208
+ # B x T x C -> T x B x C
209
+ x = x.transpose(0, 1)
210
+ if mt_task:
211
+ return (x, encoder_out, encoder_padding_mask)
212
+ return x
213
+
214
+
215
+ class TransformerDecoderOutputLayer(nn.Module):
216
+ def __init__(self, args, embed_tokens, dictionary):
217
+ super().__init__()
218
+ self.share_input_output_embed = args.share_decoder_input_output_embed
219
+ self.embed_tokens = embed_tokens
220
+ self.output_embed_dim = args.decoder_output_dim
221
+ embed_dim = args.decoder_embed_dim
222
+
223
+ self.project_out_dim = (
224
+ Linear(embed_dim, self.output_embed_dim, bias=False)
225
+ if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
226
+ else None
227
+ )
228
+ self.adaptive_softmax = None
229
+ if args.adaptive_softmax_cutoff is not None:
230
+ assert not isinstance(embed_tokens, nn.ModuleList)
231
+ self.adaptive_softmax = AdaptiveSoftmax(
232
+ len(dictionary),
233
+ self.output_embed_dim,
234
+ options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
235
+ dropout=args.adaptive_softmax_dropout,
236
+ adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
237
+ factor=args.adaptive_softmax_factor,
238
+ tie_proj=args.tie_adaptive_proj,
239
+ )
240
+ elif not self.share_input_output_embed:
241
+ self.embed_tokens = nn.Parameter(
242
+ torch.Tensor(len(dictionary), self.output_embed_dim)
243
+ )
244
+ nn.init.normal_(
245
+ self.embed_tokens, mean=0, std=self.output_embed_dim ** -0.5
246
+ )
247
+
248
+ if args.decoder_normalize_before and not getattr(
249
+ args, "no_decoder_final_norm", False
250
+ ):
251
+ self.layer_norm = LayerNorm(embed_dim)
252
+ else:
253
+ self.layer_norm = None
254
+
255
+ def forward(self, input, apply_final_proj=True):
256
+ if isinstance(input, tuple):
257
+ x = input[0]
258
+ else:
259
+ x = input
260
+
261
+ if self.layer_norm:
262
+ x = self.layer_norm(x)
263
+
264
+ # T x B x C -> B x T x C
265
+ x = x.transpose(0, 1)
266
+
267
+ if self.project_out_dim is not None:
268
+ x = self.project_out_dim(x)
269
+ if apply_final_proj:
270
+ x = self.output_layer(x)
271
+ return x
272
+
273
+ def output_layer(self, features, **kwargs):
274
+ """Project features to the vocabulary size."""
275
+ if self.adaptive_softmax is None:
276
+ # project back to size of vocabulary
277
+ if self.share_input_output_embed:
278
+ if isinstance(self.embed_tokens, nn.ModuleList):
279
+ output = None
280
+ for i, emb in enumerate(self.embed_tokens):
281
+ sidx = i * emb.embedding_dim
282
+ eidx = (i + 1) * emb.embedding_dim
283
+ if output is None:
284
+ output = F.linear(features[:, :, sidx:eidx], emb.weight)
285
+ else:
286
+ output += F.linear(features[:, :, sidx:eidx], emb.weight)
287
+
288
+ return output
289
+ else:
290
+ return F.linear(features, self.embed_tokens.weight)
291
+ else:
292
+ return F.linear(features, self.embed_tokens)
293
+ else:
294
+ return features
295
+
296
+
297
+ class TransformerEncoderLayer(nn.Module):
298
+ """Encoder layer block.
299
+ In the original paper each operation (multi-head attention or FFN) is
300
+ postprocessed with: `dropout -> add residual -> layernorm`. In the
301
+ tensor2tensor code they suggest that learning is more robust when
302
+ preprocessing each layer with layernorm and postprocessing with:
303
+ `dropout -> add residual`. We default to the approach in the paper, but the
304
+ tensor2tensor approach can be enabled by setting
305
+ *args.encoder_normalize_before* to ``True``.
306
+
307
+ Args:
308
+ args (argparse.Namespace): parsed command-line arguments
309
+ """
310
+
311
+ def __init__(self, args):
312
+ super().__init__()
313
+ self.embed_dim = args.encoder_embed_dim
314
+ self.self_attn = MultiheadAttention(
315
+ self.embed_dim,
316
+ args.encoder_attention_heads,
317
+ dropout=args.attention_dropout,
318
+ self_attention=True,
319
+ )
320
+ self.self_attn_layer_norm = LayerNorm(self.embed_dim)
321
+ self.dropout = args.dropout
322
+ self.activation_fn = utils.get_activation_fn(
323
+ activation=getattr(args, "activation_fn", "relu")
324
+ )
325
+ self.activation_dropout = getattr(args, "activation_dropout", 0)
326
+ if self.activation_dropout == 0:
327
+ # for backwards compatibility with models that use args.relu_dropout
328
+ self.activation_dropout = getattr(args, "relu_dropout", 0)
329
+ self.normalize_before = args.encoder_normalize_before
330
+ self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
331
+ self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
332
+ self.final_layer_norm = LayerNorm(self.embed_dim)
333
+
334
+ def upgrade_state_dict_named(self, state_dict, name):
335
+ """
336
+ Rename layer norm states from `...layer_norms.0.weight` to
337
+ `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
338
+ `...final_layer_norm.weight`
339
+ """
340
+ layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
341
+ for old, new in layer_norm_map.items():
342
+ for m in ("weight", "bias"):
343
+ k = "{}.layer_norms.{}.{}".format(name, old, m)
344
+ if k in state_dict:
345
+ state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
346
+ del state_dict[k]
347
+
348
+ def forward(self, input):
349
+ """
350
+ Args:
351
+ input (Tuple):
352
+ input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
353
+ input[1] (ByteTensor/FloatTensor): encoder padding mask -
354
+ binary ByteTensor of shape `(batch, src_len)` where padding elements
355
+ are indicated by ``1``.
356
+ input[2] (LongTensor): previous decoder outputs of shape
357
+ `(batch, tgt_len)`, for teacher forcing)
358
+ Returns:
359
+ output (Tuple):
360
+ output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
361
+ output[1] (ByteTensor/FloatTensor): encoder padding mask
362
+ output[2] (LongTensor): previous decoder outputs
363
+ """
364
+ x = input[0]
365
+ encoder_padding_mask = input[1]
366
+ prev_output_tokens = input[2]
367
+ residual = x
368
+ x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
369
+ x, _ = self.self_attn(
370
+ query=x, key=x, value=x, key_padding_mask=encoder_padding_mask
371
+ )
372
+ x = F.dropout(x, p=self.dropout, training=self.training)
373
+ x = residual + x
374
+ x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
375
+
376
+ residual = x
377
+ x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
378
+ x = self.activation_fn(self.fc1(x))
379
+ x = F.dropout(x, p=self.activation_dropout, training=self.training)
380
+ x = self.fc2(x)
381
+ x = F.dropout(x, p=self.dropout, training=self.training)
382
+ x = residual + x
383
+ x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
384
+ return (x, encoder_padding_mask, prev_output_tokens)
385
+
386
+ def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
387
+ assert before ^ after
388
+ if after ^ self.normalize_before:
389
+ return layer_norm(x)
390
+ else:
391
+ return x
392
+
393
+
394
+ class TransformerDecoderLayer(nn.Module):
395
+ """Decoder layer block.
396
+
397
+ In the original paper each operation (multi-head attention, encoder
398
+ attention or FFN) is postprocessed with: `dropout -> add residual ->
399
+ layernorm`. In the tensor2tensor code they suggest that learning is more
400
+ robust when preprocessing each layer with layernorm and postprocessing with:
401
+ `dropout -> add residual`. We default to the approach in the paper, but the
402
+ tensor2tensor approach can be enabled by setting
403
+ *args.decoder_normalize_before* to ``True``.
404
+
405
+ Args:
406
+ args (argparse.Namespace): parsed command-line arguments
407
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
408
+ (default: False).
409
+ """
410
+
411
+ def __init__(
412
+ self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
413
+ ):
414
+ super().__init__()
415
+ self.embed_dim = args.decoder_embed_dim
416
+ self.self_attn = MultiheadAttention(
417
+ embed_dim=self.embed_dim,
418
+ num_heads=args.decoder_attention_heads,
419
+ dropout=args.attention_dropout,
420
+ add_bias_kv=add_bias_kv,
421
+ add_zero_attn=add_zero_attn,
422
+ self_attention=True,
423
+ )
424
+ self.dropout = args.dropout
425
+ self.activation_fn = utils.get_activation_fn(
426
+ activation=getattr(args, "activation_fn", "relu")
427
+ )
428
+ self.activation_dropout = getattr(args, "activation_dropout", 0)
429
+ if self.activation_dropout == 0:
430
+ # for backwards compatibility with models that use args.relu_dropout
431
+ self.activation_dropout = getattr(args, "relu_dropout", 0)
432
+ self.normalize_before = args.decoder_normalize_before
433
+
434
+ # use layerNorm rather than FusedLayerNorm for exporting.
435
+ # char_inputs can be used to determint this.
436
+ # TODO remove this once we update apex with the fix
437
+ export = getattr(args, "char_inputs", False)
438
+ self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
439
+
440
+ if no_encoder_attn:
441
+ self.encoder_attn = None
442
+ self.encoder_attn_layer_norm = None
443
+ else:
444
+ self.encoder_attn = MultiheadAttention(
445
+ self.embed_dim,
446
+ args.decoder_attention_heads,
447
+ kdim=getattr(args, "encoder_embed_dim", None),
448
+ vdim=getattr(args, "encoder_embed_dim", None),
449
+ dropout=args.attention_dropout,
450
+ encoder_decoder_attention=True,
451
+ )
452
+ self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
453
+
454
+ self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
455
+ self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
456
+
457
+ self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
458
+ self.need_attn = True
459
+
460
+ self.onnx_trace = False
461
+
462
+ def prepare_for_onnx_export_(self):
463
+ self.onnx_trace = True
464
+
465
+ def forward(self, input):
466
+ """
467
+ Args:
468
+ input (Tuple):
469
+ input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
470
+ input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)`
471
+ input[2] (ByteTensor/FloatTensor): encoder padding mask -
472
+ binary ByteTensor of shape `(batch, src_len)` where padding elements
473
+ are indicated by ``1``.
474
+ Returns:
475
+ output (Tuple):
476
+ output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
477
+ output[1] (ByteTensor/FloatTensor): encoder padding mask
478
+ output[2] (LongTensor): previous decoder outputs
479
+ """
480
+ # Note: incremental state is not yet supported
481
+ mt_task = False
482
+ if isinstance(input, tuple):
483
+ x = input[0]
484
+ encoder_out = input[1]
485
+ encoder_padding_mask = input[2]
486
+ incremental_state = None
487
+ mt_task = True
488
+ else:
489
+ x = input
490
+ encoder_out = None
491
+ encoder_padding_mask = None
492
+ incremental_state = None
493
+
494
+ if incremental_state is None:
495
+ self_attn_mask = self.buffered_future_mask(x)
496
+ else:
497
+ self_attn_mask = None
498
+
499
+ # TODO: add back prev_self_attn_state, prev_attn_state,
500
+ # self_attn_padding_mask
501
+ prev_self_attn_state = None
502
+ prev_attn_state = None
503
+ self_attn_padding_mask = None
504
+
505
+ residual = x
506
+ x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
507
+ if prev_self_attn_state is not None:
508
+ if incremental_state is None:
509
+ incremental_state = {}
510
+ prev_key, prev_value = prev_self_attn_state
511
+ saved_state = {"prev_key": prev_key, "prev_value": prev_value}
512
+ self.self_attn._set_input_buffer(incremental_state, saved_state)
513
+ x, attn = self.self_attn(
514
+ query=x,
515
+ key=x,
516
+ value=x,
517
+ key_padding_mask=self_attn_padding_mask,
518
+ incremental_state=incremental_state,
519
+ need_weights=False,
520
+ attn_mask=self_attn_mask,
521
+ )
522
+ x = F.dropout(x, p=self.dropout, training=self.training)
523
+ x = residual + x
524
+ x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
525
+
526
+ if self.encoder_attn is not None:
527
+ residual = x
528
+ x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
529
+ if prev_attn_state is not None:
530
+ if incremental_state is None:
531
+ incremental_state = {}
532
+ prev_key, prev_value = prev_attn_state
533
+ saved_state = {"prev_key": prev_key, "prev_value": prev_value}
534
+ self.encoder_attn._set_input_buffer(incremental_state, saved_state)
535
+ x, attn = self.encoder_attn(
536
+ query=x,
537
+ key=encoder_out,
538
+ value=encoder_out,
539
+ key_padding_mask=encoder_padding_mask,
540
+ incremental_state=incremental_state,
541
+ static_kv=True,
542
+ need_weights=(not self.training and self.need_attn),
543
+ )
544
+ x = F.dropout(x, p=self.dropout, training=self.training)
545
+ x = residual + x
546
+ x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
547
+
548
+ residual = x
549
+ x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
550
+ x = self.activation_fn(self.fc1(x))
551
+ x = F.dropout(x, p=self.activation_dropout, training=self.training)
552
+ x = self.fc2(x)
553
+ x = F.dropout(x, p=self.dropout, training=self.training)
554
+ x = residual + x
555
+ x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
556
+
557
+ if mt_task:
558
+ return (x, encoder_out, encoder_padding_mask)
559
+ return x
560
+
561
+ def buffered_future_mask(self, tensor):
562
+ dim = tensor.size(0)
563
+ if (
564
+ not hasattr(self, "_future_mask")
565
+ or self._future_mask is None
566
+ or self._future_mask.device != tensor.device
567
+ ):
568
+ self._future_mask = torch.triu(
569
+ utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
570
+ )
571
+ if self._future_mask.size(0) < dim:
572
+ self._future_mask = torch.triu(
573
+ utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
574
+ )
575
+ return self._future_mask[:dim, :dim]
576
+
577
+ def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
578
+ assert before ^ after
579
+ if after ^ self.normalize_before:
580
+ return layer_norm(x)
581
+ else:
582
+ return x
583
+
584
+ def make_generation_fast_(self, need_attn=False, **kwargs):
585
+ self.need_attn = need_attn
586
+
587
+
588
+ def Embedding(num_embeddings, embedding_dim, padding_idx):
589
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
590
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
591
+ nn.init.constant_(m.weight[padding_idx], 0)
592
+ return m
593
+
594
+
595
+ def Linear(in_features, out_features, bias=True):
596
+ m = nn.Linear(in_features, out_features, bias)
597
+ nn.init.xavier_uniform_(m.weight)
598
+ if bias:
599
+ nn.init.constant_(m.bias, 0.0)
600
+ return m
fairseq-0.10.2/fairseq/model_parallel/models/roberta/model.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ RoBERTa: A Robustly Optimized BERT Pretraining Approach.
7
+ """
8
+
9
+ import logging
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from fairseq import utils
15
+ from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoder
16
+ from fairseq.models import FairseqEncoder, register_model, register_model_architecture
17
+ from fairseq.models.roberta import (
18
+ RobertaClassificationHead,
19
+ RobertaEncoder,
20
+ RobertaLMHead,
21
+ RobertaModel,
22
+ )
23
+ from fairseq.modules import LayerNorm, TransformerSentenceEncoder
24
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
25
+
26
+
27
+ try:
28
+ from fairseq.model_parallel.megatron.mpu import (
29
+ copy_to_model_parallel_region,
30
+ gather_from_model_parallel_region,
31
+ ColumnParallelLinear,
32
+ RowParallelLinear,
33
+ )
34
+
35
+ has_megatron_submodule = True
36
+ except (ImportError, ModuleNotFoundError):
37
+ has_megatron_submodule = False
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ @register_model("model_parallel_roberta")
43
+ class ModelParallelRobertaModel(RobertaModel):
44
+ def __init__(self, args, encoder):
45
+ super().__init__(args, encoder)
46
+
47
+ self.classification_heads = nn.ModuleDict()
48
+
49
+ @staticmethod
50
+ def add_args(parser):
51
+ super(ModelParallelRobertaModel, ModelParallelRobertaModel).add_args(parser)
52
+
53
+ @classmethod
54
+ def build_model(cls, args, task):
55
+ """Build a new model instance."""
56
+
57
+ # make sure all arguments are present
58
+ base_architecture(args)
59
+
60
+ task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
61
+ task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
62
+
63
+ if not hasattr(args, "max_positions"):
64
+ args.max_positions = args.tokens_per_sample
65
+
66
+ if getattr(args, "untie_weights_roberta", False):
67
+ raise NotImplementedError(
68
+ "--untie-weights-roberta is not supported in model parallel mode"
69
+ )
70
+
71
+ encoder = ModelParallelRobertaEncoder(args, task.source_dictionary)
72
+ return cls(args, encoder)
73
+
74
+ def forward(
75
+ self,
76
+ src_tokens,
77
+ features_only=False,
78
+ return_all_hiddens=False,
79
+ classification_head_name=None,
80
+ **kwargs
81
+ ):
82
+ if classification_head_name is not None:
83
+ features_only = True
84
+
85
+ x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
86
+
87
+ if classification_head_name is not None:
88
+ x = self.classification_heads[classification_head_name](x)
89
+ return x, extra
90
+
91
+ def register_classification_head(
92
+ self, name, num_classes=None, inner_dim=None, **kwargs
93
+ ):
94
+ """Register a classification head."""
95
+ if name in self.classification_heads:
96
+ prev_num_classes = self.classification_heads[name].out_proj.out_features
97
+ prev_inner_dim = self.classification_heads[name].dense.out_features
98
+ if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
99
+ logger.warning(
100
+ 're-registering head "{}" with num_classes {} (prev: {}) '
101
+ "and inner_dim {} (prev: {})".format(
102
+ name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
103
+ )
104
+ )
105
+ self.classification_heads[name] = ModelParallelRobertaClassificationHead(
106
+ self.args.encoder_embed_dim,
107
+ inner_dim or self.args.encoder_embed_dim,
108
+ num_classes,
109
+ self.args.pooler_activation_fn,
110
+ self.args.pooler_dropout,
111
+ )
112
+
113
+
114
+ class ModelParallelRobertaLMHead(nn.Module):
115
+ """Head for masked language modeling."""
116
+
117
+ def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
118
+ super().__init__()
119
+ self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True)
120
+ self.activation_fn = utils.get_activation_fn(activation_fn)
121
+ self.layer_norm = LayerNorm(embed_dim)
122
+
123
+ if weight is None:
124
+ weight = nn.Linear(embed_dim, output_dim, bias=False).weight
125
+ self.weight = weight
126
+ self.bias = nn.Parameter(torch.zeros(output_dim))
127
+
128
+ def forward(self, features, masked_tokens=None, **kwargs):
129
+ # Only project the unmasked tokens while training,
130
+ # saves both memory and computation
131
+ if masked_tokens is not None:
132
+ features = features[masked_tokens, :]
133
+
134
+ x = self.dense(features)
135
+ x = self.activation_fn(x)
136
+ x = self.layer_norm(x)
137
+
138
+ x = copy_to_model_parallel_region(x)
139
+ # project back to size of vocabulary with bias
140
+ x = F.linear(x, self.weight)
141
+ x = gather_from_model_parallel_region(x).contiguous()
142
+ x = x + self.bias
143
+ return x
144
+
145
+
146
+ class ModelParallelRobertaClassificationHead(nn.Module):
147
+ """Head for sentence-level classification tasks."""
148
+
149
+ def __init__(
150
+ self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout
151
+ ):
152
+ super().__init__()
153
+ self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True)
154
+ self.activation_fn = utils.get_activation_fn(activation_fn)
155
+ self.dropout = nn.Dropout(p=pooler_dropout)
156
+ self.out_proj = nn.Linear(inner_dim, num_classes)
157
+
158
+ def forward(self, features, **kwargs):
159
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
160
+ x = self.dropout(x)
161
+ x = self.dense(x)
162
+ x = self.activation_fn(x)
163
+ x = self.dropout(x)
164
+ x = self.out_proj(x)
165
+ return x
166
+
167
+
168
+ class ModelParallelRobertaEncoder(FairseqEncoder):
169
+ """RoBERTa encoder.
170
+
171
+ Implements the :class:`~fairseq.models.FairseqDecoder` interface required
172
+ by :class:`~fairseq.models.FairseqLanguageModel`.
173
+ """
174
+
175
+ def __init__(self, args, dictionary):
176
+ super().__init__(dictionary)
177
+ self.args = args
178
+
179
+ # RoBERTa is a sentence encoder model, so users will intuitively trim
180
+ # encoder layers. However, the implementation uses the fairseq decoder,
181
+ # so we fix here.
182
+ if args.encoder_layers_to_keep:
183
+ args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
184
+ args.decoder_layers_to_keep = args.encoder_layers_to_keep
185
+ args.encoder_layers_to_keep = None
186
+
187
+ self.sentence_encoder = ModelParallelTransformerSentenceEncoder(
188
+ padding_idx=dictionary.pad(),
189
+ vocab_size=len(dictionary),
190
+ num_encoder_layers=args.encoder_layers,
191
+ embedding_dim=args.encoder_embed_dim,
192
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
193
+ num_attention_heads=args.encoder_attention_heads,
194
+ dropout=args.dropout,
195
+ attention_dropout=args.attention_dropout,
196
+ activation_dropout=args.activation_dropout,
197
+ layerdrop=args.encoder_layerdrop,
198
+ max_seq_len=args.max_positions,
199
+ num_segments=0,
200
+ encoder_normalize_before=False,
201
+ apply_bert_init=False,
202
+ activation_fn=args.activation_fn,
203
+ )
204
+ self.lm_head = ModelParallelRobertaLMHead(
205
+ embed_dim=args.encoder_embed_dim,
206
+ output_dim=len(dictionary),
207
+ activation_fn=args.activation_fn,
208
+ weight=self.sentence_encoder.embed_tokens.weight,
209
+ )
210
+
211
+ def forward(
212
+ self,
213
+ src_tokens,
214
+ features_only=False,
215
+ return_all_hiddens=False,
216
+ masked_tokens=None,
217
+ **unused
218
+ ):
219
+ """
220
+ Args:
221
+ src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
222
+ features_only (bool, optional): skip LM head and just return
223
+ features. If True, the output will be of shape
224
+ `(batch, src_len, embed_dim)`.
225
+ return_all_hiddens (bool, optional): also return all of the
226
+ intermediate hidden states (default: False).
227
+
228
+ Returns:
229
+ tuple:
230
+ - the LM output of shape `(batch, src_len, vocab)`
231
+ - a dictionary of additional data, where 'inner_states'
232
+ is a list of hidden states. Note that the hidden
233
+ states have shape `(src_len, batch, vocab)`.
234
+ """
235
+ x, extra = self.extract_features(
236
+ src_tokens, return_all_hiddens=return_all_hiddens
237
+ )
238
+ if not features_only:
239
+ x = self.output_layer(x, masked_tokens=masked_tokens)
240
+ return x, extra
241
+
242
+ def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
243
+ inner_states, _ = self.sentence_encoder(
244
+ src_tokens,
245
+ last_state_only=not return_all_hiddens,
246
+ )
247
+ features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C
248
+ return features, {"inner_states": inner_states if return_all_hiddens else None}
249
+
250
+ def output_layer(self, features, masked_tokens=None, **unused):
251
+ return self.lm_head(features, masked_tokens)
252
+
253
+ def max_positions(self):
254
+ """Maximum output length supported by the encoder."""
255
+ return self.args.max_positions
256
+
257
+
258
+ @register_model_architecture("model_parallel_roberta", "model_parallel_roberta")
259
+ def base_architecture(args):
260
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
261
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
262
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
263
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
264
+
265
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
266
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
267
+
268
+ args.dropout = getattr(args, "dropout", 0.1)
269
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
270
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
271
+ args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
272
+ args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
273
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
274
+
275
+
276
+ @register_model_architecture("model_parallel_roberta", "model_parallel_roberta_base")
277
+ def roberta_base_architecture(args):
278
+ base_architecture(args)
279
+
280
+
281
+ @register_model_architecture("model_parallel_roberta", "model_parallel_roberta_large")
282
+ def roberta_large_architecture(args):
283
+ args.encoder_layers = getattr(args, "encoder_layers", 24)
284
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
285
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
286
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
287
+ base_architecture(args)
fairseq-0.10.2/fairseq/model_parallel/modules/__pycache__/transformer_sentence_encoder_layer.cpython-310.pyc ADDED
Binary file (2.46 kB). View file
 
fairseq-0.10.2/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from fairseq import utils
9
+ from fairseq.model_parallel.modules import ModelParallelMultiheadAttention
10
+ from fairseq.modules import TransformerSentenceEncoderLayer
11
+
12
+
13
+ try:
14
+ from fairseq.model_parallel.megatron.mpu import (
15
+ ColumnParallelLinear,
16
+ RowParallelLinear,
17
+ )
18
+
19
+ has_megatron_submodule = True
20
+ except (ImportError, ModuleNotFoundError):
21
+ has_megatron_submodule = False
22
+
23
+
24
+ class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer):
25
+ """
26
+ Implements a Model Parallel Transformer Encoder Layer used in
27
+ BERT/XLM style pre-trained models.
28
+ """
29
+
30
+ def build_fc1(self, input_dim, output_dim, **unused):
31
+ return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
32
+
33
+ def build_fc2(self, input_dim, output_dim, **unused):
34
+ return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
35
+
36
+ def build_self_attention(
37
+ self,
38
+ embed_dim,
39
+ num_attention_heads,
40
+ dropout,
41
+ **kwargs,
42
+ ):
43
+ return ModelParallelMultiheadAttention(
44
+ embed_dim, num_attention_heads, dropout=dropout, self_attention=True
45
+ )
46
+
47
+ def forward(
48
+ self,
49
+ x: torch.Tensor,
50
+ self_attn_mask: torch.Tensor = None,
51
+ self_attn_padding_mask: torch.Tensor = None,
52
+ ):
53
+ """
54
+ LayerNorm is applied either before or after the self-attention/ffn
55
+ modules similar to the original Transformer imlementation.
56
+ """
57
+ residual = x
58
+ x = self.self_attn_layer_norm(x)
59
+ x, attn = self.self_attn(
60
+ query=x,
61
+ key=x,
62
+ value=x,
63
+ key_padding_mask=self_attn_padding_mask,
64
+ need_weights=False,
65
+ attn_mask=self_attn_mask,
66
+ )
67
+ x = self.dropout_module(x)
68
+ x = residual + x
69
+
70
+ residual = x
71
+ x = self.final_layer_norm(x)
72
+ x = self.activation_fn(self.fc1(x))
73
+ x = self.activation_dropout_module(x)
74
+ x = self.fc2(x)
75
+ x = self.dropout_module(x)
76
+ x = residual + x
77
+ return x, None
fairseq-0.10.2/fairseq/modules/__init__.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """isort:skip_file"""
6
+
7
+ from .adaptive_input import AdaptiveInput
8
+ from .adaptive_softmax import AdaptiveSoftmax
9
+ from .beamable_mm import BeamableMM
10
+ from .character_token_embedder import CharacterTokenEmbedder
11
+ from .conv_tbc import ConvTBC
12
+ from .cross_entropy import cross_entropy
13
+ from .downsampled_multihead_attention import DownsampledMultiHeadAttention
14
+ from .dynamic_convolution import DynamicConv, DynamicConv1dTBC
15
+ from .dynamic_crf_layer import DynamicCRF
16
+ from .fairseq_dropout import FairseqDropout
17
+ from .fp32_group_norm import Fp32GroupNorm
18
+ from .gelu import gelu, gelu_accurate
19
+ from .grad_multiply import GradMultiply
20
+ from .gumbel_vector_quantizer import GumbelVectorQuantizer
21
+ from .kmeans_vector_quantizer import KmeansVectorQuantizer
22
+ from .layer_drop import LayerDropModuleList
23
+ from .layer_norm import Fp32LayerNorm, LayerNorm
24
+ from .learned_positional_embedding import LearnedPositionalEmbedding
25
+ from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
26
+ from .linearized_convolution import LinearizedConvolution
27
+ from .multihead_attention import MultiheadAttention
28
+ from .positional_embedding import PositionalEmbedding
29
+ from .same_pad import SamePad
30
+ from .scalar_bias import ScalarBias
31
+ from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
32
+ from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer
33
+ from .transformer_sentence_encoder import TransformerSentenceEncoder
34
+ from .transpose_last import TransposeLast
35
+ from .unfold import unfold1d
36
+ from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
37
+ from .vggblock import VGGBlock
38
+
39
+ __all__ = [
40
+ "AdaptiveInput",
41
+ "AdaptiveSoftmax",
42
+ "BeamableMM",
43
+ "CharacterTokenEmbedder",
44
+ "ConvTBC",
45
+ "cross_entropy",
46
+ "DownsampledMultiHeadAttention",
47
+ "DynamicConv1dTBC",
48
+ "DynamicConv",
49
+ "DynamicCRF",
50
+ "FairseqDropout",
51
+ "Fp32GroupNorm",
52
+ "Fp32LayerNorm",
53
+ "gelu",
54
+ "gelu_accurate",
55
+ "GradMultiply",
56
+ "GumbelVectorQuantizer",
57
+ "KmeansVectorQuantizer",
58
+ "LayerDropModuleList",
59
+ "LayerNorm",
60
+ "LearnedPositionalEmbedding",
61
+ "LightweightConv1dTBC",
62
+ "LightweightConv",
63
+ "LinearizedConvolution",
64
+ "MultiheadAttention",
65
+ "PositionalEmbedding",
66
+ "SamePad",
67
+ "ScalarBias",
68
+ "SinusoidalPositionalEmbedding",
69
+ "TransformerSentenceEncoderLayer",
70
+ "TransformerSentenceEncoder",
71
+ "TransformerDecoderLayer",
72
+ "TransformerEncoderLayer",
73
+ "TransposeLast",
74
+ "VGGBlock",
75
+ "unfold1d",
76
+ ]
fairseq-0.10.2/fairseq/modules/__pycache__/adaptive_softmax.cpython-310.pyc ADDED
Binary file (6.81 kB). View file
 
fairseq-0.10.2/fairseq/modules/__pycache__/dynamic_convolution.cpython-310.pyc ADDED
Binary file (8.24 kB). View file
 
fairseq-0.10.2/fairseq/modules/__pycache__/grad_multiply.cpython-310.pyc ADDED
Binary file (703 Bytes). View file
 
fairseq-0.10.2/fairseq/modules/__pycache__/gumbel_vector_quantizer.cpython-310.pyc ADDED
Binary file (5.99 kB). View file
 
fairseq-0.10.2/fairseq/modules/__pycache__/kmeans_vector_quantizer.cpython-310.pyc ADDED
Binary file (3.86 kB). View file
 
fairseq-0.10.2/fairseq/modules/__pycache__/learned_positional_embedding.cpython-310.pyc ADDED
Binary file (2.01 kB). View file
 
fairseq-0.10.2/fairseq/modules/__pycache__/lightweight_convolution.cpython-310.pyc ADDED
Binary file (8.87 kB). View file
 
fairseq-0.10.2/fairseq/modules/__pycache__/transformer_sentence_encoder_layer.cpython-310.pyc ADDED
Binary file (3.29 kB). View file
 
fairseq-0.10.2/fairseq/modules/__pycache__/vggblock.cpython-310.pyc ADDED
Binary file (3.49 kB). View file
 
fairseq-0.10.2/fairseq/modules/adaptive_softmax.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import functools
7
+ import operator
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from fairseq.modules.fairseq_dropout import FairseqDropout
12
+ from fairseq.modules.quant_noise import quant_noise
13
+ from torch import nn
14
+
15
+
16
+ class TiedLinear(nn.Module):
17
+ def __init__(self, weight, transpose):
18
+ super().__init__()
19
+ self.weight = weight
20
+ self.transpose = transpose
21
+
22
+ def forward(self, input):
23
+ return F.linear(input, self.weight.t() if self.transpose else self.weight)
24
+
25
+
26
+ class TiedHeadModule(nn.Module):
27
+ def __init__(self, weights, input_dim, num_classes, q_noise, qn_block_size):
28
+ super().__init__()
29
+ tied_emb, _ = weights
30
+ self.num_words, emb_dim = tied_emb.size()
31
+
32
+ self.word_proj = quant_noise(
33
+ TiedLinear(tied_emb, transpose=False), q_noise, qn_block_size
34
+ )
35
+ if input_dim != emb_dim:
36
+ self.word_proj = nn.Sequential(
37
+ quant_noise(
38
+ nn.Linear(input_dim, emb_dim, bias=False), q_noise, qn_block_size
39
+ ),
40
+ self.word_proj,
41
+ )
42
+
43
+ self.class_proj = quant_noise(
44
+ nn.Linear(input_dim, num_classes, bias=False), q_noise, qn_block_size
45
+ )
46
+ self.out_dim = self.num_words + num_classes
47
+
48
+ self.register_buffer("_float_tensor", torch.FloatTensor(1))
49
+
50
+ def forward(self, input):
51
+ inp_sz = functools.reduce(operator.mul, input.shape[:-1], 1)
52
+ out = self._float_tensor.new(inp_sz, self.out_dim)
53
+ out[:, : self.num_words] = self.word_proj(input.view(inp_sz, -1))
54
+ out[:, self.num_words :] = self.class_proj(input.view(inp_sz, -1))
55
+ return out
56
+
57
+
58
+ class AdaptiveSoftmax(nn.Module):
59
+ """
60
+ This is an implementation of the efficient softmax approximation for
61
+ graphical processing units (GPU), described in the paper "Efficient softmax
62
+ approximation for GPUs" (http://arxiv.org/abs/1609.04309).
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ vocab_size,
68
+ input_dim,
69
+ cutoff,
70
+ dropout,
71
+ factor=4.0,
72
+ adaptive_inputs=None,
73
+ tie_proj=False,
74
+ q_noise=0,
75
+ qn_block_size=8,
76
+ ):
77
+ super().__init__()
78
+
79
+ if vocab_size > cutoff[-1]:
80
+ cutoff = cutoff + [vocab_size]
81
+ else:
82
+ assert (
83
+ vocab_size == cutoff[-1]
84
+ ), "cannot specify cutoff larger than vocab size"
85
+
86
+ output_dim = cutoff[0] + len(cutoff) - 1
87
+
88
+ self.vocab_size = vocab_size
89
+ self.cutoff = cutoff
90
+ self.dropout_module = FairseqDropout(
91
+ dropout, module_name=self.__class__.__name__
92
+ )
93
+ self.input_dim = input_dim
94
+ self.factor = factor
95
+ self.q_noise = q_noise
96
+ self.qn_block_size = qn_block_size
97
+
98
+ self.lsm = nn.LogSoftmax(dim=1)
99
+
100
+ if adaptive_inputs is not None:
101
+ self.head = TiedHeadModule(
102
+ adaptive_inputs.weights_for_band(0),
103
+ input_dim,
104
+ len(cutoff) - 1,
105
+ self.q_noise,
106
+ self.qn_block_size,
107
+ )
108
+ else:
109
+ self.head = quant_noise(
110
+ nn.Linear(input_dim, output_dim, bias=False),
111
+ self.q_noise,
112
+ self.qn_block_size,
113
+ )
114
+
115
+ self._make_tail(adaptive_inputs, tie_proj)
116
+
117
+ def init_weights(m):
118
+ if (
119
+ hasattr(m, "weight")
120
+ and not isinstance(m, TiedLinear)
121
+ and not isinstance(m, TiedHeadModule)
122
+ ):
123
+ nn.init.xavier_uniform_(m.weight)
124
+
125
+ self.apply(init_weights)
126
+
127
+ self.register_buffer("version", torch.LongTensor([1]))
128
+
129
+ def _make_tail(self, adaptive_inputs=None, tie_proj=False):
130
+ self.tail = nn.ModuleList()
131
+ for i in range(len(self.cutoff) - 1):
132
+ dim = int(self.input_dim // self.factor ** (i + 1))
133
+
134
+ tied_emb, tied_proj = (
135
+ adaptive_inputs.weights_for_band(i + 1)
136
+ if adaptive_inputs is not None
137
+ else (None, None)
138
+ )
139
+
140
+ if tied_proj is not None:
141
+ if tie_proj:
142
+ proj = quant_noise(
143
+ TiedLinear(tied_proj, transpose=True),
144
+ self.q_noise,
145
+ self.qn_block_size,
146
+ )
147
+ else:
148
+ proj = quant_noise(
149
+ nn.Linear(tied_proj.size(0), tied_proj.size(1), bias=False),
150
+ self.q_noise,
151
+ self.qn_block_size,
152
+ )
153
+ else:
154
+ proj = quant_noise(
155
+ nn.Linear(self.input_dim, dim, bias=False),
156
+ self.q_noise,
157
+ self.qn_block_size,
158
+ )
159
+
160
+ if tied_emb is None:
161
+ out_proj = nn.Linear(
162
+ dim, self.cutoff[i + 1] - self.cutoff[i], bias=False
163
+ )
164
+ else:
165
+ out_proj = TiedLinear(tied_emb, transpose=False)
166
+
167
+ m = nn.Sequential(
168
+ proj,
169
+ nn.Dropout(self.dropout_module.p),
170
+ quant_noise(out_proj, self.q_noise, self.qn_block_size),
171
+ )
172
+
173
+ self.tail.append(m)
174
+
175
+ def upgrade_state_dict_named(self, state_dict, name):
176
+ version_name = name + ".version"
177
+ if version_name not in state_dict:
178
+ raise Exception("This version of the model is no longer supported")
179
+
180
+ def adapt_target(self, target):
181
+ """
182
+ In order to be efficient, the AdaptiveSoftMax does not compute the
183
+ scores for all the word of the vocabulary for all the examples. It is
184
+ thus necessary to call the method adapt_target of the AdaptiveSoftMax
185
+ layer inside each forward pass.
186
+ """
187
+
188
+ target = target.view(-1)
189
+ new_target = [target.clone()]
190
+ target_idxs = []
191
+
192
+ for i in range(len(self.cutoff) - 1):
193
+ mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
194
+ new_target[0][mask] = self.cutoff[0] + i
195
+
196
+ if mask.any():
197
+ target_idxs.append(mask.nonzero(as_tuple=False).squeeze(1))
198
+ new_target.append(target[mask].add(-self.cutoff[i]))
199
+ else:
200
+ target_idxs.append(None)
201
+ new_target.append(None)
202
+
203
+ return new_target, target_idxs
204
+
205
+ def forward(self, input, target):
206
+ """
207
+ Args:
208
+ input: (b x t x d)
209
+ target: (b x t)
210
+ Returns:
211
+ 2 lists: output for each cutoff section and new targets by cut off
212
+ """
213
+
214
+ input = input.contiguous().view(-1, input.size(-1))
215
+ input = self.dropout_module(input)
216
+
217
+ new_target, target_idxs = self.adapt_target(target)
218
+ output = [self.head(input)]
219
+
220
+ for i in range(len(target_idxs)):
221
+ if target_idxs[i] is not None:
222
+ output.append(self.tail[i](input.index_select(0, target_idxs[i])))
223
+ else:
224
+ output.append(None)
225
+
226
+ return output, new_target
227
+
228
+ def get_log_prob(self, input, target):
229
+ """
230
+ Computes the log probabilities for all the words of the vocabulary,
231
+ given a 2D tensor of hidden vectors.
232
+ """
233
+
234
+ bsz, length, dim = input.size()
235
+ input = input.contiguous().view(-1, dim)
236
+
237
+ if target is not None:
238
+ _, target_idxs = self.adapt_target(target)
239
+ else:
240
+ target_idxs = None
241
+
242
+ head_y = self.head(input)
243
+ log_probs = head_y.new_zeros(input.size(0), self.vocab_size)
244
+
245
+ head_sz = self.cutoff[0] + len(self.tail)
246
+ log_probs[:, :head_sz] = self.lsm(head_y)
247
+ tail_priors = log_probs[:, self.cutoff[0] : head_sz].clone()
248
+
249
+ for i in range(len(self.tail)):
250
+ start = self.cutoff[i]
251
+ end = self.cutoff[i + 1]
252
+
253
+ if target_idxs is None:
254
+ tail_out = log_probs[:, start:end]
255
+ tail_out.copy_(self.tail[i](input))
256
+ log_probs[:, start:end] = self.lsm(tail_out).add_(
257
+ tail_priors[:, i, None]
258
+ )
259
+ elif target_idxs[i] is not None:
260
+ idxs = target_idxs[i]
261
+ tail_out = log_probs[idxs, start:end]
262
+ tail_out.copy_(self.tail[i](input[idxs]))
263
+ log_probs[idxs, start:end] = self.lsm(tail_out).add_(
264
+ tail_priors[idxs, i, None]
265
+ )
266
+
267
+ log_probs = log_probs.view(bsz, length, -1)
268
+ return log_probs
fairseq-0.10.2/fairseq/modules/beamable_mm.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class BeamableMM(nn.Module):
11
+ """This module provides an optimized MM for beam decoding with attention.
12
+
13
+ It leverage the fact that the source-side of the input is replicated beam
14
+ times and the target-side of the input is of width one. This layer speeds up
15
+ inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
16
+ with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
17
+ """
18
+
19
+ def __init__(self, beam_size=None):
20
+ super(BeamableMM, self).__init__()
21
+ self.beam_size = beam_size
22
+
23
+ def forward(self, input1, input2):
24
+ if (
25
+ not self.training
26
+ and self.beam_size is not None # test mode
27
+ and input1.dim() == 3 # beam size is set
28
+ and input1.size(1) # only support batched input
29
+ == 1 # single time step update
30
+ ):
31
+ bsz, beam = input1.size(0), self.beam_size
32
+
33
+ # bsz x 1 x nhu --> bsz/beam x beam x nhu
34
+ input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1)
35
+
36
+ # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu
37
+ input2 = input2.unfold(0, beam, beam)[:, :, :, 0]
38
+
39
+ # use non batched operation if bsz = beam
40
+ if input1.size(0) == 1:
41
+ output = torch.mm(input1[0, :, :], input2[0, :, :])
42
+ else:
43
+ output = input1.bmm(input2)
44
+ return output.view(bsz, 1, -1)
45
+ else:
46
+ return input1.bmm(input2)
47
+
48
+ def set_beam_size(self, beam_size):
49
+ self.beam_size = beam_size
fairseq-0.10.2/fairseq/modules/character_token_embedder.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from typing import List, Tuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from fairseq.data import Dictionary
12
+ from torch import nn
13
+
14
+
15
+ CHAR_PAD_IDX = 0
16
+ CHAR_EOS_IDX = 257
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class CharacterTokenEmbedder(torch.nn.Module):
23
+ def __init__(
24
+ self,
25
+ vocab: Dictionary,
26
+ filters: List[Tuple[int, int]],
27
+ char_embed_dim: int,
28
+ word_embed_dim: int,
29
+ highway_layers: int,
30
+ max_char_len: int = 50,
31
+ char_inputs: bool = False,
32
+ ):
33
+ super(CharacterTokenEmbedder, self).__init__()
34
+
35
+ self.onnx_trace = False
36
+ self.embedding_dim = word_embed_dim
37
+ self.max_char_len = max_char_len
38
+ self.char_embeddings = nn.Embedding(257, char_embed_dim, padding_idx=0)
39
+ self.symbol_embeddings = nn.Parameter(torch.FloatTensor(2, word_embed_dim))
40
+ self.eos_idx, self.unk_idx = 0, 1
41
+ self.char_inputs = char_inputs
42
+
43
+ self.convolutions = nn.ModuleList()
44
+ for width, out_c in filters:
45
+ self.convolutions.append(
46
+ nn.Conv1d(char_embed_dim, out_c, kernel_size=width)
47
+ )
48
+
49
+ last_dim = sum(f[1] for f in filters)
50
+
51
+ self.highway = Highway(last_dim, highway_layers) if highway_layers > 0 else None
52
+
53
+ self.projection = nn.Linear(last_dim, word_embed_dim)
54
+
55
+ assert (
56
+ vocab is not None or char_inputs
57
+ ), "vocab must be set if not using char inputs"
58
+ self.vocab = None
59
+ if vocab is not None:
60
+ self.set_vocab(vocab, max_char_len)
61
+
62
+ self.reset_parameters()
63
+
64
+ def prepare_for_onnx_export_(self):
65
+ self.onnx_trace = True
66
+
67
+ def set_vocab(self, vocab, max_char_len):
68
+ word_to_char = torch.LongTensor(len(vocab), max_char_len)
69
+
70
+ truncated = 0
71
+ for i in range(len(vocab)):
72
+ if i < vocab.nspecial:
73
+ char_idxs = [0] * max_char_len
74
+ else:
75
+ chars = vocab[i].encode()
76
+ # +1 for padding
77
+ char_idxs = [c + 1 for c in chars] + [0] * (max_char_len - len(chars))
78
+ if len(char_idxs) > max_char_len:
79
+ truncated += 1
80
+ char_idxs = char_idxs[:max_char_len]
81
+ word_to_char[i] = torch.LongTensor(char_idxs)
82
+
83
+ if truncated > 0:
84
+ logger.info(
85
+ "truncated {} words longer than {} characters".format(
86
+ truncated, max_char_len
87
+ )
88
+ )
89
+
90
+ self.vocab = vocab
91
+ self.word_to_char = word_to_char
92
+
93
+ @property
94
+ def padding_idx(self):
95
+ return Dictionary().pad() if self.vocab is None else self.vocab.pad()
96
+
97
+ def reset_parameters(self):
98
+ nn.init.xavier_normal_(self.char_embeddings.weight)
99
+ nn.init.xavier_normal_(self.symbol_embeddings)
100
+ nn.init.xavier_uniform_(self.projection.weight)
101
+
102
+ nn.init.constant_(
103
+ self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.0
104
+ )
105
+ nn.init.constant_(self.projection.bias, 0.0)
106
+
107
+ def forward(
108
+ self,
109
+ input: torch.Tensor,
110
+ ):
111
+ if self.char_inputs:
112
+ chars = input.view(-1, self.max_char_len)
113
+ pads = chars[:, 0].eq(CHAR_PAD_IDX)
114
+ eos = chars[:, 0].eq(CHAR_EOS_IDX)
115
+ if eos.any():
116
+ if self.onnx_trace:
117
+ chars = torch.where(eos.unsqueeze(1), chars.new_zeros(1), chars)
118
+ else:
119
+ chars[eos] = 0
120
+
121
+ unk = None
122
+ else:
123
+ flat_words = input.view(-1)
124
+ chars = self.word_to_char[flat_words.type_as(self.word_to_char)].type_as(
125
+ input
126
+ )
127
+ pads = flat_words.eq(self.vocab.pad())
128
+ eos = flat_words.eq(self.vocab.eos())
129
+ unk = flat_words.eq(self.vocab.unk())
130
+
131
+ word_embs = self._convolve(chars)
132
+ if self.onnx_trace:
133
+ if pads.any():
134
+ word_embs = torch.where(
135
+ pads.unsqueeze(1), word_embs.new_zeros(1), word_embs
136
+ )
137
+ if eos.any():
138
+ word_embs = torch.where(
139
+ eos.unsqueeze(1), self.symbol_embeddings[self.eos_idx], word_embs
140
+ )
141
+ if unk is not None and unk.any():
142
+ word_embs = torch.where(
143
+ unk.unsqueeze(1), self.symbol_embeddings[self.unk_idx], word_embs
144
+ )
145
+ else:
146
+ if pads.any():
147
+ word_embs[pads] = 0
148
+ if eos.any():
149
+ word_embs[eos] = self.symbol_embeddings[self.eos_idx]
150
+ if unk is not None and unk.any():
151
+ word_embs[unk] = self.symbol_embeddings[self.unk_idx]
152
+
153
+ return word_embs.view(input.size()[:2] + (-1,))
154
+
155
+ def _convolve(
156
+ self,
157
+ char_idxs: torch.Tensor,
158
+ ):
159
+ char_embs = self.char_embeddings(char_idxs)
160
+ char_embs = char_embs.transpose(1, 2) # BTC -> BCT
161
+
162
+ conv_result = []
163
+
164
+ for conv in self.convolutions:
165
+ x = conv(char_embs)
166
+ x, _ = torch.max(x, -1)
167
+ x = F.relu(x)
168
+ conv_result.append(x)
169
+
170
+ x = torch.cat(conv_result, dim=-1)
171
+
172
+ if self.highway is not None:
173
+ x = self.highway(x)
174
+ x = self.projection(x)
175
+
176
+ return x
177
+
178
+
179
+ class Highway(torch.nn.Module):
180
+ """
181
+ A `Highway layer <https://arxiv.org/abs/1505.00387>`_.
182
+ Adopted from the AllenNLP implementation.
183
+ """
184
+
185
+ def __init__(self, input_dim: int, num_layers: int = 1):
186
+ super(Highway, self).__init__()
187
+ self.input_dim = input_dim
188
+ self.layers = nn.ModuleList(
189
+ [nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)]
190
+ )
191
+ self.activation = nn.ReLU()
192
+
193
+ self.reset_parameters()
194
+
195
+ def reset_parameters(self):
196
+ for layer in self.layers:
197
+ # As per comment in AllenNLP:
198
+ # We should bias the highway layer to just carry its input forward. We do that by
199
+ # setting the bias on `B(x)` to be positive, because that means `g` will be biased to
200
+ # be high, so we will carry the input forward. The bias on `B(x)` is the second half
201
+ # of the bias vector in each Linear layer.
202
+ nn.init.constant_(layer.bias[self.input_dim :], 1)
203
+
204
+ nn.init.constant_(layer.bias[: self.input_dim], 0)
205
+ nn.init.xavier_normal_(layer.weight)
206
+
207
+ def forward(self, x: torch.Tensor):
208
+ for layer in self.layers:
209
+ projection = layer(x)
210
+ proj_x, gate = projection.chunk(2, dim=-1)
211
+ proj_x = self.activation(proj_x)
212
+ gate = torch.sigmoid(gate)
213
+ x = gate * x + (gate.new_tensor([1]) - gate) * proj_x
214
+ return x
fairseq-0.10.2/fairseq/modules/cross_entropy.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"):
16
+ lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
17
+ return F.nll_loss(
18
+ lprobs,
19
+ target,
20
+ ignore_index=ignore_index,
21
+ reduction=reduction,
22
+ )
23
+
24
+
25
+ try:
26
+ import xentropy_cuda
27
+ from apex.contrib import xentropy
28
+
29
+ logger.info("using fused cross entropy")
30
+
31
+ def cross_entropy(logits, target, ignore_index=-100, reduction="mean"):
32
+ if logits.device == torch.device("cpu"):
33
+ return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
34
+ else:
35
+ half_to_float = logits.dtype == torch.half
36
+ losses = xentropy.SoftmaxCrossEntropyLoss.apply(
37
+ logits,
38
+ target,
39
+ 0.0,
40
+ ignore_index,
41
+ half_to_float,
42
+ )
43
+ if reduction == "sum":
44
+ return losses.sum()
45
+ elif reduction == "mean":
46
+ if ignore_index >= 0:
47
+ return losses.sum() / target.ne(ignore_index).sum()
48
+ else:
49
+ return losses.mean()
50
+ elif reduction == "none":
51
+ return losses
52
+ else:
53
+ raise NotImplementedError
54
+
55
+
56
+ except ImportError:
57
+
58
+ def cross_entropy(logits, target, ignore_index=-100, reduction="mean"):
59
+ return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
fairseq-0.10.2/fairseq/modules/fp32_group_norm.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ Layer norm done in fp32 (for fp16 training)
7
+ """
8
+
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class Fp32GroupNorm(nn.GroupNorm):
14
+ def __init__(self, *args, **kwargs):
15
+ super().__init__(*args, **kwargs)
16
+
17
+ def forward(self, input):
18
+ output = F.group_norm(
19
+ input.float(),
20
+ self.num_groups,
21
+ self.weight.float() if self.weight is not None else None,
22
+ self.bias.float() if self.bias is not None else None,
23
+ self.eps,
24
+ )
25
+ return output.type_as(input)
fairseq-0.10.2/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include "lightconv_cuda.cuh"
9
+ #include "lightconv_cuda_forward.cu"
10
+ #include "lightconv_cuda_backward.cu"
11
+ #include "../cuda_utils.cu"
12
+
13
+ template<int FS, int SB, int padding_l, typename scalar_t>
14
+ __global__
15
+ void lightconv_forward_kernel(const scalar_t* input,
16
+ const scalar_t* filters,
17
+ int minibatch, int sequenceLength,
18
+ int numFeatures, int numFiltersInBlock,
19
+ scalar_t* output) {
20
+
21
+ const int tid = threadIdx.x;
22
+ const int batchIdx = blockIdx.x;
23
+ const int featureIdx = blockIdx.y;
24
+ const int filterIdx = featureIdx / numFiltersInBlock;
25
+
26
+ const int IOOffset = numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength;
27
+ const scalar_t* inputFeature = &input[IOOffset];
28
+ scalar_t* outputFeature = &output[IOOffset];
29
+ const scalar_t* inputFilter = &filters[filterIdx * FS];
30
+
31
+ assert(blockDim.x == SB);
32
+
33
+ scalar_t filter[FS];
34
+ #pragma unroll
35
+ for (int i = 0; i < FS; ++i) {
36
+ filter[i] = inputFilter[i];
37
+ }
38
+
39
+ __shared__ scalar_t temp[SB + FS];
40
+ zeroSharedMem<FS, SB, padding_l>(temp);
41
+
42
+ const int numIterations = divUp<int, int>(sequenceLength, SB);
43
+
44
+ for (int i = 0; i < numIterations; ++i) {
45
+ // Read input into shared memory
46
+ const int inputOffset = i * SB;
47
+
48
+ load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset, sequenceLength,
49
+ i, numIterations, (numIterations == 1), temp);
50
+
51
+ __syncthreads();
52
+
53
+ scalar_t out = 0;
54
+ #pragma unroll
55
+ for (int j = 0; j < FS; ++j) {
56
+ out += filter[j] * temp[tid + j];
57
+ }
58
+
59
+ // Write output
60
+ const int outputOffset = inputOffset;
61
+ if ((outputOffset + tid) < sequenceLength) {
62
+ outputFeature[outputOffset + tid] = out;
63
+ }
64
+
65
+ __syncthreads();
66
+ }
67
+ }
68
+
69
+ template<int FS, int SB, int padding_l, typename scalar_t>
70
+ __global__
71
+ void lightconv_grad_wrt_input_kernel(
72
+ const scalar_t* input,
73
+ const scalar_t* filters,
74
+ int minibatch,
75
+ int sequenceLength,
76
+ int numFeatures,
77
+ int numFiltersInBlock,
78
+ scalar_t* output) {
79
+
80
+ // input grad kernel is similar to forward kernel
81
+ const int tid = threadIdx.x;
82
+ const int batchIdx = blockIdx.x;
83
+ const int featureIdx = blockIdx.y;
84
+ const int filterIdx = featureIdx / numFiltersInBlock;
85
+
86
+ const int IOOffset = numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength;
87
+ const scalar_t* inputFeature = &input[IOOffset];
88
+ scalar_t* outputFeature = &output[IOOffset];
89
+ const scalar_t* inputFilter = &filters[filterIdx * FS];
90
+
91
+ assert(blockDim.x == SB);
92
+
93
+ scalar_t filter[FS];
94
+
95
+ // The only change is loading the filter in reverse
96
+ #pragma unroll
97
+ for (int i = 0; i < FS; ++i) {
98
+ filter[i] = inputFilter[FS - i - 1];
99
+ }
100
+
101
+ __shared__ scalar_t temp[SB + FS];
102
+ const int padding = FS - padding_l - 1;
103
+ zeroSharedMem<FS, SB, padding>(temp);
104
+
105
+ __syncthreads();
106
+
107
+ const int numIterations = divUp<int, int>(sequenceLength, SB);
108
+
109
+ for (int i = 0; i < numIterations; ++i) {
110
+ // Read input into shared memory
111
+ const int inputOffset = i * SB;
112
+
113
+ load_input_to_shared<FS, SB, padding>(inputFeature, inputOffset, sequenceLength,
114
+ i, numIterations, false, temp);
115
+
116
+ __syncthreads();
117
+
118
+ scalar_t out = 0;
119
+ #pragma unroll
120
+ for (int j = 0; j < FS; ++j) {
121
+ out += filter[j] * temp[tid + j];
122
+ }
123
+
124
+ // Write output
125
+ const int outputOffset = inputOffset;
126
+ if ((outputOffset + tid) < sequenceLength) {
127
+ outputFeature[outputOffset + tid] = out;
128
+ }
129
+
130
+ __syncthreads();
131
+ }
132
+ }
133
+
134
+ // This is by far the most expensive kernel in terms of time taken.
135
+ // Can be 16x slower than the forward or grad_wrt_input when filter size is 31
136
+ template<int FS, int SB, int padding_l, typename scalar_t>
137
+ __global__
138
+ void lightconv_grad_wrt_weights_firstpass_short_kernel(
139
+ const scalar_t* input,
140
+ const scalar_t* gradInput,
141
+ int minibatch,
142
+ int sequenceLength,
143
+ int numFeatures,
144
+ int numFiltersInBlock,
145
+ int numHeads,
146
+ float* output) {
147
+
148
+ const int tid = threadIdx.x;
149
+ const int batchIdx = blockIdx.x;
150
+ const int filterIdx = blockIdx.y;
151
+
152
+ const int numIterations = divUp<int, int>(sequenceLength, SB);
153
+
154
+ float* tempOutputGradWeight = &output[filterIdx * FS * minibatch];
155
+
156
+ assert(blockDim.x == SB);
157
+
158
+ __shared__ scalar_t tempInput[SB + FS];
159
+ __shared__ scalar_t tempGradInput[SB + FS];
160
+
161
+ // local weight accumulation
162
+ float accumWeights[FS];
163
+
164
+ // Initialize memory
165
+ for (int i = 0; i < FS; ++i) {
166
+ accumWeights[i] = float(0.0);
167
+ }
168
+
169
+
170
+ // loop over each sequence within filterblock
171
+ for (int idxInFilterBlock = 0; idxInFilterBlock < numFiltersInBlock; ++idxInFilterBlock) {
172
+
173
+ const int featureOffset = batchIdx * numFeatures * sequenceLength + (filterIdx * numFiltersInBlock + idxInFilterBlock) * sequenceLength;
174
+ const scalar_t* inputFeature = &input[featureOffset];
175
+ const scalar_t* gradInputFeature = &gradInput[featureOffset];
176
+
177
+ zeroSharedMem<FS, SB, padding_l>(tempInput);
178
+ zeroSharedMem<FS, SB, (FS/2)>(tempGradInput);
179
+ __syncthreads();
180
+
181
+ for (int i = 0; i < numIterations; ++i) {
182
+
183
+ const int inputOffset = i * SB;
184
+
185
+ load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset, sequenceLength,
186
+ i, numIterations, false, tempInput);
187
+ load_input_to_shared<FS, SB, (FS/2)>(gradInputFeature, inputOffset, sequenceLength,
188
+ i, numIterations, false, tempGradInput);
189
+
190
+ __syncthreads();
191
+
192
+ const int gradIndex = (FS/2) + tid;
193
+ scalar_t tempGrad = tempGradInput[gradIndex];
194
+
195
+ #pragma unroll
196
+ for (int j = 0; j < FS; j++) {
197
+ const int inputIndex = tid + j;
198
+ accumWeights[j] += tempInput[inputIndex] * tempGrad;
199
+ }
200
+
201
+ __syncthreads();
202
+
203
+ }
204
+
205
+ }
206
+
207
+ // Row-major sum
208
+ for (int filterWeightIdx = 0; filterWeightIdx < FS; ++filterWeightIdx) {
209
+
210
+ float temp;
211
+ if (tid < sequenceLength) {
212
+ temp = accumWeights[filterWeightIdx];
213
+ } else {
214
+ temp = float(0.0);
215
+ }
216
+
217
+ const int outputOffset = filterWeightIdx * minibatch + batchIdx;
218
+
219
+ temp = blockReduce(temp);
220
+
221
+ if (tid == 0) {
222
+ tempOutputGradWeight[outputOffset] = temp;
223
+ }
224
+ }
225
+ }
226
+
227
+ template<int FS, int SB, typename scalar_t>
228
+ __global__
229
+ void lightconv_grad_wrt_weights_secondpass_short_kernel(
230
+ const float* input,
231
+ const int minibatch,
232
+ const int numFiltersInBlock,
233
+ scalar_t* output) {
234
+
235
+ assert(blockDim.x == SB);
236
+
237
+ const int tid = threadIdx.x;
238
+
239
+ const int filterIdx = blockIdx.x;
240
+ const int filterWeightIdx = blockIdx.y;
241
+
242
+ const int inputOffset = filterIdx * FS * minibatch +
243
+ filterWeightIdx * minibatch;
244
+ const float* tempInput = &input[inputOffset];
245
+
246
+ // read into shared memory for reduction
247
+ int readIndex = tid;
248
+
249
+ float sum = 0.0;
250
+ while (readIndex < minibatch) {
251
+ sum += tempInput[readIndex];
252
+ readIndex += SB;
253
+ }
254
+
255
+ float temp = blockReduce(sum);
256
+
257
+ if (tid == 0) {
258
+ output[blockIdx.x * FS + blockIdx.y] = temp;
259
+ }
260
+ }
261
+
262
+ // This is by far the most expensive kernel in terms of time taken.
263
+ // Can be 16x slower than the forward or grad_wrt_input when filter size is 31
264
+ template<int FS, int SB, int padding_l, typename scalar_t>
265
+ __global__
266
+ void lightconv_grad_wrt_weights_firstpass_kernel(
267
+ const scalar_t* input,
268
+ const scalar_t* gradInput,
269
+ int minibatch,
270
+ int sequenceLength,
271
+ int numFeatures,
272
+ int numFiltersInBlock,
273
+ float* output) {
274
+
275
+ assert(blockDim.x == SB);
276
+
277
+ const int tid = threadIdx.x;
278
+ const int batchIdx = blockIdx.x;
279
+ const int featureIdx = blockIdx.y;
280
+ const int filterIdx = featureIdx / numFiltersInBlock;
281
+ const int idxInFilterBlock = featureIdx % numFiltersInBlock;
282
+
283
+ const int numIterations = divUp<int, int>(sequenceLength, SB);
284
+
285
+ float temp;
286
+
287
+ __shared__ scalar_t tempInput[SB + FS];
288
+ __shared__ scalar_t tempGradInput[SB + FS];
289
+ zeroSharedMem<FS, SB, padding_l>(tempInput);
290
+ zeroSharedMem<FS, SB, (FS/2)>(tempGradInput);
291
+ __syncthreads();
292
+
293
+ float accumWeights[FS];
294
+
295
+ for (int i = 0; i < FS; ++i) {
296
+ accumWeights[i] = float(0.0);
297
+ }
298
+
299
+ const int IOOffset = batchIdx * numFeatures * sequenceLength + featureIdx * sequenceLength;
300
+ const scalar_t* inputFeature = &input[IOOffset];
301
+ const scalar_t* gradInputFeature = &gradInput[IOOffset];
302
+ float* tempOutputGradWeight = &output[filterIdx * FS * minibatch * numFiltersInBlock];
303
+
304
+ for (int i = 0; i < numIterations; ++i) {
305
+ const int inputOffset = i * SB;
306
+
307
+ load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset, sequenceLength,
308
+ i, numIterations, false, tempInput);
309
+ load_input_to_shared<FS, SB, (FS/2)>(gradInputFeature, inputOffset, sequenceLength,
310
+ i, numIterations, false, tempGradInput);
311
+ __syncthreads();
312
+
313
+ #pragma unroll
314
+ for (int j = 0; j < FS; ++j) {
315
+ accumWeights[j] += tempInput[tid + j] * tempGradInput[tid + (FS/2)];
316
+ }
317
+
318
+ __syncthreads();
319
+ }
320
+
321
+ // Row-major sum
322
+ for (int filterWeightIdx = 0; filterWeightIdx < FS; ++filterWeightIdx) {
323
+
324
+ // Write to shared memory before reduction
325
+ if (tid < sequenceLength) {
326
+ temp = accumWeights[filterWeightIdx];
327
+ } else {
328
+ temp = float(0.0);
329
+ }
330
+
331
+ temp = blockReduce(temp);
332
+
333
+ const int outputOffset = filterWeightIdx * minibatch * numFiltersInBlock +
334
+ batchIdx * numFiltersInBlock +
335
+ idxInFilterBlock;
336
+
337
+ if (tid == 0) {
338
+ tempOutputGradWeight[outputOffset] = temp;
339
+ }
340
+ }
341
+ }
342
+
343
+ template<int FS, int SB, typename scalar_t>
344
+ __global__
345
+ void lightconv_grad_wrt_weights_secondpass_kernel(
346
+ const float* input,
347
+ const int minibatch,
348
+ const int numFiltersInBlock,
349
+ scalar_t* output) {
350
+
351
+ assert(blockDim.x == SB);
352
+ const int tid = threadIdx.x;
353
+
354
+ // What is the id within a minibatch
355
+ const int filterIdx = blockIdx.x;
356
+ const int filterWeightIdx = blockIdx.y;
357
+
358
+ const int inputOffset = filterIdx * FS * minibatch * numFiltersInBlock +
359
+ filterWeightIdx * minibatch * numFiltersInBlock;
360
+ const float* tempInput = &input[inputOffset];
361
+
362
+ int readIndex = tid;
363
+
364
+ float sum = float(0.0);
365
+ while (readIndex < (minibatch * numFiltersInBlock)) {
366
+ sum += tempInput[readIndex];
367
+ readIndex += SB;
368
+ }
369
+
370
+ float temp = blockReduce(sum);
371
+
372
+ if (tid == 0) {
373
+ output[blockIdx.x * FS + blockIdx.y] = temp;
374
+ }
375
+ }
fairseq-0.10.2/fairseq/modules/linearized_convolution.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from fairseq import utils
9
+ from fairseq.incremental_decoding_utils import with_incremental_state
10
+
11
+ from .conv_tbc import ConvTBC
12
+
13
+
14
+ @with_incremental_state
15
+ class LinearizedConvolution(ConvTBC):
16
+ """An optimized version of nn.Conv1d.
17
+
18
+ At training time, this module uses ConvTBC, which is an optimized version
19
+ of Conv1d. At inference time, it optimizes incremental generation (i.e.,
20
+ one time step at a time) by replacing the convolutions with linear layers.
21
+ Note that the input order changes from training to inference.
22
+ """
23
+
24
+ def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
25
+ super().__init__(in_channels, out_channels, kernel_size, **kwargs)
26
+ self._linearized_weight = None
27
+ self.register_backward_hook(self._clear_linearized_weight)
28
+
29
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
30
+ state = ConvTBC.state_dict(self, destination, prefix, keep_vars=keep_vars)
31
+ # don't store redundant _linearized_weight in checkpoints
32
+ if prefix + "_linearized_weight" in state:
33
+ del state[prefix + "_linearized_weight"]
34
+ return state
35
+
36
+ def upgrade_state_dict_named(self, state_dict, name):
37
+ prefix = name + "." if name != "" else ""
38
+ if prefix + "_linearized_weight" in state_dict:
39
+ del state_dict[prefix + "_linearized_weight"]
40
+
41
+ def forward(self, input, incremental_state=None):
42
+ """
43
+ Args:
44
+ incremental_state: Used to buffer signal; if not None, then input is
45
+ expected to contain a single frame. If the input order changes
46
+ between time steps, call reorder_incremental_state.
47
+ Input:
48
+ Time x Batch x Channel during training
49
+ Batch x Time x Channel during inference
50
+ """
51
+ if incremental_state is None:
52
+ output = super().forward(input)
53
+ if self.kernel_size[0] > 1 and self.padding[0] > 0:
54
+ # remove future timesteps added by padding
55
+ output = output[: -self.padding[0], :, :]
56
+ return output
57
+
58
+ # reshape weight
59
+ weight = self._get_linearized_weight()
60
+ kw = self.kernel_size[0]
61
+
62
+ bsz = input.size(0) # input: bsz x len x dim
63
+ if kw > 1:
64
+ input = input.data
65
+ input_buffer = self._get_input_buffer(incremental_state)
66
+ if input_buffer is None:
67
+ input_buffer = input.new(bsz, kw, input.size(2)).zero_()
68
+ self._set_input_buffer(incremental_state, input_buffer)
69
+ else:
70
+ # shift buffer
71
+ input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
72
+ # append next input
73
+ input_buffer[:, -1, :] = input[:, -1, :]
74
+ input = input_buffer
75
+ with torch.no_grad():
76
+ output = F.linear(input.view(bsz, -1), weight, self.bias)
77
+ return output.view(bsz, 1, -1)
78
+
79
+ def reorder_incremental_state(self, incremental_state, new_order):
80
+ input_buffer = self._get_input_buffer(incremental_state)
81
+ if input_buffer is not None:
82
+ input_buffer = input_buffer.index_select(0, new_order)
83
+ self._set_input_buffer(incremental_state, input_buffer)
84
+
85
+ def _get_input_buffer(self, incremental_state):
86
+ return utils.get_incremental_state(self, incremental_state, "input_buffer")
87
+
88
+ def _set_input_buffer(self, incremental_state, new_buffer):
89
+ return utils.set_incremental_state(
90
+ self, incremental_state, "input_buffer", new_buffer
91
+ )
92
+
93
+ def _get_linearized_weight(self):
94
+ if self._linearized_weight is None:
95
+ kw = self.kernel_size[0]
96
+ weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
97
+ assert weight.size() == (self.out_channels, kw, self.in_channels)
98
+ self._linearized_weight = torch.nn.Parameter(
99
+ weight.view(self.out_channels, -1)
100
+ )
101
+ return self._linearized_weight
102
+
103
+ def _clear_linearized_weight(self, *args):
104
+ self._linearized_weight = None
fairseq-0.10.2/fairseq/modules/multihead_attention.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from typing import Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from fairseq import utils
12
+ from fairseq.incremental_decoding_utils import with_incremental_state
13
+ from fairseq.modules.fairseq_dropout import FairseqDropout
14
+ from fairseq.modules.quant_noise import quant_noise
15
+ from torch import Tensor, nn
16
+ from torch.nn import Parameter
17
+
18
+
19
+ @with_incremental_state
20
+ class MultiheadAttention(nn.Module):
21
+ """Multi-headed attention.
22
+
23
+ See "Attention Is All You Need" for more details.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ embed_dim,
29
+ num_heads,
30
+ kdim=None,
31
+ vdim=None,
32
+ dropout=0.0,
33
+ bias=True,
34
+ add_bias_kv=False,
35
+ add_zero_attn=False,
36
+ self_attention=False,
37
+ encoder_decoder_attention=False,
38
+ q_noise=0.0,
39
+ qn_block_size=8,
40
+ ):
41
+ super().__init__()
42
+ self.embed_dim = embed_dim
43
+ self.kdim = kdim if kdim is not None else embed_dim
44
+ self.vdim = vdim if vdim is not None else embed_dim
45
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
46
+
47
+ self.num_heads = num_heads
48
+ self.dropout_module = FairseqDropout(
49
+ dropout, module_name=self.__class__.__name__
50
+ )
51
+
52
+ self.head_dim = embed_dim // num_heads
53
+ assert (
54
+ self.head_dim * num_heads == self.embed_dim
55
+ ), "embed_dim must be divisible by num_heads"
56
+ self.scaling = self.head_dim ** -0.5
57
+
58
+ self.self_attention = self_attention
59
+ self.encoder_decoder_attention = encoder_decoder_attention
60
+
61
+ assert not self.self_attention or self.qkv_same_dim, (
62
+ "Self-attention requires query, key and " "value to be of the same size"
63
+ )
64
+
65
+ self.k_proj = quant_noise(
66
+ nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
67
+ )
68
+ self.v_proj = quant_noise(
69
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
70
+ )
71
+ self.q_proj = quant_noise(
72
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
73
+ )
74
+
75
+ self.out_proj = quant_noise(
76
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
77
+ )
78
+
79
+ if add_bias_kv:
80
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
81
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
82
+ else:
83
+ self.bias_k = self.bias_v = None
84
+
85
+ self.add_zero_attn = add_zero_attn
86
+
87
+ self.reset_parameters()
88
+
89
+ self.onnx_trace = False
90
+ self.tpu = False
91
+
92
+ def prepare_for_onnx_export_(self):
93
+ self.onnx_trace = True
94
+
95
+ def prepare_for_tpu_(self, **kwargs):
96
+ self.tpu = True
97
+
98
+ def reset_parameters(self):
99
+ if self.qkv_same_dim:
100
+ # Empirically observed the convergence to be much better with
101
+ # the scaled initialization
102
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
103
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
104
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
105
+ else:
106
+ nn.init.xavier_uniform_(self.k_proj.weight)
107
+ nn.init.xavier_uniform_(self.v_proj.weight)
108
+ nn.init.xavier_uniform_(self.q_proj.weight)
109
+
110
+ nn.init.xavier_uniform_(self.out_proj.weight)
111
+ if self.out_proj.bias is not None:
112
+ nn.init.constant_(self.out_proj.bias, 0.0)
113
+ if self.bias_k is not None:
114
+ nn.init.xavier_normal_(self.bias_k)
115
+ if self.bias_v is not None:
116
+ nn.init.xavier_normal_(self.bias_v)
117
+
118
+ def forward(
119
+ self,
120
+ query,
121
+ key: Optional[Tensor],
122
+ value: Optional[Tensor],
123
+ key_padding_mask: Optional[Tensor] = None,
124
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
125
+ need_weights: bool = True,
126
+ static_kv: bool = False,
127
+ attn_mask: Optional[Tensor] = None,
128
+ before_softmax: bool = False,
129
+ need_head_weights: bool = False,
130
+ ) -> Tuple[Tensor, Optional[Tensor]]:
131
+ """Input shape: Time x Batch x Channel
132
+
133
+ Args:
134
+ key_padding_mask (ByteTensor, optional): mask to exclude
135
+ keys that are pads, of shape `(batch, src_len)`, where
136
+ padding elements are indicated by 1s.
137
+ need_weights (bool, optional): return the attention weights,
138
+ averaged over heads (default: False).
139
+ attn_mask (ByteTensor, optional): typically used to
140
+ implement causal attention, where the mask prevents the
141
+ attention from looking forward in time (default: None).
142
+ before_softmax (bool, optional): return the raw attention
143
+ weights and values before the attention softmax.
144
+ need_head_weights (bool, optional): return the attention
145
+ weights for each head. Implies *need_weights*. Default:
146
+ return the average attention weights over all heads.
147
+ """
148
+ if need_head_weights:
149
+ need_weights = True
150
+
151
+ tgt_len, bsz, embed_dim = query.size()
152
+ assert embed_dim == self.embed_dim
153
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
154
+
155
+ if (
156
+ not self.onnx_trace
157
+ and not self.tpu # don't use PyTorch version on TPUs
158
+ and incremental_state is None
159
+ and not static_kv
160
+ # A workaround for quantization to work. Otherwise JIT compilation
161
+ # treats bias in linear module as method.
162
+ and not torch.jit.is_scripting()
163
+ ):
164
+ assert key is not None and value is not None
165
+ return F.multi_head_attention_forward(
166
+ query,
167
+ key,
168
+ value,
169
+ self.embed_dim,
170
+ self.num_heads,
171
+ torch.empty([0]),
172
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
173
+ self.bias_k,
174
+ self.bias_v,
175
+ self.add_zero_attn,
176
+ self.dropout_module.p,
177
+ self.out_proj.weight,
178
+ self.out_proj.bias,
179
+ self.training or self.dropout_module.apply_during_inference,
180
+ key_padding_mask,
181
+ need_weights,
182
+ attn_mask,
183
+ use_separate_proj_weight=True,
184
+ q_proj_weight=self.q_proj.weight,
185
+ k_proj_weight=self.k_proj.weight,
186
+ v_proj_weight=self.v_proj.weight,
187
+ )
188
+
189
+ if incremental_state is not None:
190
+ saved_state = self._get_input_buffer(incremental_state)
191
+ if saved_state is not None and "prev_key" in saved_state:
192
+ # previous time steps are cached - no need to recompute
193
+ # key and value if they are static
194
+ if static_kv:
195
+ assert self.encoder_decoder_attention and not self.self_attention
196
+ key = value = None
197
+ else:
198
+ saved_state = None
199
+
200
+ if self.self_attention:
201
+ q = self.q_proj(query)
202
+ k = self.k_proj(query)
203
+ v = self.v_proj(query)
204
+ elif self.encoder_decoder_attention:
205
+ # encoder-decoder attention
206
+ q = self.q_proj(query)
207
+ if key is None:
208
+ assert value is None
209
+ k = v = None
210
+ else:
211
+ k = self.k_proj(key)
212
+ v = self.v_proj(key)
213
+
214
+ else:
215
+ assert key is not None and value is not None
216
+ q = self.q_proj(query)
217
+ k = self.k_proj(key)
218
+ v = self.v_proj(value)
219
+ q *= self.scaling
220
+
221
+ if self.bias_k is not None:
222
+ assert self.bias_v is not None
223
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
224
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
225
+ if attn_mask is not None:
226
+ attn_mask = torch.cat(
227
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
228
+ )
229
+ if key_padding_mask is not None:
230
+ key_padding_mask = torch.cat(
231
+ [
232
+ key_padding_mask,
233
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
234
+ ],
235
+ dim=1,
236
+ )
237
+
238
+ q = (
239
+ q.contiguous()
240
+ .view(tgt_len, bsz * self.num_heads, self.head_dim)
241
+ .transpose(0, 1)
242
+ )
243
+ if k is not None:
244
+ k = (
245
+ k.contiguous()
246
+ .view(-1, bsz * self.num_heads, self.head_dim)
247
+ .transpose(0, 1)
248
+ )
249
+ if v is not None:
250
+ v = (
251
+ v.contiguous()
252
+ .view(-1, bsz * self.num_heads, self.head_dim)
253
+ .transpose(0, 1)
254
+ )
255
+
256
+ if saved_state is not None:
257
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
258
+ if "prev_key" in saved_state:
259
+ _prev_key = saved_state["prev_key"]
260
+ assert _prev_key is not None
261
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
262
+ if static_kv:
263
+ k = prev_key
264
+ else:
265
+ assert k is not None
266
+ k = torch.cat([prev_key, k], dim=1)
267
+ if "prev_value" in saved_state:
268
+ _prev_value = saved_state["prev_value"]
269
+ assert _prev_value is not None
270
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
271
+ if static_kv:
272
+ v = prev_value
273
+ else:
274
+ assert v is not None
275
+ v = torch.cat([prev_value, v], dim=1)
276
+ prev_key_padding_mask: Optional[Tensor] = None
277
+ if "prev_key_padding_mask" in saved_state:
278
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
279
+ assert k is not None and v is not None
280
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
281
+ key_padding_mask=key_padding_mask,
282
+ prev_key_padding_mask=prev_key_padding_mask,
283
+ batch_size=bsz,
284
+ src_len=k.size(1),
285
+ static_kv=static_kv,
286
+ )
287
+
288
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
289
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
290
+ saved_state["prev_key_padding_mask"] = key_padding_mask
291
+ # In this branch incremental_state is never None
292
+ assert incremental_state is not None
293
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
294
+ assert k is not None
295
+ src_len = k.size(1)
296
+
297
+ # This is part of a workaround to get around fork/join parallelism
298
+ # not supporting Optional types.
299
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
300
+ key_padding_mask = None
301
+
302
+ if key_padding_mask is not None:
303
+ assert key_padding_mask.size(0) == bsz
304
+ assert key_padding_mask.size(1) == src_len
305
+
306
+ if self.add_zero_attn:
307
+ assert v is not None
308
+ src_len += 1
309
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
310
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
311
+ if attn_mask is not None:
312
+ attn_mask = torch.cat(
313
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
314
+ )
315
+ if key_padding_mask is not None:
316
+ key_padding_mask = torch.cat(
317
+ [
318
+ key_padding_mask,
319
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
320
+ key_padding_mask
321
+ ),
322
+ ],
323
+ dim=1,
324
+ )
325
+
326
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
327
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
328
+
329
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
330
+
331
+ if attn_mask is not None:
332
+ attn_mask = attn_mask.unsqueeze(0)
333
+ if self.onnx_trace:
334
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
335
+ attn_weights += attn_mask
336
+
337
+ if key_padding_mask is not None:
338
+ # don't attend to padding symbols
339
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
340
+ if not self.tpu:
341
+ attn_weights = attn_weights.masked_fill(
342
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
343
+ float("-inf"),
344
+ )
345
+ else:
346
+ attn_weights = attn_weights.transpose(0, 2)
347
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
348
+ attn_weights = attn_weights.transpose(0, 2)
349
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
350
+
351
+ if before_softmax:
352
+ return attn_weights, v
353
+
354
+ attn_weights_float = utils.softmax(
355
+ attn_weights, dim=-1, onnx_trace=self.onnx_trace
356
+ )
357
+ attn_weights = attn_weights_float.type_as(attn_weights)
358
+ attn_probs = self.dropout_module(attn_weights)
359
+
360
+ assert v is not None
361
+ attn = torch.bmm(attn_probs, v)
362
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
363
+ if self.onnx_trace and attn.size(1) == 1:
364
+ # when ONNX tracing a single decoder step (sequence length == 1)
365
+ # the transpose is a no-op copy before view, thus unnecessary
366
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
367
+ else:
368
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
369
+ attn = self.out_proj(attn)
370
+ attn_weights: Optional[Tensor] = None
371
+ if need_weights:
372
+ attn_weights = attn_weights_float.view(
373
+ bsz, self.num_heads, tgt_len, src_len
374
+ ).transpose(1, 0)
375
+ if not need_head_weights:
376
+ # average attention weights over heads
377
+ attn_weights = attn_weights.mean(dim=0)
378
+
379
+ return attn, attn_weights
380
+
381
+ @staticmethod
382
+ def _append_prev_key_padding_mask(
383
+ key_padding_mask: Optional[Tensor],
384
+ prev_key_padding_mask: Optional[Tensor],
385
+ batch_size: int,
386
+ src_len: int,
387
+ static_kv: bool,
388
+ ) -> Optional[Tensor]:
389
+ # saved key padding masks have shape (bsz, seq_len)
390
+ if prev_key_padding_mask is not None and static_kv:
391
+ new_key_padding_mask = prev_key_padding_mask
392
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
393
+ new_key_padding_mask = torch.cat(
394
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
395
+ )
396
+ # During incremental decoding, as the padding token enters and
397
+ # leaves the frame, there will be a time when prev or current
398
+ # is None
399
+ elif prev_key_padding_mask is not None:
400
+ filler = torch.zeros(
401
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
402
+ device=prev_key_padding_mask.device,
403
+ )
404
+ new_key_padding_mask = torch.cat(
405
+ [prev_key_padding_mask.float(), filler.float()], dim=1
406
+ )
407
+ elif key_padding_mask is not None:
408
+ filler = torch.zeros(
409
+ (batch_size, src_len - key_padding_mask.size(1)),
410
+ device=key_padding_mask.device,
411
+ )
412
+ new_key_padding_mask = torch.cat(
413
+ [filler.float(), key_padding_mask.float()], dim=1
414
+ )
415
+ else:
416
+ new_key_padding_mask = prev_key_padding_mask
417
+ return new_key_padding_mask
418
+
419
+ @torch.jit.export
420
+ def reorder_incremental_state(
421
+ self,
422
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
423
+ new_order: Tensor,
424
+ ):
425
+ """Reorder buffered internal state (for incremental generation)."""
426
+ input_buffer = self._get_input_buffer(incremental_state)
427
+ if input_buffer is not None:
428
+ for k in input_buffer.keys():
429
+ input_buffer_k = input_buffer[k]
430
+ if input_buffer_k is not None:
431
+ if self.encoder_decoder_attention and input_buffer_k.size(
432
+ 0
433
+ ) == new_order.size(0):
434
+ break
435
+ input_buffer[k] = input_buffer_k.index_select(0, new_order)
436
+ incremental_state = self._set_input_buffer(incremental_state, input_buffer)
437
+ return incremental_state
438
+
439
+ def _get_input_buffer(
440
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
441
+ ) -> Dict[str, Optional[Tensor]]:
442
+ result = self.get_incremental_state(incremental_state, "attn_state")
443
+ if result is not None:
444
+ return result
445
+ else:
446
+ empty_result: Dict[str, Optional[Tensor]] = {}
447
+ return empty_result
448
+
449
+ def _set_input_buffer(
450
+ self,
451
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
452
+ buffer: Dict[str, Optional[Tensor]],
453
+ ):
454
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
455
+
456
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
457
+ return attn_weights
458
+
459
+ def upgrade_state_dict_named(self, state_dict, name):
460
+ prefix = name + "." if name != "" else ""
461
+ items_to_add = {}
462
+ keys_to_remove = []
463
+ for k in state_dict.keys():
464
+ if k.endswith(prefix + "in_proj_weight"):
465
+ # in_proj_weight used to be q + k + v with same dimensions
466
+ dim = int(state_dict[k].shape[0] / 3)
467
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
468
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
469
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
470
+
471
+ keys_to_remove.append(k)
472
+
473
+ k_bias = prefix + "in_proj_bias"
474
+ if k_bias in state_dict.keys():
475
+ dim = int(state_dict[k].shape[0] / 3)
476
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
477
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
478
+ dim : 2 * dim
479
+ ]
480
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
481
+
482
+ keys_to_remove.append(prefix + "in_proj_bias")
483
+
484
+ for k in keys_to_remove:
485
+ del state_dict[k]
486
+
487
+ for key, value in items_to_add.items():
488
+ state_dict[key] = value
fairseq-0.10.2/fairseq/modules/positional_embedding.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch.nn as nn
7
+
8
+ from .learned_positional_embedding import LearnedPositionalEmbedding
9
+ from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
10
+
11
+
12
+ def PositionalEmbedding(
13
+ num_embeddings: int,
14
+ embedding_dim: int,
15
+ padding_idx: int,
16
+ learned: bool = False,
17
+ ):
18
+ if learned:
19
+ # if padding_idx is specified then offset the embedding ids by
20
+ # this index and adjust num_embeddings appropriately
21
+ # TODO: The right place for this offset would be inside
22
+ # LearnedPositionalEmbedding. Move this there for a cleaner implementation.
23
+ if padding_idx is not None:
24
+ num_embeddings = num_embeddings + padding_idx + 1
25
+ m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
26
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
27
+ if padding_idx is not None:
28
+ nn.init.constant_(m.weight[padding_idx], 0)
29
+ else:
30
+ m = SinusoidalPositionalEmbedding(
31
+ embedding_dim,
32
+ padding_idx,
33
+ init_size=num_embeddings + padding_idx + 1,
34
+ )
35
+ return m
fairseq-0.10.2/fairseq/modules/quant_noise.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ def quant_noise(module, p, block_size):
11
+ """
12
+ Wraps modules and applies quantization noise to the weights for
13
+ subsequent quantization with Iterative Product Quantization as
14
+ described in "Training with Quantization Noise for Extreme Model Compression"
15
+
16
+ Args:
17
+ - module: nn.Module
18
+ - p: amount of Quantization Noise
19
+ - block_size: size of the blocks for subsequent quantization with iPQ
20
+
21
+ Remarks:
22
+ - Module weights must have the right sizes wrt the block size
23
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
24
+ - For more detail on how to quantize by blocks with convolutional weights,
25
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
26
+ - We implement the simplest form of noise here as stated in the paper
27
+ which consists in randomly dropping blocks
28
+ """
29
+
30
+ # if no quantization noise, don't register hook
31
+ if p <= 0:
32
+ return module
33
+
34
+ # supported modules
35
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
36
+
37
+ # test whether module.weight has the right sizes wrt block_size
38
+ is_conv = module.weight.ndim == 4
39
+
40
+ # 2D matrix
41
+ if not is_conv:
42
+ assert (
43
+ module.weight.size(1) % block_size == 0
44
+ ), "Input features must be a multiple of block sizes"
45
+
46
+ # 4D matrix
47
+ else:
48
+ # 1x1 convolutions
49
+ if module.kernel_size == (1, 1):
50
+ assert (
51
+ module.in_channels % block_size == 0
52
+ ), "Input channels must be a multiple of block sizes"
53
+ # regular convolutions
54
+ else:
55
+ k = module.kernel_size[0] * module.kernel_size[1]
56
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
57
+
58
+ def _forward_pre_hook(mod, input):
59
+ # no noise for evaluation
60
+ if mod.training:
61
+ if not is_conv:
62
+ # gather weight and sizes
63
+ weight = mod.weight
64
+ in_features = weight.size(1)
65
+ out_features = weight.size(0)
66
+
67
+ # split weight matrix into blocks and randomly drop selected blocks
68
+ mask = torch.zeros(
69
+ in_features // block_size * out_features, device=weight.device
70
+ )
71
+ mask.bernoulli_(p)
72
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
73
+
74
+ else:
75
+ # gather weight and sizes
76
+ weight = mod.weight
77
+ in_channels = mod.in_channels
78
+ out_channels = mod.out_channels
79
+
80
+ # split weight matrix into blocks and randomly drop selected blocks
81
+ if mod.kernel_size == (1, 1):
82
+ mask = torch.zeros(
83
+ int(in_channels // block_size * out_channels),
84
+ device=weight.device,
85
+ )
86
+ mask.bernoulli_(p)
87
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
88
+ else:
89
+ mask = torch.zeros(
90
+ weight.size(0), weight.size(1), device=weight.device
91
+ )
92
+ mask.bernoulli_(p)
93
+ mask = (
94
+ mask.unsqueeze(2)
95
+ .unsqueeze(3)
96
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
97
+ )
98
+
99
+ # scale weights and apply mask
100
+ mask = mask.to(
101
+ torch.bool
102
+ ) # x.bool() is not currently supported in TorchScript
103
+ s = 1 / (1 - p)
104
+ mod.weight.data = s * weight.masked_fill(mask, 0)
105
+
106
+ module.register_forward_pre_hook(_forward_pre_hook)
107
+ return module
fairseq-0.10.2/fairseq/modules/quantization/__pycache__/quantization_options.cpython-310.pyc ADDED
Binary file (1.34 kB). View file
 
fairseq-0.10.2/fairseq/modules/quantization/pq/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (247 Bytes). View file
 
fairseq-0.10.2/fairseq/modules/quantization/pq/__pycache__/pq.cpython-310.pyc ADDED
Binary file (3.55 kB). View file
 
fairseq-0.10.2/fairseq/modules/quantization/pq/__pycache__/utils.cpython-310.pyc ADDED
Binary file (9.88 kB). View file
 
fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .qconv import PQConv2d # NOQA
7
+ from .qemb import PQEmbedding # NOQA
8
+ from .qlinear import PQLinear # NOQA
fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (306 Bytes). View file
 
fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__pycache__/qconv.cpython-310.pyc ADDED
Binary file (3.84 kB). View file
 
fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__pycache__/qemb.cpython-310.pyc ADDED
Binary file (3.1 kB). View file
 
fairseq-0.10.2/fairseq/modules/quantization/pq/modules/qemb.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class PQEmbedding(nn.Module):
12
+ """
13
+ Quantized counterpart of nn.Embedding module. Stores the centroids and
14
+ the assignments. The full weight is re-instantiated at each forward
15
+ pass.
16
+
17
+ Args:
18
+ - centroids: centroids of size n_centroids x block_size
19
+ - assignments: assignments of the centroids to the subvectors
20
+ of size self.out_features x n_blocks
21
+ - bias: the non-quantized bias
22
+
23
+ Remarks:
24
+ - We refer the reader to the official documentation of the nn.Embedding module
25
+ for the other arguments and the behavior of the module
26
+ - Performance tests on GPU show that this implementation is 10% slower than
27
+ the non-quantized nn.Embedding module for a standard training loop.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ centroids,
33
+ assignments,
34
+ num_embeddings,
35
+ embedding_dim,
36
+ padding_idx=None,
37
+ max_norm=None,
38
+ norm_type=2.0,
39
+ scale_grad_by_freq=False,
40
+ sparse=False,
41
+ _weight=None,
42
+ ):
43
+ super(PQEmbedding, self).__init__()
44
+ self.block_size = centroids.size(1)
45
+ self.n_centroids = centroids.size(0)
46
+ self.num_embeddings = num_embeddings
47
+ self.embedding_dim = embedding_dim
48
+ if padding_idx is not None:
49
+ if padding_idx > 0:
50
+ assert (
51
+ padding_idx < self.num_embeddings
52
+ ), "Padding_idx must be within num_embeddings"
53
+ elif padding_idx < 0:
54
+ assert (
55
+ padding_idx >= -self.num_embeddings
56
+ ), "Padding_idx must be within num_embeddings"
57
+ padding_idx = self.num_embeddings + padding_idx
58
+ self.padding_idx = padding_idx
59
+ self.max_norm = max_norm
60
+ self.norm_type = norm_type
61
+ self.scale_grad_by_freq = scale_grad_by_freq
62
+ self.sparse = sparse
63
+ # check compatibility
64
+ if self.embedding_dim % self.block_size != 0:
65
+ raise ValueError("Wrong PQ sizes")
66
+ if len(assignments) % self.num_embeddings != 0:
67
+ raise ValueError("Wrong PQ sizes")
68
+ # define parameters
69
+ self.centroids = nn.Parameter(centroids, requires_grad=True)
70
+ self.register_buffer("assignments", assignments)
71
+ self.register_buffer("counts", torch.bincount(assignments).type_as(centroids))
72
+
73
+ @property
74
+ def weight(self):
75
+ return (
76
+ self.centroids[self.assignments]
77
+ .reshape(-1, self.num_embeddings, self.block_size)
78
+ .permute(1, 0, 2)
79
+ .flatten(1, 2)
80
+ )
81
+
82
+ def forward(self, input):
83
+ return F.embedding(
84
+ input,
85
+ self.weight,
86
+ self.padding_idx,
87
+ self.max_norm,
88
+ self.norm_type,
89
+ self.scale_grad_by_freq,
90
+ self.sparse,
91
+ )
92
+
93
+ def extra_repr(self):
94
+ s = "{num_embeddings}, {embedding_dim}"
95
+ if self.padding_idx is not None:
96
+ s += ", padding_idx={padding_idx}"
97
+ if self.max_norm is not None:
98
+ s += ", max_norm={max_norm}"
99
+ if self.norm_type != 2:
100
+ s += ", norm_type={norm_type}"
101
+ if self.scale_grad_by_freq is not False:
102
+ s += ", scale_grad_by_freq={scale_grad_by_freq}"
103
+ if self.sparse is not False:
104
+ s += ", sparse=True"
105
+ s += ", n_centroids={n_centroids}, block_size={block_size}"
106
+
107
+ return s.format(**self.__dict__)
fairseq-0.10.2/fairseq/modules/quantization/pq/modules/qlinear.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class PQLinear(nn.Module):
12
+ """
13
+ Quantized counterpart of nn.Linear module. Stores the centroid, the assignments
14
+ and the non-quantized biases. The full weight is re-instantiated at each forward
15
+ pass.
16
+
17
+ Args:
18
+ - centroids: centroids of size n_centroids x block_size
19
+ - assignments: assignments of the centroids to the subvectors
20
+ of size self.out_features x n_blocks
21
+ - bias: the non-quantized bias
22
+
23
+ Remarks:
24
+ - We refer the reader to the official documentation of the nn.Linear module
25
+ for the other arguments and the behavior of the module
26
+ - Performance tests on GPU show that this implementation is 15% slower than
27
+ the non-quantized nn.Linear module for a standard training loop.
28
+ """
29
+
30
+ def __init__(self, centroids, assignments, bias, in_features, out_features):
31
+ super(PQLinear, self).__init__()
32
+ self.block_size = centroids.size(1)
33
+ self.n_centroids = centroids.size(0)
34
+ self.in_features = in_features
35
+ self.out_features = out_features
36
+ # check compatibility
37
+ if self.in_features % self.block_size != 0:
38
+ raise ValueError("Wrong PQ sizes")
39
+ if len(assignments) % self.out_features != 0:
40
+ raise ValueError("Wrong PQ sizes")
41
+ # define parameters
42
+ self.centroids = nn.Parameter(centroids, requires_grad=True)
43
+ self.register_buffer("assignments", assignments)
44
+ self.register_buffer("counts", torch.bincount(assignments).type_as(centroids))
45
+ if bias is not None:
46
+ self.bias = nn.Parameter(bias)
47
+ else:
48
+ self.register_parameter("bias", None)
49
+
50
+ @property
51
+ def weight(self):
52
+ return (
53
+ self.centroids[self.assignments]
54
+ .reshape(-1, self.out_features, self.block_size)
55
+ .permute(1, 0, 2)
56
+ .flatten(1, 2)
57
+ )
58
+
59
+ def forward(self, x):
60
+ return F.linear(
61
+ x,
62
+ self.weight,
63
+ self.bias,
64
+ )
65
+
66
+ def extra_repr(self):
67
+ return f"in_features={self.in_features},\
68
+ out_features={self.out_features},\
69
+ n_centroids={self.n_centroids},\
70
+ block_size={self.block_size},\
71
+ bias={self.bias is not None}"
fairseq-0.10.2/fairseq/modules/quantization/pq/pq.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .em import EM, EmptyClusterResolveError
7
+
8
+
9
+ class PQ(EM):
10
+ """
11
+ Quantizes the layer weights W with the standard Product Quantization
12
+ technique. This learns a codebook of codewords or centroids of size
13
+ block_size from W. For further reference on using PQ to quantize
14
+ neural networks, see "And the Bit Goes Down: Revisiting the Quantization
15
+ of Neural Networks", Stock et al., ICLR 2020.
16
+
17
+ PQ is performed in two steps:
18
+ (1) The matrix W (weights or fully-connected or convolutional layer)
19
+ is reshaped to (block_size, -1).
20
+ - If W is fully-connected (2D), its columns are split into
21
+ blocks of size block_size.
22
+ - If W is convolutional (4D), its filters are split along the
23
+ spatial dimension.
24
+ (2) We apply the standard EM/k-means algorithm to the resulting reshaped matrix.
25
+
26
+ Args:
27
+ - W: weight matrix to quantize of size (in_features x out_features)
28
+ - block_size: size of the blocks (subvectors)
29
+ - n_centroids: number of centroids
30
+ - n_iter: number of k-means iterations
31
+ - eps: for cluster reassignment when an empty cluster is found
32
+ - max_tentatives for cluster reassignment when an empty cluster is found
33
+ - verbose: print information after each iteration
34
+
35
+ Remarks:
36
+ - block_size be compatible with the shape of W
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ W,
42
+ block_size,
43
+ n_centroids=256,
44
+ n_iter=20,
45
+ eps=1e-6,
46
+ max_tentatives=30,
47
+ verbose=True,
48
+ ):
49
+ self.block_size = block_size
50
+ W_reshaped = self._reshape(W)
51
+ super(PQ, self).__init__(
52
+ W_reshaped,
53
+ n_centroids=n_centroids,
54
+ n_iter=n_iter,
55
+ eps=eps,
56
+ max_tentatives=max_tentatives,
57
+ verbose=verbose,
58
+ )
59
+
60
+ def _reshape(self, W):
61
+ """
62
+ Reshapes the matrix W as expained in step (1).
63
+ """
64
+
65
+ # fully connected: by convention the weight has size out_features x in_features
66
+ if len(W.size()) == 2:
67
+ self.out_features, self.in_features = W.size()
68
+ assert (
69
+ self.in_features % self.block_size == 0
70
+ ), "Linear: n_blocks must be a multiple of in_features"
71
+ return (
72
+ W.reshape(self.out_features, -1, self.block_size)
73
+ .permute(2, 1, 0)
74
+ .flatten(1, 2)
75
+ )
76
+
77
+ # convolutional: we reshape along the spatial dimension
78
+ elif len(W.size()) == 4:
79
+ self.out_channels, self.in_channels, self.k_h, self.k_w = W.size()
80
+ assert (
81
+ self.in_channels * self.k_h * self.k_w
82
+ ) % self.block_size == 0, (
83
+ "Conv2d: n_blocks must be a multiple of in_channels * k_h * k_w"
84
+ )
85
+ return (
86
+ W.reshape(self.out_channels, -1, self.block_size)
87
+ .permute(2, 1, 0)
88
+ .flatten(1, 2)
89
+ )
90
+ # not implemented
91
+ else:
92
+ raise NotImplementedError(W.size())
93
+
94
+ def encode(self):
95
+ """
96
+ Performs self.n_iter EM steps.
97
+ """
98
+
99
+ self.initialize_centroids()
100
+ for i in range(self.n_iter):
101
+ try:
102
+ self.step(i)
103
+ except EmptyClusterResolveError:
104
+ break
105
+
106
+ def decode(self):
107
+ """
108
+ Returns the encoded full weight matrix. Must be called after
109
+ the encode function.
110
+ """
111
+
112
+ # fully connected case
113
+ if "k_h" not in self.__dict__:
114
+ return (
115
+ self.centroids[self.assignments]
116
+ .reshape(-1, self.out_features, self.block_size)
117
+ .permute(1, 0, 2)
118
+ .flatten(1, 2)
119
+ )
120
+
121
+ # convolutional case
122
+ else:
123
+ return (
124
+ self.centroids[self.assignments]
125
+ .reshape(-1, self.out_channels, self.block_size)
126
+ .permute(1, 0, 2)
127
+ .reshape(self.out_channels, self.in_channels, self.k_h, self.k_w)
128
+ )
fairseq-0.10.2/fairseq/modules/quantization/quantization_options.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ def parse_config_yaml(yaml_data):
8
+ # Initialize to default options.
9
+ quantization_options = {
10
+ "n_centroids": {
11
+ "Linear": ["in_features", {"*": 256}],
12
+ "Embedding": ["embedding_dim", {"*": 256}],
13
+ },
14
+ "block_sizes": {
15
+ "Linear": ["fuzzy_name", {"fc": 8, "attn": 4, "emb": 4}],
16
+ "Embedding": ["fuzzy_name", {"emb": 8}],
17
+ },
18
+ "layers_to_quantize": [
19
+ "decoder\\.layers\\.\\d+\\.fc[12]",
20
+ "decoder\\.embed_tokens\\.embeddings\\.[012]\\.[01]",
21
+ "decoder\\.layers\\.\\d+\\.self_attn\\.(k_proj|v_proj|q_proj|out_proj)",
22
+ ],
23
+ }
24
+
25
+ if "n_centroids" in yaml_data:
26
+ quantization_options["n_centroids"] = {
27
+ layer: convert_yaml_to_tuple(layer_data)
28
+ for layer, layer_data in yaml_data["n_centroids"].items()
29
+ }
30
+ if "block_sizes" in yaml_data:
31
+ quantization_options["block_sizes"] = {
32
+ layer: convert_yaml_to_tuple(layer_data)
33
+ for layer, layer_data in yaml_data["block_sizes"].items()
34
+ }
35
+ if "layers_to_quantize" in yaml_data:
36
+ quantization_options["layers_to_quantize"] = yaml_data["layers_to_quantize"]
37
+
38
+ return quantization_options
39
+
40
+
41
+ def convert_yaml_to_tuple(yaml_dictionary):
42
+ """Converts a yaml dictionary with two keys: `key` and `value` into a two
43
+ argument tuple of those values."""
44
+ return (yaml_dictionary["key"], yaml_dictionary["value"])
fairseq-0.10.2/fairseq/modules/quantization/scalar/__pycache__/ops.cpython-310.pyc ADDED
Binary file (1.54 kB). View file
 
fairseq-0.10.2/fairseq/modules/quantization/scalar/modules/__pycache__/qemb.cpython-310.pyc ADDED
Binary file (3.85 kB). View file
 
fairseq-0.10.2/fairseq/modules/quantization/scalar/modules/qemb.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from ..ops import emulate_int
11
+
12
+
13
+ class IntEmbedding(nn.Module):
14
+ """
15
+ Quantized counterpart of the nn.Embedding module that applies QuantNoise during training.
16
+
17
+ Args:
18
+ - num_embeddings: number of tokens
19
+ - embedding_dim: embedding dimension
20
+ - p: amount of noise to inject (0 = no quantization, 1 = quantize all the weights)
21
+ - bits: number of bits
22
+ - method: choose among {"tensor", "histogram", "channel"}
23
+ - update_step: recompute scale and zero_point every update_steps iterations
24
+
25
+ Remarks:
26
+ - We use the straight-through estimator so that the gradients
27
+ back-propagate nicely in the network, this is implemented with
28
+ the detach() trick
29
+ - Parameters scale and zero_point are recomputed every update_step
30
+ forward pass to reduce the overhead
31
+ - At test time, the weights are fully quantized
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ num_embeddings,
37
+ embedding_dim,
38
+ padding_idx=None,
39
+ max_norm=None,
40
+ norm_type=2.0,
41
+ scale_grad_by_freq=False,
42
+ sparse=False,
43
+ _weight=None,
44
+ p=0,
45
+ update_step=1000,
46
+ bits=8,
47
+ method="histogram",
48
+ ):
49
+ super(IntEmbedding, self).__init__()
50
+ self.num_embeddings = num_embeddings
51
+ self.embedding_dim = embedding_dim
52
+ if padding_idx is not None:
53
+ if padding_idx > 0:
54
+ assert (
55
+ padding_idx < self.num_embeddings
56
+ ), "Padding_idx must be within num_embeddings"
57
+ elif padding_idx < 0:
58
+ assert (
59
+ padding_idx >= -self.num_embeddings
60
+ ), "Padding_idx must be within num_embeddings"
61
+ padding_idx = self.num_embeddings + padding_idx
62
+ self.padding_idx = padding_idx
63
+ self.max_norm = max_norm
64
+ self.norm_type = norm_type
65
+ self.scale_grad_by_freq = scale_grad_by_freq
66
+ if _weight is None:
67
+ self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
68
+ self.reset_parameters()
69
+ else:
70
+ assert list(_weight.shape) == [
71
+ num_embeddings,
72
+ embedding_dim,
73
+ ], "Shape of weight does not match num_embeddings and embedding_dim"
74
+ self.weight = nn.Parameter(_weight)
75
+ self.sparse = sparse
76
+
77
+ # quantization parameters
78
+ self.p = p
79
+ self.bits = bits
80
+ self.method = method
81
+ self.update_step = update_step
82
+ self.counter = 0
83
+
84
+ def reset_parameters(self):
85
+ nn.init.normal_(self.weight)
86
+ if self.padding_idx is not None:
87
+ with torch.no_grad():
88
+ self.weight[self.padding_idx].fill_(0)
89
+
90
+ def forward(self, input):
91
+ # train with QuantNoise and evaluate the fully quantized network
92
+ p = self.p if self.training else 1
93
+
94
+ # update parameters every 1000 iterations
95
+ if self.counter % self.update_step == 0:
96
+ self.scale = None
97
+ self.zero_point = None
98
+ self.counter += 1
99
+
100
+ # quantize weight
101
+ weight_quantized, self.scale, self.zero_point = emulate_int(
102
+ self.weight.detach(),
103
+ bits=self.bits,
104
+ method=self.method,
105
+ scale=self.scale,
106
+ zero_point=self.zero_point,
107
+ )
108
+
109
+ # mask to apply noise
110
+ mask = torch.zeros_like(self.weight)
111
+ mask.bernoulli_(1 - p)
112
+ noise = (weight_quantized - self.weight).masked_fill(mask.bool(), 0)
113
+
114
+ # using straight-through estimator (STE)
115
+ clamp_low = -self.scale * self.zero_point
116
+ clamp_high = self.scale * (2 ** self.bits - 1 - self.zero_point)
117
+ weight = (
118
+ torch.clamp(self.weight, clamp_low.item(), clamp_high.item())
119
+ + noise.detach()
120
+ )
121
+
122
+ # return output
123
+ output = F.embedding(
124
+ input,
125
+ weight,
126
+ self.padding_idx,
127
+ self.max_norm,
128
+ self.norm_type,
129
+ self.scale_grad_by_freq,
130
+ self.sparse,
131
+ )
132
+ return output
133
+
134
+ def extra_repr(self):
135
+ s = "{num_embeddings}, {embedding_dim}"
136
+ if self.padding_idx is not None:
137
+ s += ", padding_idx={padding_idx}"
138
+ if self.max_norm is not None:
139
+ s += ", max_norm={max_norm}"
140
+ if self.norm_type != 2:
141
+ s += ", norm_type={norm_type}"
142
+ if self.scale_grad_by_freq is not False:
143
+ s += ", scale_grad_by_freq={scale_grad_by_freq}"
144
+ if self.sparse is not False:
145
+ s += ", sparse=True"
146
+ s += "quant_noise={p}, bits={bits}, method={method}"
147
+ return s.format(**self.__dict__)
fairseq-0.10.2/fairseq/modules/sparse_transformer_sentence_encoder.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch.nn as nn
7
+ from fairseq.modules import TransformerSentenceEncoder
8
+ from fairseq.modules.sparse_transformer_sentence_encoder_layer import (
9
+ SparseTransformerSentenceEncoderLayer,
10
+ )
11
+
12
+
13
+ class SparseTransformerSentenceEncoder(TransformerSentenceEncoder):
14
+ """
15
+ Sparse implementation of the TransformerSentenceEncoder
16
+ - see SparseMultiheadAttention
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ padding_idx: int,
22
+ vocab_size: int,
23
+ num_encoder_layers: int = 6,
24
+ embedding_dim: int = 768,
25
+ ffn_embedding_dim: int = 3072,
26
+ num_attention_heads: int = 8,
27
+ dropout: float = 0.1,
28
+ attention_dropout: float = 0.1,
29
+ activation_dropout: float = 0.1,
30
+ max_seq_len: int = 256,
31
+ num_segments: int = 2,
32
+ use_position_embeddings: bool = True,
33
+ offset_positions_by_padding: bool = True,
34
+ encoder_normalize_before: bool = False,
35
+ apply_bert_init: bool = False,
36
+ activation_fn: str = "relu",
37
+ learned_pos_embedding: bool = True,
38
+ embed_scale: float = None,
39
+ freeze_embeddings: bool = False,
40
+ n_trans_layers_to_freeze: int = 0,
41
+ export: bool = False,
42
+ is_bidirectional: bool = True,
43
+ stride: int = 32,
44
+ expressivity: int = 8,
45
+ ) -> None:
46
+
47
+ super().__init__(
48
+ padding_idx,
49
+ vocab_size,
50
+ num_encoder_layers,
51
+ embedding_dim,
52
+ ffn_embedding_dim,
53
+ num_attention_heads,
54
+ dropout,
55
+ attention_dropout,
56
+ activation_dropout,
57
+ max_seq_len,
58
+ num_segments,
59
+ use_position_embeddings,
60
+ offset_positions_by_padding,
61
+ encoder_normalize_before,
62
+ apply_bert_init,
63
+ activation_fn,
64
+ learned_pos_embedding,
65
+ embed_scale,
66
+ freeze_embeddings,
67
+ n_trans_layers_to_freeze,
68
+ export,
69
+ )
70
+
71
+ self.layers = nn.ModuleList(
72
+ [
73
+ SparseTransformerSentenceEncoderLayer(
74
+ embedding_dim=self.embedding_dim,
75
+ ffn_embedding_dim=ffn_embedding_dim,
76
+ num_attention_heads=num_attention_heads,
77
+ dropout=dropout,
78
+ attention_dropout=attention_dropout,
79
+ activation_dropout=activation_dropout,
80
+ activation_fn=activation_fn,
81
+ export=export,
82
+ is_bidirectional=is_bidirectional,
83
+ stride=stride,
84
+ expressivity=expressivity,
85
+ )
86
+ for _ in range(num_encoder_layers)
87
+ ]
88
+ )
89
+
90
+ def freeze_module_params(m):
91
+ if m is not None:
92
+ for p in m.parameters():
93
+ p.requires_grad = False
94
+
95
+ for layer in range(n_trans_layers_to_freeze):
96
+ freeze_module_params(self.layers[layer])