sleepyhead111 commited on
Commit
a2608e1
·
verified ·
1 Parent(s): 04827bc

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq-0.10.2/fairseq/criterions/__init__.py +38 -0
  2. fairseq-0.10.2/fairseq/criterions/__pycache__/ctc.cpython-310.pyc +0 -0
  3. fairseq-0.10.2/fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc +0 -0
  4. fairseq-0.10.2/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc +0 -0
  5. fairseq-0.10.2/fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc +0 -0
  6. fairseq-0.10.2/fairseq/criterions/adaptive_loss.py +123 -0
  7. fairseq-0.10.2/fairseq/criterions/composite_loss.py +100 -0
  8. fairseq-0.10.2/fairseq/criterions/ctc.py +253 -0
  9. fairseq-0.10.2/fairseq/criterions/fairseq_criterion.py +119 -0
  10. fairseq-0.10.2/fairseq/criterions/legacy_masked_lm.py +177 -0
  11. fairseq-0.10.2/fairseq/criterions/sentence_prediction.py +99 -0
  12. fairseq-0.10.2/fairseq/criterions/sentence_ranking.py +120 -0
  13. fairseq-0.10.2/fairseq/model_parallel/models/__pycache__/transformer.cpython-310.pyc +0 -0
  14. fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py +721 -0
  15. fairseq-0.10.2/fairseq/model_parallel/models/roberta/__init__.py +6 -0
  16. fairseq-0.10.2/fairseq/model_parallel/models/roberta/__pycache__/__init__.cpython-310.pyc +0 -0
  17. fairseq-0.10.2/fairseq/models/__pycache__/__init__.cpython-310.pyc +0 -0
  18. fairseq-0.10.2/fairseq/models/__pycache__/distributed_fairseq_model.cpython-310.pyc +0 -0
  19. fairseq-0.10.2/fairseq/models/__pycache__/fairseq_decoder.cpython-310.pyc +0 -0
  20. fairseq-0.10.2/fairseq/models/__pycache__/fairseq_encoder.cpython-310.pyc +0 -0
  21. fairseq-0.10.2/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-310.pyc +0 -0
  22. fairseq-0.10.2/fairseq/models/__pycache__/fairseq_model.cpython-310.pyc +0 -0
  23. fairseq-0.10.2/fairseq/models/__pycache__/fconv.cpython-310.pyc +0 -0
  24. fairseq-0.10.2/fairseq/models/__pycache__/fconv_lm.cpython-310.pyc +0 -0
  25. fairseq-0.10.2/fairseq/models/__pycache__/fconv_self_att.cpython-310.pyc +0 -0
  26. fairseq-0.10.2/fairseq/models/__pycache__/lightconv.cpython-310.pyc +0 -0
  27. fairseq-0.10.2/fairseq/models/__pycache__/lightconv_lm.cpython-310.pyc +0 -0
  28. fairseq-0.10.2/fairseq/models/__pycache__/lstm.cpython-310.pyc +0 -0
  29. fairseq-0.10.2/fairseq/models/__pycache__/lstm_lm.cpython-310.pyc +0 -0
  30. fairseq-0.10.2/fairseq/models/__pycache__/masked_lm.cpython-310.pyc +0 -0
  31. fairseq-0.10.2/fairseq/models/__pycache__/model_utils.cpython-310.pyc +0 -0
  32. fairseq-0.10.2/fairseq/models/__pycache__/multilingual_transformer.cpython-310.pyc +0 -0
  33. fairseq-0.10.2/fairseq/models/__pycache__/transformer.cpython-310.pyc +0 -0
  34. fairseq-0.10.2/fairseq/models/__pycache__/transformer_lm.cpython-310.pyc +0 -0
  35. fairseq-0.10.2/fairseq/models/bart/__init__.py +7 -0
  36. fairseq-0.10.2/fairseq/models/bart/hub_interface.py +201 -0
  37. fairseq-0.10.2/fairseq/models/distributed_fairseq_model.py +103 -0
  38. fairseq-0.10.2/fairseq/models/fairseq_decoder.py +90 -0
  39. fairseq-0.10.2/fairseq/models/fairseq_encoder.py +92 -0
  40. fairseq-0.10.2/fairseq/models/fairseq_model.py +556 -0
  41. fairseq-0.10.2/fairseq/models/huggingface/hf_gpt2.py +168 -0
  42. fairseq-0.10.2/fairseq/models/lightconv_lm.py +306 -0
  43. fairseq-0.10.2/fairseq/models/masked_lm.py +403 -0
  44. fairseq-0.10.2/fairseq/models/model_utils.py +92 -0
  45. fairseq-0.10.2/fairseq/models/multilingual_transformer.py +228 -0
  46. fairseq-0.10.2/fairseq/models/roberta/__pycache__/hub_interface.cpython-310.pyc +0 -0
  47. fairseq-0.10.2/fairseq/models/roberta/__pycache__/model_camembert.cpython-310.pyc +0 -0
  48. fairseq-0.10.2/fairseq/models/roberta/__pycache__/model_xlmr.cpython-310.pyc +0 -0
  49. fairseq-0.10.2/fairseq/models/roberta/model_xlmr.py +44 -0
  50. fairseq-0.10.2/fairseq/modules/__pycache__/beamable_mm.cpython-310.pyc +0 -0
fairseq-0.10.2/fairseq/criterions/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """isort:skip_file"""
6
+
7
+ import importlib
8
+ import os
9
+ from argparse import Namespace
10
+ from typing import Union
11
+
12
+ from fairseq import registry
13
+ from fairseq.criterions.fairseq_criterion import ( # noqa
14
+ FairseqCriterion,
15
+ LegacyFairseqCriterion,
16
+ )
17
+ from omegaconf import DictConfig
18
+
19
+
20
+ (
21
+ build_criterion_,
22
+ register_criterion,
23
+ CRITERION_REGISTRY,
24
+ CRITERION_DATACLASS_REGISTRY,
25
+ ) = registry.setup_registry(
26
+ "--criterion", base_class=FairseqCriterion, default="cross_entropy"
27
+ )
28
+
29
+
30
+ def build_criterion(criterion_cfg: Union[DictConfig, Namespace], task):
31
+ return build_criterion_(criterion_cfg, task)
32
+
33
+
34
+ # automatically import any Python files in the criterions/ directory
35
+ for file in os.listdir(os.path.dirname(__file__)):
36
+ if file.endswith(".py") and not file.startswith("_"):
37
+ file_name = file[: file.find(".py")]
38
+ importlib.import_module("fairseq.criterions." + file_name)
fairseq-0.10.2/fairseq/criterions/__pycache__/ctc.cpython-310.pyc ADDED
Binary file (7.02 kB). View file
 
fairseq-0.10.2/fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc ADDED
Binary file (4.36 kB). View file
 
fairseq-0.10.2/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc ADDED
Binary file (5.46 kB). View file
 
fairseq-0.10.2/fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc ADDED
Binary file (3.88 kB). View file
 
fairseq-0.10.2/fairseq/criterions/adaptive_loss.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass
8
+
9
+ import torch.nn.functional as F
10
+ from fairseq import metrics, utils
11
+ from fairseq.criterions import FairseqCriterion, register_criterion
12
+ from fairseq.dataclass import FairseqDataclass
13
+ from fairseq.dataclass.constants import DDP_BACKEND_CHOICES
14
+ from omegaconf import II
15
+
16
+
17
+ @dataclass
18
+ class AdaptiveLossConfig(FairseqDataclass):
19
+ sentence_avg: bool = II("params.optimization.sentence_avg")
20
+ ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend")
21
+
22
+
23
+ @register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig)
24
+ class AdaptiveLoss(FairseqCriterion):
25
+ """This is an implementation of the loss function accompanying the adaptive softmax approximation for
26
+ graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
27
+ (http://arxiv.org/abs/1609.04309)."""
28
+
29
+ def __init__(self, task, sentence_avg):
30
+ super().__init__(task)
31
+ self.sentence_avg = sentence_avg
32
+
33
+ @classmethod
34
+ def build_criterion(cls, args, task):
35
+ if getattr(args, "ddp_backend", None) == "c10d":
36
+ raise Exception(
37
+ "AdaptiveLoss is not compatible with the c10d "
38
+ "version of DistributedDataParallel. Please use "
39
+ "`--ddp-backend=no_c10d` instead."
40
+ )
41
+ return cls(task, args.sentence_avg)
42
+
43
+ def forward(self, model, sample, reduce=True):
44
+ """Compute the loss for the given sample.
45
+
46
+ Returns a tuple with three elements:
47
+ 1) the loss
48
+ 2) the sample size, which is used as the denominator for the gradient
49
+ 3) logging outputs to display while training
50
+ """
51
+
52
+ assert (
53
+ hasattr(model.decoder, "adaptive_softmax")
54
+ and model.decoder.adaptive_softmax is not None
55
+ )
56
+ adaptive_softmax = model.decoder.adaptive_softmax
57
+
58
+ net_output = model(**sample["net_input"])
59
+ orig_target = model.get_targets(sample, net_output)
60
+
61
+ nsentences = orig_target.size(0)
62
+ orig_target = orig_target.view(-1)
63
+
64
+ bsz = orig_target.size(0)
65
+
66
+ logits, target = adaptive_softmax(net_output[0], orig_target)
67
+ assert len(target) == len(logits)
68
+
69
+ loss = net_output[0].new(1 if reduce else bsz).zero_()
70
+
71
+ for i in range(len(target)):
72
+ if target[i] is not None:
73
+ assert target[i].min() >= 0 and target[i].max() <= logits[i].size(1)
74
+ loss += F.cross_entropy(
75
+ logits[i],
76
+ target[i],
77
+ ignore_index=self.padding_idx,
78
+ reduction="sum" if reduce else "none",
79
+ )
80
+
81
+ orig = utils.strip_pad(orig_target, self.padding_idx)
82
+ ntokens = orig.numel()
83
+ sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
84
+ logging_output = {
85
+ "loss": loss.data,
86
+ "ntokens": ntokens,
87
+ "nsentences": nsentences,
88
+ "sample_size": sample_size,
89
+ }
90
+ return loss, sample_size, logging_output
91
+
92
+ @staticmethod
93
+ def reduce_metrics(logging_outputs) -> None:
94
+ """Aggregate logging outputs from data parallel training."""
95
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
96
+ ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
97
+ sample_size = utils.item(
98
+ sum(log.get("sample_size", 0) for log in logging_outputs)
99
+ )
100
+
101
+ metrics.log_scalar(
102
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
103
+ )
104
+ if sample_size != ntokens:
105
+ metrics.log_scalar(
106
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
107
+ )
108
+ metrics.log_derived(
109
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
110
+ )
111
+ else:
112
+ metrics.log_derived(
113
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
114
+ )
115
+
116
+ @staticmethod
117
+ def logging_outputs_can_be_summed() -> bool:
118
+ """
119
+ Whether the logging outputs returned by `forward` can be summed
120
+ across workers prior to calling `reduce_metrics`. Setting this
121
+ to True will improves distributed training speed.
122
+ """
123
+ return True
fairseq-0.10.2/fairseq/criterions/composite_loss.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq import utils
7
+ from fairseq.criterions import LegacyFairseqCriterion, register_criterion
8
+ from torch import nn
9
+
10
+
11
+ @register_criterion("composite_loss")
12
+ class CompositeLoss(LegacyFairseqCriterion):
13
+ """This is a composite loss that, given a list of model outputs and a list of targets,
14
+ computes an average of losses for each output-target pair"""
15
+
16
+ def __init__(self, args, task):
17
+ super().__init__(args, task)
18
+ self.underlying_criterion = args.underlying_criterion
19
+
20
+ @staticmethod
21
+ def add_args(parser):
22
+ """Add criterion-specific arguments to the parser."""
23
+ # fmt: off
24
+ parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True,
25
+ help='underlying criterion to use for the composite loss')
26
+ # fmt: on
27
+
28
+ @staticmethod
29
+ def build_underlying_criterion(args, task):
30
+ saved_criterion = args.criterion
31
+ args.criterion = args.underlying_criterion
32
+ assert saved_criterion != args.underlying_criterion
33
+ underlying_criterion = task.build_criterion(args)
34
+ args.criterion = saved_criterion
35
+ return underlying_criterion
36
+
37
+ @classmethod
38
+ def build_criterion(cls, args, task):
39
+ underlying_criterion = CompositeLoss.build_underlying_criterion(args, task)
40
+
41
+ class FakeModel(nn.Module):
42
+ def __init__(self, model, net_out, target):
43
+ super().__init__()
44
+ self.model = model
45
+ self.net_out = net_out
46
+ self.target = target
47
+
48
+ def forward(self, **unused):
49
+ return self.net_out
50
+
51
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
52
+ return self.model.get_normalized_probs(
53
+ net_output, log_probs, sample=sample
54
+ )
55
+
56
+ def get_targets(self, *unused):
57
+ return self.target
58
+
59
+ @property
60
+ def decoder(self):
61
+ return self.model.decoder
62
+
63
+ class _CompositeLoss(LegacyFairseqCriterion):
64
+ def __init__(self, args, task, underlying_criterion):
65
+ super().__init__(args, task)
66
+ self.underlying_criterion = underlying_criterion
67
+
68
+ def forward(self, model, sample, reduce=True):
69
+ net_outputs = model(**sample["net_input"])
70
+ targets = sample["target"]
71
+
72
+ bsz = targets[0].size(0)
73
+ loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_()
74
+
75
+ sample_size = 0
76
+ logging_output = {}
77
+ for o, t in zip(net_outputs[0], targets):
78
+ m = FakeModel(model, (o, net_outputs[1]), t)
79
+ sample["target"] = t
80
+ l, ss, logging_output = self.underlying_criterion(m, sample, reduce)
81
+ loss += l
82
+ sample_size += ss
83
+
84
+ loss.div_(len(targets))
85
+ sample_size /= len(targets)
86
+
87
+ logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
88
+ return loss, sample_size, logging_output
89
+
90
+ @staticmethod
91
+ def aggregate_logging_outputs(logging_outputs):
92
+ return underlying_criterion.__class__.aggregate_logging_outputs(
93
+ logging_outputs
94
+ )
95
+
96
+ @staticmethod
97
+ def reduce_metrics(logging_outputs) -> None:
98
+ underlying_criterion.__class__.reduce_metrics(logging_outputs)
99
+
100
+ return _CompositeLoss(args, task, underlying_criterion)
fairseq-0.10.2/fairseq/criterions/ctc.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All rights reserved.
2
+ #
3
+ # This source code is licensed under the license found in the LICENSE file in
4
+ # the root directory of this source tree. An additional grant of patent rights
5
+ # can be found in the PATENTS file in the same directory.
6
+
7
+ import math
8
+ from argparse import Namespace
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from fairseq import metrics, utils
13
+ from fairseq.criterions import FairseqCriterion, register_criterion
14
+ from fairseq.data.data_utils import post_process
15
+ from fairseq.logging.meters import safe_round
16
+
17
+
18
+ @register_criterion("ctc")
19
+ class CtcCriterion(FairseqCriterion):
20
+ def __init__(self, task, wer_args, zero_infinity, sentence_avg, remove_bpe):
21
+ super().__init__(task)
22
+ self.blank_idx = task.target_dictionary.bos()
23
+ self.pad_idx = task.target_dictionary.pad()
24
+ self.eos_idx = task.target_dictionary.eos()
25
+ self.post_process = remove_bpe if remove_bpe else "letter"
26
+
27
+ if wer_args is not None:
28
+ from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
29
+
30
+ wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(wer_args)
31
+
32
+ dec_args = Namespace()
33
+ dec_args.nbest = 1
34
+ dec_args.criterion = "ctc"
35
+ dec_args.kenlm_model = wer_compute_kenlm
36
+ dec_args.lexicon = wer_lexicon
37
+ dec_args.beam = 50
38
+ dec_args.beam_size_token = min(50, len(task.target_dictionary))
39
+ dec_args.beam_threshold = min(50, len(task.target_dictionary))
40
+ dec_args.lm_weight = lm_w
41
+ dec_args.word_score = ws_w
42
+ dec_args.unk_weight = -math.inf
43
+ dec_args.sil_weight = 0
44
+
45
+ self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
46
+ else:
47
+ self.w2l_decoder = None
48
+
49
+ self.zero_infinity = zero_infinity
50
+ self.sentence_avg = sentence_avg
51
+
52
+ @staticmethod
53
+ def add_args(parser):
54
+ """Add criterion-specific arguments to the parser."""
55
+ parser.add_argument(
56
+ "--zero-infinity", action="store_true", help="zero inf loss"
57
+ )
58
+ try:
59
+ parser.add_argument(
60
+ "--remove-bpe",
61
+ "--post-process",
62
+ default="letter",
63
+ help="remove BPE tokens before scoring (can be set to sentencepiece, letter, and more)",
64
+ )
65
+ except:
66
+ pass # this option might have been added from eval args
67
+ parser.add_argument(
68
+ "--wer-args",
69
+ type=str,
70
+ default=None,
71
+ help="options for wer computation on valid set using 4 gram lm. this should be a tuple of 4 elements: path to 4-gram lm, \
72
+ path to lexicon, lm score, word score",
73
+ )
74
+
75
+ def forward(self, model, sample, reduce=True):
76
+ net_output = model(**sample["net_input"])
77
+ lprobs = model.get_normalized_probs(
78
+ net_output, log_probs=True
79
+ ).contiguous() # (T, B, C) from the encoder
80
+
81
+ if "src_lengths" in sample["net_input"]:
82
+ input_lengths = sample["net_input"]["src_lengths"]
83
+ else:
84
+ non_padding_mask = ~net_output["padding_mask"]
85
+ input_lengths = non_padding_mask.long().sum(-1)
86
+
87
+ pad_mask = (sample["target"] != self.pad_idx) & (
88
+ sample["target"] != self.eos_idx
89
+ )
90
+ targets_flat = sample["target"].masked_select(pad_mask)
91
+ target_lengths = sample["target_lengths"]
92
+
93
+ with torch.backends.cudnn.flags(enabled=False):
94
+ loss = F.ctc_loss(
95
+ lprobs,
96
+ targets_flat,
97
+ input_lengths,
98
+ target_lengths,
99
+ blank=self.blank_idx,
100
+ reduction="sum",
101
+ zero_infinity=self.zero_infinity,
102
+ )
103
+
104
+ ntokens = (
105
+ sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
106
+ )
107
+
108
+ sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
109
+ logging_output = {
110
+ "loss": utils.item(loss.data), # * sample['ntokens'],
111
+ "ntokens": ntokens,
112
+ "nsentences": sample["id"].numel(),
113
+ "sample_size": sample_size,
114
+ }
115
+
116
+ if not model.training:
117
+ import editdistance
118
+
119
+ with torch.no_grad():
120
+ lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
121
+
122
+ c_err = 0
123
+ c_len = 0
124
+ w_errs = 0
125
+ w_len = 0
126
+ wv_errs = 0
127
+ for lp, t, inp_l in zip(
128
+ lprobs_t,
129
+ sample["target_label"]
130
+ if "target_label" in sample
131
+ else sample["target"],
132
+ input_lengths,
133
+ ):
134
+ lp = lp[:inp_l].unsqueeze(0)
135
+
136
+ decoded = None
137
+ if self.w2l_decoder is not None:
138
+ decoded = self.w2l_decoder.decode(lp)
139
+ if len(decoded) < 1:
140
+ decoded = None
141
+ else:
142
+ decoded = decoded[0]
143
+ if len(decoded) < 1:
144
+ decoded = None
145
+ else:
146
+ decoded = decoded[0]
147
+
148
+ p = (t != self.task.target_dictionary.pad()) & (
149
+ t != self.task.target_dictionary.eos()
150
+ )
151
+ targ = t[p]
152
+ targ_units = self.task.target_dictionary.string(targ)
153
+ targ_units_arr = targ.tolist()
154
+
155
+ toks = lp.argmax(dim=-1).unique_consecutive()
156
+ pred_units_arr = toks[toks != self.blank_idx].tolist()
157
+
158
+ c_err += editdistance.eval(pred_units_arr, targ_units_arr)
159
+ c_len += len(targ_units_arr)
160
+
161
+ targ_words = post_process(targ_units, self.post_process).split()
162
+
163
+ pred_units = self.task.target_dictionary.string(pred_units_arr)
164
+ pred_words_raw = post_process(pred_units, self.post_process).split()
165
+
166
+ if decoded is not None and "words" in decoded:
167
+ pred_words = decoded["words"]
168
+ w_errs += editdistance.eval(pred_words, targ_words)
169
+ wv_errs += editdistance.eval(pred_words_raw, targ_words)
170
+ else:
171
+ dist = editdistance.eval(pred_words_raw, targ_words)
172
+ w_errs += dist
173
+ wv_errs += dist
174
+
175
+ w_len += len(targ_words)
176
+
177
+ logging_output["wv_errors"] = wv_errs
178
+ logging_output["w_errors"] = w_errs
179
+ logging_output["w_total"] = w_len
180
+ logging_output["c_errors"] = c_err
181
+ logging_output["c_total"] = c_len
182
+
183
+ return loss, sample_size, logging_output
184
+
185
+ @staticmethod
186
+ def reduce_metrics(logging_outputs) -> None:
187
+ """Aggregate logging outputs from data parallel training."""
188
+
189
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
190
+ ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
191
+ nsentences = utils.item(
192
+ sum(log.get("nsentences", 0) for log in logging_outputs)
193
+ )
194
+ sample_size = utils.item(
195
+ sum(log.get("sample_size", 0) for log in logging_outputs)
196
+ )
197
+
198
+ metrics.log_scalar(
199
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
200
+ )
201
+ metrics.log_scalar("ntokens", ntokens)
202
+ metrics.log_scalar("nsentences", nsentences)
203
+ if sample_size != ntokens:
204
+ metrics.log_scalar(
205
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
206
+ )
207
+
208
+ c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
209
+ metrics.log_scalar("_c_errors", c_errors)
210
+ c_total = sum(log.get("c_total", 0) for log in logging_outputs)
211
+ metrics.log_scalar("_c_total", c_total)
212
+ w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
213
+ metrics.log_scalar("_w_errors", w_errors)
214
+ wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
215
+ metrics.log_scalar("_wv_errors", wv_errors)
216
+ w_total = sum(log.get("w_total", 0) for log in logging_outputs)
217
+ metrics.log_scalar("_w_total", w_total)
218
+
219
+ if c_total > 0:
220
+ metrics.log_derived(
221
+ "uer",
222
+ lambda meters: safe_round(
223
+ meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
224
+ )
225
+ if meters["_c_total"].sum > 0
226
+ else float("nan"),
227
+ )
228
+ if w_total > 0:
229
+ metrics.log_derived(
230
+ "wer",
231
+ lambda meters: safe_round(
232
+ meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
233
+ )
234
+ if meters["_w_total"].sum > 0
235
+ else float("nan"),
236
+ )
237
+ metrics.log_derived(
238
+ "raw_wer",
239
+ lambda meters: safe_round(
240
+ meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
241
+ )
242
+ if meters["_w_total"].sum > 0
243
+ else float("nan"),
244
+ )
245
+
246
+ @staticmethod
247
+ def logging_outputs_can_be_summed() -> bool:
248
+ """
249
+ Whether the logging outputs returned by `forward` can be summed
250
+ across workers prior to calling `reduce_metrics`. Setting this
251
+ to True will improves distributed training speed.
252
+ """
253
+ return True
fairseq-0.10.2/fairseq/criterions/fairseq_criterion.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import inspect
7
+ from typing import Any, Dict, List
8
+
9
+ from fairseq import metrics, utils
10
+ from fairseq.dataclass.utils import gen_parser_from_dataclass
11
+ from torch.nn.modules.loss import _Loss
12
+
13
+
14
+ class FairseqCriterion(_Loss):
15
+ def __init__(self, task):
16
+ super().__init__()
17
+ self.task = task
18
+ if hasattr(task, "target_dictionary"):
19
+ tgt_dict = task.target_dictionary
20
+ self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100
21
+
22
+ @classmethod
23
+ def add_args(cls, parser):
24
+ """Add criterion-specific arguments to the parser."""
25
+ dc = getattr(cls, "__dataclass", None)
26
+ if dc is not None:
27
+ gen_parser_from_dataclass(parser, dc())
28
+
29
+ @classmethod
30
+ def build_criterion(cls, args, task):
31
+ """Construct a criterion from command-line args."""
32
+ # Criterions can override this, but for convenience we also try
33
+ # to automatically map argparse.Namespace keys to corresponding
34
+ # arguments in the __init__.
35
+ init_args = {}
36
+ for p in inspect.signature(cls).parameters.values():
37
+ if (
38
+ p.kind == p.POSITIONAL_ONLY
39
+ or p.kind == p.VAR_POSITIONAL
40
+ or p.kind == p.VAR_KEYWORD
41
+ ):
42
+ # we haven't implemented inference for these argument types,
43
+ # but PRs welcome :)
44
+ raise NotImplementedError("{} not supported".format(p.kind))
45
+
46
+ assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
47
+
48
+ if p.name == "task":
49
+ init_args["task"] = task
50
+ elif hasattr(args, p.name):
51
+ init_args[p.name] = getattr(args, p.name)
52
+ elif p.default != p.empty:
53
+ pass # we'll use the default value
54
+ else:
55
+ raise NotImplementedError(
56
+ "Unable to infer Criterion arguments, please implement "
57
+ "{}.build_criterion".format(cls.__name__)
58
+ )
59
+ return cls(**init_args)
60
+
61
+ def forward(self, model, sample, reduce=True):
62
+ """Compute the loss for the given sample.
63
+
64
+ Returns a tuple with three elements:
65
+ 1) the loss
66
+ 2) the sample size, which is used as the denominator for the gradient
67
+ 3) logging outputs to display while training
68
+ """
69
+ raise NotImplementedError
70
+
71
+ @staticmethod
72
+ def aggregate_logging_outputs(
73
+ logging_outputs: List[Dict[str, Any]],
74
+ ) -> Dict[str, Any]:
75
+ """Aggregate logging outputs from data parallel training."""
76
+ utils.deprecation_warning(
77
+ "The aggregate_logging_outputs API is deprecated. "
78
+ "Please use the reduce_metrics API instead."
79
+ )
80
+ raise NotImplementedError
81
+
82
+ @classmethod
83
+ def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
84
+ """Aggregate logging outputs from data parallel training."""
85
+ utils.deprecation_warning(
86
+ "Criterions should implement the reduce_metrics API. "
87
+ "Falling back to deprecated aggregate_logging_outputs API."
88
+ )
89
+ agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
90
+ for k, v in agg_logging_outputs.items():
91
+ if k in {"nsentences", "ntokens", "sample_size"}:
92
+ continue
93
+ metrics.log_scalar(k, v)
94
+
95
+ @staticmethod
96
+ def logging_outputs_can_be_summed() -> bool:
97
+ """
98
+ Whether the logging outputs returned by `forward` can be summed
99
+ across workers prior to calling `reduce_metrics`. Setting this
100
+ to True will improves distributed training speed.
101
+ """
102
+ return False
103
+
104
+
105
+ class LegacyFairseqCriterion(FairseqCriterion):
106
+ def __init__(self, args, task):
107
+ super().__init__(task=task)
108
+ self.args = args
109
+
110
+ utils.deprecation_warning(
111
+ "Criterions should take explicit arguments instead of an "
112
+ "argparse.Namespace object, please update your criterion by "
113
+ "extending FairseqCriterion instead of LegacyFairseqCriterion."
114
+ )
115
+
116
+ @classmethod
117
+ def build_criterion(cls, args, task):
118
+ """Construct a criterion from command-line args."""
119
+ return cls(args, task)
fairseq-0.10.2/fairseq/criterions/legacy_masked_lm.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from fairseq import metrics, utils
11
+ from fairseq.criterions import FairseqCriterion, register_criterion
12
+
13
+
14
+ def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
15
+ """
16
+ Function to compute the cross entropy loss. The default value of
17
+ ignore_index is the same as the default value for F.cross_entropy in
18
+ pytorch.
19
+ """
20
+ assert logits.size(0) == targets.size(
21
+ -1
22
+ ), "Logits and Targets tensor shapes don't match up"
23
+
24
+ loss = F.nll_loss(
25
+ F.log_softmax(logits, -1, dtype=torch.float32),
26
+ targets,
27
+ reduction="sum",
28
+ ignore_index=ignore_index,
29
+ )
30
+ return loss
31
+
32
+
33
+ @register_criterion("legacy_masked_lm_loss")
34
+ class LegacyMaskedLmLoss(FairseqCriterion):
35
+ """
36
+ Implementation for the loss used in masked language model (MLM) training.
37
+ This optionally also computes the next sentence prediction (NSP) loss and
38
+ adds it to the overall loss based on the specified args. There are three
39
+ cases to consider:
40
+ 1) Generic MLM training without NSP loss. In this case sentence_targets
41
+ and sentence_logits are both None.
42
+ 2) BERT training without NSP loss. In this case sentence_targets is
43
+ not None but sentence_logits is None and we should not be computing
44
+ a sentence level loss.
45
+ 3) BERT training with NSP loss. In this case both sentence_targets and
46
+ sentence_logits are not None and we should be computing a sentence
47
+ level loss. The weight of the sentence level loss is specified as
48
+ an argument.
49
+ """
50
+
51
+ def __init__(self, task, masked_lm_only, nsp_loss_weight):
52
+ super().__init__(task)
53
+ self.masked_lm_only = masked_lm_only
54
+ self.nsp_loss_weight = nsp_loss_weight
55
+
56
+ @staticmethod
57
+ def add_args(parser):
58
+ """Args for MaskedLM Loss"""
59
+ # Default for masked_lm_only is False so as to not break BERT training
60
+ parser.add_argument(
61
+ "--masked-lm-only",
62
+ default=False,
63
+ action="store_true",
64
+ help="compute MLM loss only",
65
+ )
66
+ parser.add_argument(
67
+ "--nsp-loss-weight",
68
+ default=1.0,
69
+ type=float,
70
+ help="weight for next sentence prediction" " loss (default 1)",
71
+ )
72
+
73
+ def forward(self, model, sample, reduce=True):
74
+ """Compute the loss for the given sample.
75
+ Returns a tuple with three elements:
76
+ 1) the loss
77
+ 2) the sample size, which is used as the denominator for the gradient
78
+ 3) logging outputs to display while training
79
+ """
80
+ lm_logits, output_metadata = model(**sample["net_input"])
81
+
82
+ # reshape lm_logits from (N,T,C) to (N*T,C)
83
+ lm_logits = lm_logits.view(-1, lm_logits.size(-1))
84
+ lm_targets = sample["lm_target"].view(-1)
85
+ lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets, self.padding_idx)
86
+
87
+ # compute the number of tokens for which loss is computed. This is used
88
+ # to normalize the loss
89
+ ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
90
+ loss = lm_loss / ntokens
91
+ nsentences = sample["nsentences"]
92
+ # nsentences = 0
93
+
94
+ # Compute sentence loss if masked_lm_only is False
95
+ sentence_loss = None
96
+ if not self.masked_lm_only:
97
+ sentence_logits = output_metadata["sentence_logits"]
98
+ sentence_targets = sample["sentence_target"].view(-1)
99
+ # This needs to be recomputed due to some differences between
100
+ # TokenBlock and BlockPair dataset. This can be resolved with a
101
+ # refactor of BERTModel which we will do in the future.
102
+ # TODO: Remove this after refactor of BERTModel
103
+ nsentences = sentence_targets.size(0)
104
+
105
+ # Check for logits being none which can happen when remove_heads
106
+ # is set to true in the BERT model. Ideally we should set
107
+ # masked_lm_only to true in this case, but that requires some
108
+ # refactor in the BERT model.
109
+ if sentence_logits is not None:
110
+ sentence_loss = compute_cross_entropy_loss(
111
+ sentence_logits, sentence_targets
112
+ )
113
+
114
+ loss += self.nsp_loss_weight * (sentence_loss / nsentences)
115
+
116
+ # NOTE: as we are summing up per token mlm loss and per sentence nsp loss
117
+ # we don't need to use sample_size as denominator for the gradient
118
+ # here sample_size is just used for logging
119
+ sample_size = 1
120
+ logging_output = {
121
+ "loss": utils.item(loss.data) if reduce else loss.data,
122
+ "lm_loss": utils.item(lm_loss.data) if reduce else lm_loss.data,
123
+ # sentence loss is not always computed
124
+ "sentence_loss": (
125
+ (utils.item(sentence_loss.data) if reduce else sentence_loss.data)
126
+ if sentence_loss is not None
127
+ else 0.0
128
+ ),
129
+ "ntokens": ntokens,
130
+ "nsentences": nsentences,
131
+ "sample_size": sample_size,
132
+ }
133
+ return loss, sample_size, logging_output
134
+
135
+ @staticmethod
136
+ def reduce_metrics(logging_outputs) -> None:
137
+ """Aggregate logging outputs from data parallel training."""
138
+ lm_loss_sum = sum(log.get("lm_loss", 0) for log in logging_outputs)
139
+ sentence_loss_sum = sum(log.get("sentence_loss", 0) for log in logging_outputs)
140
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
141
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
142
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
143
+ agg_loss = sum(log.get("loss", 0) for log in logging_outputs)
144
+
145
+ metrics.log_scalar(
146
+ "loss",
147
+ agg_loss / sample_size / math.log(2) if sample_size > 0 else 0.0,
148
+ sample_size,
149
+ round=3,
150
+ )
151
+ metrics.log_scalar(
152
+ "lm_loss",
153
+ lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
154
+ ntokens,
155
+ round=3,
156
+ )
157
+ metrics.log_scalar(
158
+ "sentence_loss",
159
+ sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0.0,
160
+ nsentences,
161
+ round=3,
162
+ )
163
+ metrics.log_scalar(
164
+ "nll_loss",
165
+ lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
166
+ ntokens,
167
+ round=3,
168
+ )
169
+
170
+ @staticmethod
171
+ def logging_outputs_can_be_summed() -> bool:
172
+ """
173
+ Whether the logging outputs returned by `forward` can be summed
174
+ across workers prior to calling `reduce_metrics`. Setting this
175
+ to True will improves distributed training speed.
176
+ """
177
+ return True
fairseq-0.10.2/fairseq/criterions/sentence_prediction.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from fairseq import metrics, utils
11
+ from fairseq.criterions import FairseqCriterion, register_criterion
12
+
13
+
14
+ @register_criterion("sentence_prediction")
15
+ class SentencePredictionCriterion(FairseqCriterion):
16
+ def __init__(self, task, classification_head_name, regression_target):
17
+ super().__init__(task)
18
+ self.classification_head_name = classification_head_name
19
+ self.regression_target = regression_target
20
+
21
+ @staticmethod
22
+ def add_args(parser):
23
+ # fmt: off
24
+ parser.add_argument('--classification-head-name',
25
+ default='sentence_classification_head',
26
+ help='name of the classification head to use')
27
+ # fmt: on
28
+
29
+ def forward(self, model, sample, reduce=True):
30
+ """Compute the loss for the given sample.
31
+
32
+ Returns a tuple with three elements:
33
+ 1) the loss
34
+ 2) the sample size, which is used as the denominator for the gradient
35
+ 3) logging outputs to display while training
36
+ """
37
+ assert (
38
+ hasattr(model, "classification_heads")
39
+ and self.classification_head_name in model.classification_heads
40
+ ), "model must provide sentence classification head for --criterion=sentence_prediction"
41
+
42
+ logits, _ = model(
43
+ **sample["net_input"],
44
+ features_only=True,
45
+ classification_head_name=self.classification_head_name,
46
+ )
47
+ targets = model.get_targets(sample, [logits]).view(-1)
48
+ sample_size = targets.numel()
49
+
50
+ if not self.regression_target:
51
+ lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
52
+ loss = F.nll_loss(lprobs, targets, reduction="sum")
53
+ else:
54
+ logits = logits.view(-1).float()
55
+ targets = targets.float()
56
+ loss = F.mse_loss(logits, targets, reduction="sum")
57
+
58
+ logging_output = {
59
+ "loss": loss.data,
60
+ "ntokens": sample["ntokens"],
61
+ "nsentences": sample_size,
62
+ "sample_size": sample_size,
63
+ }
64
+ if not self.regression_target:
65
+ preds = logits.argmax(dim=1)
66
+ logging_output["ncorrect"] = (preds == targets).sum()
67
+
68
+ return loss, sample_size, logging_output
69
+
70
+ @staticmethod
71
+ def reduce_metrics(logging_outputs) -> None:
72
+ """Aggregate logging outputs from data parallel training."""
73
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
74
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
75
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
76
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
77
+
78
+ metrics.log_scalar(
79
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
80
+ )
81
+ if sample_size != ntokens:
82
+ metrics.log_scalar(
83
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
84
+ )
85
+
86
+ if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
87
+ ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
88
+ metrics.log_scalar(
89
+ "accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
90
+ )
91
+
92
+ @staticmethod
93
+ def logging_outputs_can_be_summed() -> bool:
94
+ """
95
+ Whether the logging outputs returned by `forward` can be summed
96
+ across workers prior to calling `reduce_metrics`. Setting this
97
+ to True will improves distributed training speed.
98
+ """
99
+ return True
fairseq-0.10.2/fairseq/criterions/sentence_ranking.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from fairseq import metrics, utils
11
+ from fairseq.criterions import FairseqCriterion, register_criterion
12
+
13
+
14
+ @register_criterion("sentence_ranking")
15
+ class SentenceRankingCriterion(FairseqCriterion):
16
+ def __init__(self, task, ranking_head_name, save_predictions, num_classes):
17
+ super().__init__(task)
18
+ self.ranking_head_name = ranking_head_name
19
+ if save_predictions is not None:
20
+ self.prediction_h = open(save_predictions, "w")
21
+ else:
22
+ self.prediction_h = None
23
+ self.num_classes = num_classes
24
+
25
+ def __del__(self):
26
+ if self.prediction_h is not None:
27
+ self.prediction_h.close()
28
+
29
+ @staticmethod
30
+ def add_args(parser):
31
+ # fmt: off
32
+ parser.add_argument('--save-predictions', metavar='FILE',
33
+ help='file to save predictions to')
34
+ parser.add_argument('--ranking-head-name',
35
+ default='sentence_classification_head',
36
+ help='name of the ranking head to use')
37
+ # fmt: on
38
+
39
+ def forward(self, model, sample, reduce=True):
40
+ """Compute ranking loss for the given sample.
41
+
42
+ Returns a tuple with three elements:
43
+ 1) the loss
44
+ 2) the sample size, which is used as the denominator for the gradient
45
+ 3) logging outputs to display while training
46
+ """
47
+ assert (
48
+ hasattr(model, "classification_heads")
49
+ and self.ranking_head_name in model.classification_heads
50
+ ), "model must provide sentence ranking head for --criterion=sentence_ranking"
51
+
52
+ scores = []
53
+ for idx in range(self.num_classes):
54
+ score, _ = model(
55
+ **sample["net_input{idx}".format(idx=idx + 1)],
56
+ classification_head_name=self.ranking_head_name,
57
+ )
58
+ scores.append(score)
59
+
60
+ logits = torch.cat(scores, dim=1)
61
+ sample_size = logits.size(0)
62
+
63
+ if "target" in sample:
64
+ targets = model.get_targets(sample, [logits]).view(-1)
65
+ lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
66
+ loss = F.nll_loss(lprobs, targets, reduction="sum")
67
+ else:
68
+ targets = None
69
+ loss = torch.tensor(0.0, requires_grad=True)
70
+
71
+ if self.prediction_h is not None:
72
+ preds = logits.argmax(dim=1)
73
+ for i, (id, pred) in enumerate(zip(sample["id"].tolist(), preds.tolist())):
74
+ if targets is not None:
75
+ label = targets[i].item()
76
+ print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
77
+ else:
78
+ print("{}\t{}".format(id, pred), file=self.prediction_h)
79
+
80
+ logging_output = {
81
+ "loss": loss.data,
82
+ "ntokens": sample["ntokens"],
83
+ "nsentences": sample_size,
84
+ "sample_size": sample_size,
85
+ }
86
+ if targets is not None:
87
+ logging_output["ncorrect"] = (logits.argmax(dim=1) == targets).sum()
88
+
89
+ return loss, sample_size, logging_output
90
+
91
+ @staticmethod
92
+ def reduce_metrics(logging_outputs) -> None:
93
+ """Aggregate logging outputs from data parallel training."""
94
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
95
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
96
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
97
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
98
+
99
+ metrics.log_scalar(
100
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
101
+ )
102
+ if sample_size != ntokens:
103
+ metrics.log_scalar(
104
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
105
+ )
106
+
107
+ if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
108
+ ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
109
+ metrics.log_scalar(
110
+ "accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
111
+ )
112
+
113
+ @staticmethod
114
+ def logging_outputs_can_be_summed() -> bool:
115
+ """
116
+ Whether the logging outputs returned by `forward` can be summed
117
+ across workers prior to calling `reduce_metrics`. Setting this
118
+ to True will improves distributed training speed.
119
+ """
120
+ return True
fairseq-0.10.2/fairseq/model_parallel/models/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (4.05 kB). View file
 
fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from fairseq import utils
12
+ from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import (
13
+ Embedding,
14
+ TransformerDecoderEmbedding,
15
+ TransformerDecoderLayer,
16
+ TransformerDecoderOutputLayer,
17
+ TransformerEncoderEmbedding,
18
+ TransformerEncoderLayer,
19
+ TransformerEncoderLayerNorm,
20
+ )
21
+ from fairseq.models import (
22
+ BaseFairseqModel,
23
+ FairseqDecoder,
24
+ FairseqEncoder,
25
+ register_model,
26
+ register_model_architecture,
27
+ )
28
+ from fairseq.models.fairseq_encoder import EncoderOut
29
+ from fairseq.models.transformer import (
30
+ base_architecture,
31
+ transformer_iwslt_de_en,
32
+ transformer_wmt_en_de_big,
33
+ )
34
+ from fairseq.modules import SinusoidalPositionalEmbedding
35
+
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ DEFAULT_MAX_SOURCE_POSITIONS = 1024
41
+ DEFAULT_MAX_TARGET_POSITIONS = 1024
42
+
43
+
44
+ @register_model("pipeline_parallel_transformer")
45
+ class PipelineParallelTransformerModel(BaseFairseqModel):
46
+ def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint):
47
+ try:
48
+ from fairscale.nn import Pipe
49
+ except ImportError:
50
+ raise ImportError("Please install fairscale with: pip install fairscale")
51
+ super().__init__()
52
+ assert isinstance(encoder, FairseqEncoder)
53
+ assert isinstance(decoder, FairseqDecoder)
54
+ encoder_module_list = (
55
+ [encoder.embedding_layer]
56
+ + list(encoder.encoder_layers)
57
+ + [encoder.final_layer_norm]
58
+ )
59
+ self.num_encoder_modules = len(encoder_module_list)
60
+ decoder_module_list = (
61
+ [decoder.embedding_layer]
62
+ + list(decoder.decoder_layers)
63
+ + [decoder.decoder_output_layer]
64
+ )
65
+ self.num_decoder_modules = len(decoder_module_list)
66
+ module_list = encoder_module_list + decoder_module_list
67
+ self.devices = devices
68
+ self.model = Pipe(
69
+ nn.Sequential(*module_list),
70
+ balance=balance,
71
+ devices=devices,
72
+ chunks=chunks,
73
+ checkpoint=checkpoint,
74
+ )
75
+ self.encoder_max_positions = self.max_positions_helper(
76
+ encoder.embedding_layer, "max_source_positions"
77
+ )
78
+ self.decoder_max_positions = self.max_positions_helper(
79
+ decoder.embedding_layer, "max_target_positions"
80
+ )
81
+ self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None)
82
+ # Note: To be populated during inference
83
+ self.encoder = None
84
+ self.decoder = None
85
+
86
+ def forward(self, src_tokens, src_lengths, prev_output_tokens):
87
+ if self.training:
88
+ input_lst = [src_tokens, src_lengths, prev_output_tokens]
89
+ input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst)
90
+ return self.model(input)
91
+ else:
92
+ assert self.encoder is not None and self.decoder is not None, (
93
+ "encoder and decoder need to be initialized by "
94
+ + "calling the `prepare_for_inference_()` method"
95
+ )
96
+ encoder_output_tuple = self.encoder(input)
97
+ return self.decoder(encoder_output_tuple)
98
+
99
+ def prepare_for_inference_(self, args):
100
+ if self.encoder is not None and self.decoder is not None:
101
+ logger.info("Encoder and Decoder already initialized")
102
+ return
103
+ encoder_module_list = []
104
+ decoder_module_list = []
105
+ module_count = 0
106
+ for partition in self.model.partitions:
107
+ for module in partition:
108
+ if module_count < self.num_encoder_modules:
109
+ encoder_module_list.append(module)
110
+ else:
111
+ decoder_module_list.append(module)
112
+ module_count += 1
113
+ self.model = None
114
+ self.encoder = TransformerEncoder(args, None, None, encoder_module_list)
115
+ self.decoder = TransformerDecoder(
116
+ args, None, None, decoder_module_list=decoder_module_list
117
+ )
118
+
119
+ @staticmethod
120
+ def add_args(parser):
121
+ """Add model-specific arguments to the parser."""
122
+ # fmt: off
123
+ parser.add_argument('--activation-fn',
124
+ choices=utils.get_available_activation_fns(),
125
+ help='activation function to use')
126
+ parser.add_argument('--dropout', type=float, metavar='D',
127
+ help='dropout probability')
128
+ parser.add_argument('--attention-dropout', type=float, metavar='D',
129
+ help='dropout probability for attention weights')
130
+ parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
131
+ help='dropout probability after activation in FFN.')
132
+ parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
133
+ help='path to pre-trained encoder embedding')
134
+ parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
135
+ help='encoder embedding dimension')
136
+ parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
137
+ help='encoder embedding dimension for FFN')
138
+ parser.add_argument('--encoder-layers', type=int, metavar='N',
139
+ help='num encoder layers')
140
+ parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
141
+ help='num encoder attention heads')
142
+ parser.add_argument('--encoder-normalize-before', action='store_true',
143
+ help='apply layernorm before each encoder block')
144
+ parser.add_argument('--encoder-learned-pos', action='store_true',
145
+ help='use learned positional embeddings in the encoder')
146
+ parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
147
+ help='path to pre-trained decoder embedding')
148
+ parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
149
+ help='decoder embedding dimension')
150
+ parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
151
+ help='decoder embedding dimension for FFN')
152
+ parser.add_argument('--decoder-layers', type=int, metavar='N',
153
+ help='num decoder layers')
154
+ parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
155
+ help='num decoder attention heads')
156
+ parser.add_argument('--decoder-learned-pos', action='store_true',
157
+ help='use learned positional embeddings in the decoder')
158
+ parser.add_argument('--decoder-normalize-before', action='store_true',
159
+ help='apply layernorm before each decoder block')
160
+ parser.add_argument('--share-decoder-input-output-embed', action='store_true',
161
+ help='share decoder input and output embeddings')
162
+ parser.add_argument('--share-all-embeddings', action='store_true',
163
+ help='share encoder, decoder and output embeddings'
164
+ ' (requires shared dictionary and embed dim)')
165
+ parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
166
+ help='if set, disables positional embeddings (outside self attention)')
167
+ parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
168
+ help='comma separated list of adaptive softmax cutoff points. '
169
+ 'Must be used with adaptive_loss criterion'),
170
+ parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
171
+ help='sets adaptive softmax dropout for the tail projections')
172
+ parser.add_argument('--num-embedding-chunks', type=int, metavar='N', default=1,
173
+ help='Number of embedding layer chunks (enables more even distribution'
174
+ 'of optimizer states across data parallel nodes'
175
+ 'when using optimizer state sharding and'
176
+ 'a big embedding vocabulary)')
177
+ # fmt: on
178
+
179
+ @classmethod
180
+ def build_model_base(cls, args, task):
181
+ """Build a new model instance."""
182
+
183
+ # make sure all arguments are present in older models
184
+ base_architecture(args)
185
+
186
+ if not hasattr(args, "max_source_positions"):
187
+ args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
188
+ if not hasattr(args, "max_target_positions"):
189
+ args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
190
+
191
+ src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
192
+
193
+ def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1):
194
+ assert embed_dim % num_embed_chunks == 0, (
195
+ f"Number of embedding chunks = {num_embed_chunks} should be "
196
+ + f"divisible by the embedding dimension = {embed_dim}"
197
+ )
198
+ assert path is None or num_embed_chunks == 1, (
199
+ "Loading embedding from a path with number of embedding chunks > 1"
200
+ + " is not yet supported"
201
+ )
202
+ num_embeddings = len(dictionary)
203
+ padding_idx = dictionary.pad()
204
+ # if provided, load from preloaded dictionaries
205
+ if path:
206
+ emb = Embedding(num_embeddings, embed_dim, padding_idx)
207
+ embed_dict = utils.parse_embedding(path)
208
+ utils.load_embedding(embed_dict, dictionary, emb)
209
+ else:
210
+ embed_chunk_dim = embed_dim // num_embed_chunks
211
+ emb = nn.ModuleList()
212
+ for i in range(num_embed_chunks):
213
+ emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx))
214
+ return emb
215
+
216
+ num_embed_chunks = args.num_embedding_chunks
217
+ if args.share_all_embeddings:
218
+ if src_dict != tgt_dict:
219
+ raise ValueError("--share-all-embeddings requires a joined dictionary")
220
+ if args.encoder_embed_dim != args.decoder_embed_dim:
221
+ raise ValueError(
222
+ "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
223
+ )
224
+ if args.decoder_embed_path and (
225
+ args.decoder_embed_path != args.encoder_embed_path
226
+ ):
227
+ raise ValueError(
228
+ "--share-all-embeddings not compatible with --decoder-embed-path"
229
+ )
230
+ encoder_embed_tokens = build_embedding(
231
+ src_dict,
232
+ args.encoder_embed_dim,
233
+ args.encoder_embed_path,
234
+ num_embed_chunks,
235
+ )
236
+ decoder_embed_tokens = encoder_embed_tokens
237
+ args.share_decoder_input_output_embed = True
238
+ else:
239
+ assert args.share_decoder_input_output_embed or num_embed_chunks == 1, (
240
+ "Not sharing decoder I/O embeddings is not yet supported with number of "
241
+ + "embedding chunks > 1"
242
+ )
243
+ encoder_embed_tokens = build_embedding(
244
+ src_dict,
245
+ args.encoder_embed_dim,
246
+ args.encoder_embed_path,
247
+ num_embed_chunks,
248
+ )
249
+ decoder_embed_tokens = build_embedding(
250
+ tgt_dict,
251
+ args.decoder_embed_dim,
252
+ args.decoder_embed_path,
253
+ num_embed_chunks,
254
+ )
255
+
256
+ encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
257
+ decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
258
+ return (encoder, decoder)
259
+
260
+ @classmethod
261
+ def build_encoder(cls, args, src_dict, embed_tokens):
262
+ return TransformerEncoder(args, src_dict, embed_tokens)
263
+
264
+ @classmethod
265
+ def build_decoder(cls, args, tgt_dict, embed_tokens):
266
+ return TransformerDecoder(args, tgt_dict, embed_tokens)
267
+
268
+ @classmethod
269
+ def build_model(cls, args, task):
270
+ encoder, decoder = cls.build_model_base(args, task)
271
+ return PipelineParallelTransformerModel(
272
+ encoder=encoder,
273
+ decoder=decoder,
274
+ balance=utils.eval_str_list(args.pipeline_balance, type=int),
275
+ devices=utils.eval_str_list(args.pipeline_devices, type=int),
276
+ chunks=args.pipeline_chunks,
277
+ checkpoint=args.pipeline_checkpoint,
278
+ )
279
+
280
+ def output_layer(self, features, **kwargs):
281
+ """Project features to the default output size (typically vocabulary size)."""
282
+ return self.decoder.output_layer(features, **kwargs)
283
+
284
+ def max_positions(self):
285
+ """Maximum length supported by the model."""
286
+ return (self.encoder_max_positions, self.decoder_max_positions)
287
+
288
+ def max_positions_helper(
289
+ self, embedding_layer, max_positions_field="max_source_positions"
290
+ ):
291
+ """Maximum input length supported by the encoder or decoder."""
292
+ if embedding_layer.embed_positions is None:
293
+ return getattr(embedding_layer, max_positions_field)
294
+ return min(
295
+ getattr(embedding_layer, max_positions_field),
296
+ embedding_layer.embed_positions.max_positions,
297
+ )
298
+
299
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
300
+ """Get normalized probabilities (or log probs) from a net's output."""
301
+
302
+ if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
303
+ if sample is not None:
304
+ assert "target" in sample
305
+ target = sample["target"]
306
+ else:
307
+ target = None
308
+ out = self.adaptive_softmax.get_log_prob(net_output, target=target)
309
+ return out.exp_() if not log_probs else out
310
+
311
+ # A Pipe() module returns a tuple of tensors as the output.
312
+ # In this case, the tuple has one element - the output tensor of logits
313
+ logits = net_output if isinstance(net_output, torch.Tensor) else net_output[0]
314
+ if log_probs:
315
+ return utils.log_softmax(logits, dim=-1, onnx_trace=False)
316
+ else:
317
+ return utils.softmax(logits, dim=-1, onnx_trace=False)
318
+
319
+ def max_decoder_positions(self):
320
+ """Maximum length supported by the decoder."""
321
+ return self.decoder_max_positions
322
+
323
+ def load_state_dict(self, state_dict, strict=True, args=None):
324
+ """Copies parameters and buffers from *state_dict* into this module and
325
+ its descendants.
326
+
327
+ Overrides the method in :class:`nn.Module`. Compared with that method
328
+ this additionally "upgrades" *state_dicts* from old checkpoints.
329
+ """
330
+ self.upgrade_state_dict(state_dict)
331
+ is_regular_transformer = not any("model.partitions" in k for k in state_dict)
332
+ if is_regular_transformer:
333
+ state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict)
334
+ return super().load_state_dict(state_dict, strict)
335
+
336
+ def convert_to_pipeline_parallel_state_dict(self, state_dict):
337
+ new_state_dict = self.state_dict()
338
+ encoder_layer_idx = 0
339
+ decoder_layer_idx = 0
340
+ encoder_key_suffixes = [
341
+ "self_attn.k_proj.weight",
342
+ "self_attn.k_proj.bias",
343
+ "self_attn.v_proj.weight",
344
+ "self_attn.v_proj.bias",
345
+ "self_attn.q_proj.weight",
346
+ "self_attn.q_proj.bias",
347
+ "self_attn.out_proj.weight",
348
+ "self_attn.out_proj.bias",
349
+ "self_attn_layer_norm.weight",
350
+ "self_attn_layer_norm.bias",
351
+ "fc1.weight",
352
+ "fc1.bias",
353
+ "fc2.weight",
354
+ "fc2.bias",
355
+ "final_layer_norm.weight",
356
+ "final_layer_norm.bias",
357
+ ]
358
+ decoder_key_suffixes = [
359
+ "self_attn.k_proj.weight",
360
+ "self_attn.k_proj.bias",
361
+ "self_attn.v_proj.weight",
362
+ "self_attn.v_proj.bias",
363
+ "self_attn.q_proj.weight",
364
+ "self_attn.q_proj.bias",
365
+ "self_attn.out_proj.weight",
366
+ "self_attn.out_proj.bias",
367
+ "self_attn_layer_norm.weight",
368
+ "self_attn_layer_norm.bias",
369
+ "encoder_attn.k_proj.weight",
370
+ "encoder_attn.k_proj.bias",
371
+ "encoder_attn.v_proj.weight",
372
+ "encoder_attn.v_proj.bias",
373
+ "encoder_attn.q_proj.weight",
374
+ "encoder_attn.q_proj.bias",
375
+ "encoder_attn.out_proj.weight",
376
+ "encoder_attn.out_proj.bias",
377
+ "encoder_attn_layer_norm.weight",
378
+ "encoder_attn_layer_norm.bias",
379
+ "fc1.weight",
380
+ "fc1.bias",
381
+ "fc2.weight",
382
+ "fc2.bias",
383
+ "final_layer_norm.weight",
384
+ "final_layer_norm.bias",
385
+ ]
386
+ for pid, partition in enumerate(self.model.partitions):
387
+ logger.info(f"Begin Partition {pid}")
388
+ for mid, module in enumerate(partition):
389
+ # fmt: off
390
+ if isinstance(module, TransformerEncoderEmbedding):
391
+ new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight']
392
+ new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['encoder.embed_positions._float_tensor']
393
+ if isinstance(module, TransformerEncoderLayer):
394
+ for suffix in encoder_key_suffixes:
395
+ new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}']
396
+ encoder_layer_idx += 1
397
+ if isinstance(module, TransformerDecoderLayer):
398
+ for suffix in decoder_key_suffixes:
399
+ new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'decoder.layers.{decoder_layer_idx}.{suffix}']
400
+ decoder_layer_idx += 1
401
+ if isinstance(module, TransformerEncoderLayerNorm):
402
+ if 'encoder.layer_norm.weight' in state_dict:
403
+ new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.weight'] = state_dict['encoder.layer_norm.weight']
404
+ new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias']
405
+ if isinstance(module, TransformerDecoderEmbedding):
406
+ new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight']
407
+ new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['decoder.embed_positions._float_tensor']
408
+ if isinstance(module, TransformerDecoderOutputLayer):
409
+ new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight']
410
+ # fmt: on
411
+ return new_state_dict
412
+
413
+
414
+ class TransformerEncoder(FairseqEncoder):
415
+ """
416
+ Transformer encoder consisting of *args.encoder_layers* layers. Each layer
417
+ is a :class:`TransformerEncoderLayer`.
418
+
419
+ Args:
420
+ args (argparse.Namespace): parsed command-line arguments
421
+ dictionary (~fairseq.data.Dictionary): encoding dictionary
422
+ embed_tokens (torch.nn.Embedding): input embedding
423
+ """
424
+
425
+ def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None):
426
+ super().__init__(dictionary)
427
+ self.register_buffer("version", torch.Tensor([3]))
428
+ try:
429
+ from fairscale.nn import Pipe
430
+ except ImportError:
431
+ raise ImportError("Please install fairscale with: pip install fairscale")
432
+ if encoder_module_list is None:
433
+ embedding_layer = TransformerEncoderEmbedding(args, embed_tokens)
434
+ layers = [TransformerEncoderLayer(args) for i in range(args.encoder_layers)]
435
+ if isinstance(embed_tokens, nn.ModuleList):
436
+ emb_dim = sum(e.embedding_dim for e in embed_tokens)
437
+ else:
438
+ emb_dim = embed_tokens.embedding_dim
439
+ final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim)
440
+ encoder_module_list = [embedding_layer] + layers + [final_layer_norm]
441
+ self.use_pipeline = getattr(args, "pipeline_encoder_balance", None) is not None
442
+ if self.use_pipeline:
443
+ encoder_balance = utils.eval_str_list(
444
+ args.pipeline_encoder_balance, type=int
445
+ )
446
+ encoder_devices = utils.eval_str_list(
447
+ args.pipeline_encoder_devices, type=int
448
+ )
449
+ assert sum(encoder_balance) == len(encoder_module_list), (
450
+ f"Sum of encoder_balance={encoder_balance} is not equal "
451
+ + f"to num_encoder_modules={len(encoder_module_list)}"
452
+ )
453
+ self.model = Pipe(
454
+ module=nn.Sequential(*encoder_module_list),
455
+ balance=encoder_balance,
456
+ devices=encoder_devices,
457
+ chunks=args.pipeline_chunks,
458
+ checkpoint=args.pipeline_checkpoint,
459
+ )
460
+ else:
461
+ self.embedding_layer = encoder_module_list[0]
462
+ self.encoder_layers = nn.Sequential(*encoder_module_list[1:-1])
463
+ self.final_layer_norm = encoder_module_list[-1]
464
+
465
+ def forward(self, src_tokens, src_lengths):
466
+ """
467
+ Args:
468
+ input_tuple(
469
+ src_tokens (LongTensor): tokens in the source language of shape
470
+ `(batch, src_len)`
471
+ src_lengths (torch.LongTensor): lengths of each source sentence of
472
+ shape `(batch)`
473
+ )
474
+
475
+ Returns:
476
+ output_tuple(
477
+ - **encoder_out** (Tensor): the last encoder layer's output of
478
+ shape `(src_len, batch, embed_dim)`
479
+ - **encoder_padding_mask** (ByteTensor): the positions of
480
+ padding elements of shape `(batch, src_len)`
481
+ - prev_output_tokens
482
+ - **encoder_states** (List[Tensor]): all intermediate
483
+ hidden states of shape `(src_len, batch, embed_dim)`.
484
+ Only populated if *return_all_hiddens* is True.
485
+ )
486
+ """
487
+ dummy_prev_output_tokens = torch.zeros(
488
+ 1, dtype=src_tokens.dtype, device=src_tokens.device
489
+ )
490
+ input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens)
491
+ if self.use_pipeline:
492
+ input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
493
+ encoder_out = self.model(input_tuple)
494
+ else:
495
+ encoder_embed_output_tuple = self.embedding_layer(input_tuple)
496
+ encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple)
497
+ encoder_out = self.final_layer_norm(encoder_layers_output)
498
+ # first element is the encoder output
499
+ # second element is the encoder padding mask
500
+ # the remaining elements of EncoderOut are not computed by
501
+ # the PipelineParallelTransformer
502
+ return EncoderOut(encoder_out[0], encoder_out[1], None, None, None, None)
503
+
504
+ def reorder_encoder_out(self, encoder_out, new_order):
505
+ """
506
+ Reorder encoder output according to *new_order*.
507
+
508
+ Args:
509
+ encoder_out: output from the ``forward()`` method
510
+ new_order (LongTensor): desired order
511
+
512
+ Returns:
513
+ *encoder_out* rearranged according to *new_order*
514
+ """
515
+ if encoder_out.encoder_out is not None:
516
+ encoder_out = encoder_out._replace(
517
+ encoder_out=encoder_out.encoder_out.index_select(1, new_order)
518
+ )
519
+ if encoder_out.encoder_padding_mask is not None:
520
+ encoder_out = encoder_out._replace(
521
+ encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(
522
+ 0, new_order
523
+ )
524
+ )
525
+ if encoder_out.encoder_embedding is not None:
526
+ encoder_out = encoder_out._replace(
527
+ encoder_embedding=encoder_out.encoder_embedding.index_select(
528
+ 0, new_order
529
+ )
530
+ )
531
+ if encoder_out.encoder_states is not None:
532
+ for idx, state in enumerate(encoder_out.encoder_states):
533
+ encoder_out.encoder_states[idx] = state.index_select(1, new_order)
534
+ return encoder_out
535
+
536
+ def max_positions(self):
537
+ """Maximum input length supported by the encoder."""
538
+ if self.embedding_layer.embed_positions is None:
539
+ return self.embedding_layer.max_source_positions
540
+ return min(
541
+ self.embedding_layer.max_source_positions,
542
+ self.embedding_layer.embed_positions.max_positions,
543
+ )
544
+
545
+
546
+ class TransformerDecoder(FairseqDecoder):
547
+ """
548
+ Transformer decoder consisting of *args.decoder_layers* layers. Each layer
549
+ is a :class:`TransformerDecoderLayer`.
550
+
551
+ Args:
552
+ args (argparse.Namespace): parsed command-line arguments
553
+ dictionary (~fairseq.data.Dictionary): decoding dictionary
554
+ embed_tokens (torch.nn.Embedding): output embedding
555
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
556
+ (default: False).
557
+ """
558
+
559
+ def __init__(
560
+ self,
561
+ args,
562
+ dictionary,
563
+ embed_tokens,
564
+ no_encoder_attn=False,
565
+ decoder_module_list=None,
566
+ ):
567
+ super().__init__(dictionary)
568
+ self.register_buffer("version", torch.Tensor([3]))
569
+ try:
570
+ from fairscale.nn import Pipe
571
+ except ImportError:
572
+ raise ImportError("Please install fairscale with: pip install fairscale")
573
+ if decoder_module_list is None:
574
+ embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
575
+ layers = [
576
+ TransformerDecoderLayer(args, no_encoder_attn)
577
+ for _ in range(args.decoder_layers)
578
+ ]
579
+ decoder_output_layer = TransformerDecoderOutputLayer(
580
+ args, embed_tokens, dictionary
581
+ )
582
+ decoder_module_list = [embedding_layer] + layers + [decoder_output_layer]
583
+ self.use_pipeline = getattr(args, "pipeline_decoder_balance", None) is not None
584
+ if self.use_pipeline:
585
+ decoder_balance = utils.eval_str_list(
586
+ args.pipeline_decoder_balance, type=int
587
+ )
588
+ decoder_devices = utils.eval_str_list(
589
+ args.pipeline_decoder_devices, type=int
590
+ )
591
+ assert sum(decoder_balance) == len(decoder_module_list), (
592
+ f"Sum of decoder_balance={decoder_balance} is not equal "
593
+ + f"to num_decoder_modules={len(decoder_module_list)}"
594
+ )
595
+ self.model = Pipe(
596
+ module=nn.Sequential(*decoder_module_list),
597
+ balance=decoder_balance,
598
+ devices=decoder_devices,
599
+ chunks=args.pipeline_chunks,
600
+ checkpoint=args.pipeline_checkpoint,
601
+ )
602
+ else:
603
+ self.embedding_layer = decoder_module_list[0]
604
+ self.decoder_layers = nn.Sequential(*decoder_module_list[1:-1])
605
+ self.decoder_output_layer = decoder_module_list[-1]
606
+
607
+ def forward(
608
+ self,
609
+ prev_output_tokens,
610
+ encoder_out=None,
611
+ ):
612
+ """
613
+ Args:
614
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
615
+ `(batch, tgt_len)`, for teacher forcing
616
+ encoder_out (optional): output from the encoder, used for
617
+ encoder-side attention
618
+ incremental_state (dict): dictionary used for storing state during
619
+ :ref:`Incremental decoding`
620
+ features_only (bool, optional): only return features without
621
+ applying output layer (default: False).
622
+
623
+ Returns:
624
+ tuple:
625
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
626
+ - a dictionary with any model-specific outputs
627
+ """
628
+ input_tuple = (
629
+ encoder_out.encoder_out,
630
+ encoder_out.encoder_padding_mask,
631
+ prev_output_tokens,
632
+ )
633
+ if self.use_pipeline:
634
+ input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
635
+ return (self.model(input_tuple),)
636
+ else:
637
+ embed_layer_output = self.embedding_layer(input_tuple)
638
+ state = self.decoder_layers(embed_layer_output)
639
+ return (self.decoder_output_layer(state),)
640
+
641
+ def output_layer(self, features, **kwargs):
642
+ """Project features to the vocabulary size."""
643
+ if self.adaptive_softmax is None:
644
+ # project back to size of vocabulary
645
+ if self.share_input_output_embed:
646
+ return F.linear(features, self.embed_tokens.weight)
647
+ else:
648
+ return F.linear(features, self.embed_out)
649
+ else:
650
+ return features
651
+
652
+ def max_positions(self):
653
+ """Maximum output length supported by the decoder."""
654
+ if self.embedding_layer.embed_positions is None:
655
+ return self.embedding_layer.max_target_positions
656
+ return min(
657
+ self.embedding_layer.max_target_positions,
658
+ self.embedding_layer.embed_positions.max_positions,
659
+ )
660
+
661
+ def buffered_future_mask(self, tensor):
662
+ dim = tensor.size(0)
663
+ if (
664
+ not hasattr(self, "_future_mask")
665
+ or self._future_mask is None
666
+ or self._future_mask.device != tensor.device
667
+ or self._future_mask.size(0) < dim
668
+ ):
669
+ self._future_mask = torch.triu(
670
+ utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
671
+ )
672
+ return self._future_mask[:dim, :dim]
673
+
674
+ def upgrade_state_dict_named(self, state_dict, name):
675
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
676
+ if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
677
+ weights_key = "{}.embed_positions.weights".format(name)
678
+ if weights_key in state_dict:
679
+ del state_dict[weights_key]
680
+ state_dict[
681
+ "{}.embed_positions._float_tensor".format(name)
682
+ ] = torch.FloatTensor(1)
683
+
684
+ for i in range(len(self.layers)):
685
+ # update layer norms
686
+ layer_norm_map = {
687
+ "0": "self_attn_layer_norm",
688
+ "1": "encoder_attn_layer_norm",
689
+ "2": "final_layer_norm",
690
+ }
691
+ for old, new in layer_norm_map.items():
692
+ for m in ("weight", "bias"):
693
+ k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
694
+ if k in state_dict:
695
+ state_dict[
696
+ "{}.layers.{}.{}.{}".format(name, i, new, m)
697
+ ] = state_dict[k]
698
+ del state_dict[k]
699
+
700
+ version_key = "{}.version".format(name)
701
+ if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
702
+ # earlier checkpoints did not normalize after the stack of layers
703
+ self.layer_norm = None
704
+ self.normalize = False
705
+ state_dict[version_key] = torch.Tensor([1])
706
+
707
+ return state_dict
708
+
709
+
710
+ @register_model_architecture(
711
+ "pipeline_parallel_transformer", "transformer_iwslt_de_en_pipeline_parallel"
712
+ )
713
+ def transformer_iwslt_de_en_dist(args):
714
+ transformer_iwslt_de_en(args)
715
+
716
+
717
+ @register_model_architecture(
718
+ "pipeline_parallel_transformer", "transformer_wmt_en_de_big_pipeline_parallel"
719
+ )
720
+ def transformer_wmt_en_de_big_dist(args):
721
+ transformer_wmt_en_de_big(args)
fairseq-0.10.2/fairseq/model_parallel/models/roberta/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .model import * # noqa
fairseq-0.10.2/fairseq/model_parallel/models/roberta/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (208 Bytes). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (5.58 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/distributed_fairseq_model.cpython-310.pyc ADDED
Binary file (3.06 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/fairseq_decoder.cpython-310.pyc ADDED
Binary file (3.41 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/fairseq_encoder.cpython-310.pyc ADDED
Binary file (3.59 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-310.pyc ADDED
Binary file (4.8 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/fairseq_model.cpython-310.pyc ADDED
Binary file (20.6 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/fconv.cpython-310.pyc ADDED
Binary file (19.1 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/fconv_lm.cpython-310.pyc ADDED
Binary file (3.77 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/fconv_self_att.cpython-310.pyc ADDED
Binary file (16.3 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/lightconv.cpython-310.pyc ADDED
Binary file (25 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/lightconv_lm.cpython-310.pyc ADDED
Binary file (6.99 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/lstm.cpython-310.pyc ADDED
Binary file (18.6 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/lstm_lm.cpython-310.pyc ADDED
Binary file (4.34 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/masked_lm.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/model_utils.cpython-310.pyc ADDED
Binary file (2.35 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/multilingual_transformer.cpython-310.pyc ADDED
Binary file (6.69 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (27.7 kB). View file
 
fairseq-0.10.2/fairseq/models/__pycache__/transformer_lm.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
fairseq-0.10.2/fairseq/models/bart/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .hub_interface import * # noqa
7
+ from .model import * # noqa
fairseq-0.10.2/fairseq/models/bart/hub_interface.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import copy
7
+ import logging
8
+ from typing import List
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from fairseq import utils
15
+ from fairseq.data import encoders
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class BARTHubInterface(nn.Module):
22
+ """A simple PyTorch Hub interface to BART.
23
+
24
+ Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart
25
+ """
26
+
27
+ def __init__(self, args, task, model):
28
+ super().__init__()
29
+ self.args = args
30
+ self.task = task
31
+ self.model = model
32
+
33
+ self.bpe = encoders.build_bpe(args)
34
+
35
+ self.max_positions = min(
36
+ utils.resolve_max_positions(
37
+ self.task.max_positions(),
38
+ self.model.max_positions(),
39
+ )
40
+ )
41
+
42
+ # this is useful for determining the device
43
+ self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
44
+
45
+ @property
46
+ def device(self):
47
+ return self._float_tensor.device
48
+
49
+ def encode(
50
+ self, sentence: str, *addl_sentences, no_separator=True
51
+ ) -> torch.LongTensor:
52
+ """
53
+ BPE-encode a sentence (or multiple sentences).
54
+
55
+ Every sequence begins with a beginning-of-sentence (`<s>`) symbol.
56
+ Every sentence ends with an end-of-sentence (`</s>`).
57
+
58
+ Example (single sentence): `<s> a b c </s>`
59
+ Example (sentence pair): `<s> d e f </s> 1 2 3 </s>`
60
+
61
+ The BPE encoding follows GPT-2. One subtle detail is that the GPT-2 BPE
62
+ requires leading spaces. For example::
63
+
64
+ >>> bart.encode('Hello world').tolist()
65
+ [0, 31414, 232, 2]
66
+ >>> bart.encode(' world').tolist()
67
+ [0, 232, 2]
68
+ >>> bart.encode('world').tolist()
69
+ [0, 8331, 2]
70
+ """
71
+ tokens = self.bpe.encode(sentence)
72
+ if len(tokens.split(" ")) > self.max_positions - 2:
73
+ tokens = " ".join(tokens.split(" ")[: self.max_positions - 2])
74
+ bpe_sentence = "<s> " + tokens + " </s>"
75
+ for s in addl_sentences:
76
+ bpe_sentence += " </s>" if not no_separator else ""
77
+ bpe_sentence += " " + self.bpe.encode(s) + " </s>"
78
+ tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False)
79
+ return tokens.long()
80
+
81
+ def decode(self, tokens: torch.LongTensor):
82
+ assert tokens.dim() == 1
83
+ tokens = tokens.cpu().numpy()
84
+ if tokens[0] == self.task.source_dictionary.bos():
85
+ tokens = tokens[1:] # remove <s>
86
+ eos_mask = tokens == self.task.source_dictionary.eos()
87
+ doc_mask = eos_mask[1:] & eos_mask[:-1]
88
+ sentences = np.split(tokens, doc_mask.nonzero()[0] + 1)
89
+ sentences = [
90
+ self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences
91
+ ]
92
+ if len(sentences) == 1:
93
+ return sentences[0]
94
+ return sentences
95
+
96
+ def _build_sample(self, src_tokens: List[torch.LongTensor]):
97
+ # assert torch.is_tensor(src_tokens)
98
+ dataset = self.task.build_dataset_for_inference(
99
+ src_tokens,
100
+ [x.numel() for x in src_tokens],
101
+ )
102
+ sample = dataset.collater(dataset)
103
+ sample = utils.apply_to_sample(lambda tensor: tensor.to(self.device), sample)
104
+ return sample
105
+
106
+ def sample(
107
+ self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs
108
+ ) -> str:
109
+ input = [self.encode(sentence) for sentence in sentences]
110
+ hypos = self.generate(input, beam, verbose, **kwargs)
111
+ return [self.decode(x["tokens"]) for x in hypos]
112
+
113
+ def generate(
114
+ self,
115
+ tokens: List[torch.LongTensor],
116
+ beam: int = 5,
117
+ verbose: bool = False,
118
+ **kwargs
119
+ ) -> torch.LongTensor:
120
+ sample = self._build_sample(tokens)
121
+
122
+ # build generator using current args as well as any kwargs
123
+ gen_args = copy.copy(self.args)
124
+ gen_args.beam = beam
125
+ for k, v in kwargs.items():
126
+ setattr(gen_args, k, v)
127
+ generator = self.task.build_generator([self.model], gen_args)
128
+ translations = self.task.inference_step(
129
+ generator,
130
+ [self.model],
131
+ sample,
132
+ prefix_tokens=sample["net_input"]["src_tokens"]
133
+ .new_zeros((len(tokens), 1))
134
+ .fill_(self.task.source_dictionary.bos()),
135
+ )
136
+
137
+ if verbose:
138
+ src_str_with_unk = self.string(tokens)
139
+ logger.info("S\t{}".format(src_str_with_unk))
140
+
141
+ def getarg(name, default):
142
+ return getattr(gen_args, name, getattr(self.args, name, default))
143
+
144
+ # Process top predictions
145
+ hypos = [x[0] for x in translations]
146
+ hypos = [v for _, v in sorted(zip(sample["id"].tolist(), hypos))]
147
+ return hypos
148
+
149
+ def extract_features(
150
+ self, tokens: torch.LongTensor, return_all_hiddens: bool = False
151
+ ) -> torch.Tensor:
152
+ if tokens.dim() == 1:
153
+ tokens = tokens.unsqueeze(0)
154
+ if tokens.size(-1) > min(self.model.max_positions()):
155
+ raise ValueError(
156
+ "tokens exceeds maximum length: {} > {}".format(
157
+ tokens.size(-1), self.model.max_positions()
158
+ )
159
+ )
160
+ tokens.to(device=self.device),
161
+ prev_output_tokens = tokens.clone()
162
+
163
+ prev_output_tokens[:, 0] = tokens.gather(
164
+ 1,
165
+ (tokens.ne(self.task.source_dictionary.pad()).sum(dim=1) - 1).unsqueeze(-1),
166
+ ).squeeze()
167
+
168
+ prev_output_tokens[:, 1:] = tokens[:, :-1]
169
+ features, extra = self.model(
170
+ src_tokens=tokens,
171
+ src_lengths=None,
172
+ prev_output_tokens=prev_output_tokens,
173
+ features_only=True,
174
+ return_all_hiddens=return_all_hiddens,
175
+ )
176
+ if return_all_hiddens:
177
+ # convert from T x B x C -> B x T x C
178
+ inner_states = extra["inner_states"]
179
+ return [inner_state.transpose(0, 1) for inner_state in inner_states]
180
+ else:
181
+ return features # just the last layer's features
182
+
183
+ def register_classification_head(
184
+ self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs
185
+ ):
186
+ self.model.register_classification_head(
187
+ name, num_classes=num_classes, embedding_size=embedding_size, **kwargs
188
+ )
189
+
190
+ def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False):
191
+ if tokens.dim() == 1:
192
+ tokens = tokens.unsqueeze(0)
193
+ features = self.extract_features(tokens.to(device=self.device))
194
+ sentence_representation = features[
195
+ tokens.eq(self.task.source_dictionary.eos()), :
196
+ ].view(features.size(0), -1, features.size(-1))[:, -1, :]
197
+
198
+ logits = self.model.classification_heads[head](sentence_representation)
199
+ if return_logits:
200
+ return logits
201
+ return F.log_softmax(logits, dim=-1)
fairseq-0.10.2/fairseq/models/distributed_fairseq_model.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import inspect
7
+
8
+ import torch.nn as nn
9
+ from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
10
+
11
+
12
+ _GOSSIP_DISABLED = False
13
+ try:
14
+ import gossip
15
+ except ImportError:
16
+ _GOSSIP_DISABLED = True
17
+
18
+
19
+ def DistributedFairseqModel(args, model, process_group=None):
20
+ """
21
+ Wrap a *model* to support distributed data parallel training.
22
+
23
+ This is similar to the built-in DistributedDataParallel, but allows
24
+ additional configuration of the DistributedDataParallel class to
25
+ use, and also provides easier access to the wrapped model by
26
+ forwarding requests for missing attributes to the wrapped model.
27
+
28
+ Args:
29
+ args (argparse.Namespace): fairseq args
30
+ model (BaseFairseqModel): model to wrap
31
+ """
32
+ # determine which DDP class to extend
33
+ assert isinstance(model, nn.Module)
34
+ if args.distributed_wrapper == "DDP" and args.ddp_backend == "c10d":
35
+ ddp_class = nn.parallel.DistributedDataParallel
36
+ init_kwargs = dict(
37
+ module=model,
38
+ device_ids=[args.device_id],
39
+ output_device=args.device_id,
40
+ broadcast_buffers=args.broadcast_buffers,
41
+ bucket_cap_mb=args.bucket_cap_mb,
42
+ process_group=process_group,
43
+ )
44
+ # Maintain backward compatibility
45
+ if "check_reduction" in inspect.getargspec(ddp_class)[0]:
46
+ init_kwargs["check_reduction"] = True
47
+ if "find_unused_parameters" in inspect.getargspec(ddp_class)[0]:
48
+ init_kwargs["find_unused_parameters"] = args.find_unused_parameters
49
+ elif args.distributed_wrapper == "DDP" and args.ddp_backend == "no_c10d":
50
+ ddp_class = LegacyDistributedDataParallel
51
+ init_kwargs = dict(
52
+ module=model,
53
+ world_size=args.distributed_world_size,
54
+ buffer_size=2 ** 28,
55
+ process_group=process_group,
56
+ )
57
+ elif args.distributed_wrapper == "SlowMo":
58
+ if _GOSSIP_DISABLED:
59
+ raise ImportError(
60
+ "Cannot find gossip library. Please install from: "
61
+ "github.com/facebookresearch/stochastic_gradient_push"
62
+ )
63
+ ddp_class = gossip.GossipDataParallel
64
+
65
+ # The values of slowmo_momentum below were obtained by tuning on the
66
+ # En-De 16 dataset by training the transformer_wmt_en_de_large model
67
+ if args.slowmo_momentum is None:
68
+ if args.distributed_world_size <= 16:
69
+ args.slowmo_momentum = 0.0
70
+ elif args.distributed_world_size <= 32:
71
+ args.slowmo_momentum = 0.2
72
+ elif args.distributed_world_size <= 64:
73
+ args.slowmo_momentum = 0.5
74
+ else:
75
+ args.slowmo_momentum = 0.6
76
+
77
+ init_kwargs = dict(
78
+ module=model,
79
+ device_ids=[args.device_id],
80
+ output_device=args.device_id,
81
+ broadcast_buffers=args.broadcast_buffers,
82
+ nprocs_per_node=args.nprocs_per_node,
83
+ slowmo_momentum=args.slowmo_momentum,
84
+ localsgd=(args.slowmo_algorithm == "LocalSGD"),
85
+ localsgd_frequency=args.localsgd_frequency,
86
+ )
87
+ else:
88
+ raise ValueError("Unknown --ddp-backend: " + args.ddp_backend)
89
+
90
+ class _DistributedFairseqModel(ddp_class):
91
+ """Extend DistributedDataParallel to check for missing
92
+ attributes in the wrapped module."""
93
+
94
+ def __init__(self, *args, **kwargs):
95
+ super().__init__(*args, **kwargs)
96
+
97
+ def __getattr__(self, name):
98
+ wrapped_module = super().__getattr__("module")
99
+ if hasattr(wrapped_module, name):
100
+ return getattr(wrapped_module, name)
101
+ return super().__getattr__(name)
102
+
103
+ return _DistributedFairseqModel(**init_kwargs)
fairseq-0.10.2/fairseq/models/fairseq_decoder.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import torch.nn as nn
9
+ from fairseq import utils
10
+ from torch import Tensor
11
+
12
+
13
+ class FairseqDecoder(nn.Module):
14
+ """Base class for decoders."""
15
+
16
+ def __init__(self, dictionary):
17
+ super().__init__()
18
+ self.dictionary = dictionary
19
+ self.onnx_trace = False
20
+
21
+ def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
22
+ """
23
+ Args:
24
+ prev_output_tokens (LongTensor): shifted output tokens of shape
25
+ `(batch, tgt_len)`, for teacher forcing
26
+ encoder_out (dict, optional): output from the encoder, used for
27
+ encoder-side attention
28
+
29
+ Returns:
30
+ tuple:
31
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
32
+ - a dictionary with any model-specific outputs
33
+ """
34
+ x, extra = self.extract_features(
35
+ prev_output_tokens, encoder_out=encoder_out, **kwargs
36
+ )
37
+ x = self.output_layer(x)
38
+ return x, extra
39
+
40
+ def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
41
+ """
42
+ Returns:
43
+ tuple:
44
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
45
+ - a dictionary with any model-specific outputs
46
+ """
47
+ raise NotImplementedError
48
+
49
+ def output_layer(self, features, **kwargs):
50
+ """
51
+ Project features to the default output size, e.g., vocabulary size.
52
+
53
+ Args:
54
+ features (Tensor): features returned by *extract_features*.
55
+ """
56
+ raise NotImplementedError
57
+
58
+ def get_normalized_probs(
59
+ self,
60
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
61
+ log_probs: bool,
62
+ sample: Optional[Dict[str, Tensor]] = None,
63
+ ):
64
+ """Get normalized probabilities (or log probs) from a net's output."""
65
+
66
+ if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
67
+ if sample is not None:
68
+ assert "target" in sample
69
+ target = sample["target"]
70
+ else:
71
+ target = None
72
+ out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
73
+ return out.exp_() if not log_probs else out
74
+
75
+ logits = net_output[0]
76
+ if log_probs:
77
+ return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
78
+ else:
79
+ return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
80
+
81
+ def max_positions(self):
82
+ """Maximum input length supported by the decoder."""
83
+ return 1e6 # an arbitrary large number
84
+
85
+ def upgrade_state_dict(self, state_dict):
86
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
87
+ return state_dict
88
+
89
+ def prepare_for_onnx_export_(self):
90
+ self.onnx_trace = True
fairseq-0.10.2/fairseq/models/fairseq_encoder.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Dict, List, NamedTuple, Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import Tensor
11
+
12
+
13
+ EncoderOut = NamedTuple(
14
+ "EncoderOut",
15
+ [
16
+ ("encoder_out", Tensor), # T x B x C
17
+ ("encoder_padding_mask", Optional[Tensor]), # B x T
18
+ ("encoder_embedding", Optional[Tensor]), # B x T x C
19
+ ("encoder_states", Optional[List[Tensor]]), # List[T x B x C]
20
+ ("src_tokens", Optional[Tensor]), # B x T
21
+ ("src_lengths", Optional[Tensor]), # B x 1
22
+ ],
23
+ )
24
+
25
+
26
+ class FairseqEncoder(nn.Module):
27
+ """Base class for encoders."""
28
+
29
+ def __init__(self, dictionary):
30
+ super().__init__()
31
+ self.dictionary = dictionary
32
+
33
+ def forward(self, src_tokens, src_lengths=None, **kwargs):
34
+ """
35
+ Args:
36
+ src_tokens (LongTensor): tokens in the source language of shape
37
+ `(batch, src_len)`
38
+ src_lengths (LongTensor): lengths of each source sentence of shape
39
+ `(batch)`
40
+ """
41
+ raise NotImplementedError
42
+
43
+ def forward_torchscript(self, net_input: Dict[str, Tensor]):
44
+ """A TorchScript-compatible version of forward.
45
+
46
+ Encoders which use additional arguments may want to override
47
+ this method for TorchScript compatibility.
48
+ """
49
+ if torch.jit.is_scripting():
50
+ return self.forward(
51
+ src_tokens=net_input["src_tokens"],
52
+ src_lengths=net_input["src_lengths"],
53
+ )
54
+ else:
55
+ return self.forward_non_torchscript(net_input)
56
+
57
+ @torch.jit.unused
58
+ def forward_non_torchscript(self, net_input: Dict[str, Tensor]):
59
+ encoder_input = {
60
+ k: v for k, v in net_input.items() if k != "prev_output_tokens"
61
+ }
62
+ return self.forward(**encoder_input)
63
+
64
+ def reorder_encoder_out(self, encoder_out, new_order):
65
+ """
66
+ Reorder encoder output according to `new_order`.
67
+
68
+ Args:
69
+ encoder_out: output from the ``forward()`` method
70
+ new_order (LongTensor): desired order
71
+
72
+ Returns:
73
+ `encoder_out` rearranged according to `new_order`
74
+ """
75
+ raise NotImplementedError
76
+
77
+ def max_positions(self):
78
+ """Maximum input length supported by the encoder."""
79
+ return 1e6 # an arbitrary large number
80
+
81
+ def upgrade_state_dict(self, state_dict):
82
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
83
+ return state_dict
84
+
85
+ def set_num_updates(self, num_updates):
86
+ """State from trainer to pass along to model at every update."""
87
+
88
+ def _apply(m):
89
+ if hasattr(m, "set_num_updates") and m != self:
90
+ m.set_num_updates(num_updates)
91
+
92
+ self.apply(_apply)
fairseq-0.10.2/fairseq/models/fairseq_model.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ Base classes for various fairseq models.
7
+ """
8
+
9
+ import logging
10
+ from typing import Dict, List, Optional, Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from fairseq import utils
16
+ # from fairseq.checkpoint_utils import prune_state_dict
17
+ from fairseq.data import Dictionary
18
+ from fairseq.dataclass.utils import gen_parser_from_dataclass
19
+ from fairseq.models import FairseqDecoder, FairseqEncoder
20
+ from torch import Tensor
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class BaseFairseqModel(nn.Module):
27
+ """Base class for fairseq models."""
28
+
29
+ def __init__(self):
30
+ super().__init__()
31
+ self._is_generation_fast = False
32
+
33
+ @classmethod
34
+ def add_args(cls, parser):
35
+ """Add model-specific arguments to the parser."""
36
+ dc = getattr(cls, "__dataclass", None)
37
+ if dc is not None:
38
+ # do not set defaults so that settings defaults from various architectures still works
39
+ gen_parser_from_dataclass(parser, dc(), delete_default=True)
40
+
41
+ @classmethod
42
+ def build_model(cls, args, task):
43
+ """Build a new model instance."""
44
+ raise NotImplementedError("Model must implement the build_model method")
45
+
46
+ def get_targets(self, sample, net_output):
47
+ """Get targets from either the sample or the net's output."""
48
+ return sample["target"]
49
+
50
+ def get_normalized_probs(
51
+ self,
52
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
53
+ log_probs: bool,
54
+ sample: Optional[Dict[str, Tensor]] = None,
55
+ ):
56
+ """Get normalized probabilities (or log probs) from a net's output."""
57
+ return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
58
+
59
+ # TorchScript doesn't support super() method so that the scriptable Subclass
60
+ # can't access the base class model in Torchscript.
61
+ # Current workaround is to add a helper function with different name and
62
+ # call the helper function from scriptable Subclass.
63
+ def get_normalized_probs_scriptable(
64
+ self,
65
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
66
+ log_probs: bool,
67
+ sample: Optional[Dict[str, Tensor]] = None,
68
+ ):
69
+ """Scriptable helper function for get_normalized_probs in ~BaseFairseqModel"""
70
+ if hasattr(self, "decoder"):
71
+ return self.decoder.get_normalized_probs(net_output, log_probs, sample)
72
+ elif torch.is_tensor(net_output):
73
+ # syntactic sugar for simple models which don't have a decoder
74
+ # (e.g., the classification tutorial)
75
+ logits = net_output.float()
76
+ if log_probs:
77
+ return F.log_softmax(logits, dim=-1)
78
+ else:
79
+ return F.softmax(logits, dim=-1)
80
+ raise NotImplementedError
81
+
82
+ def extract_features(self, *args, **kwargs):
83
+ """Similar to *forward* but only return features."""
84
+ return self(*args, **kwargs)
85
+
86
+ def max_positions(self):
87
+ """Maximum length supported by the model."""
88
+ return None
89
+
90
+ def load_state_dict(self, state_dict, strict=True, args=None):
91
+ """Copies parameters and buffers from *state_dict* into this module and
92
+ its descendants.
93
+
94
+ Overrides the method in :class:`nn.Module`. Compared with that method
95
+ this additionally "upgrades" *state_dicts* from old checkpoints.
96
+ """
97
+ self.upgrade_state_dict(state_dict)
98
+ from fairseq.checkpoint_utils import prune_state_dict
99
+ new_state_dict = prune_state_dict(state_dict, args)
100
+ return super().load_state_dict(new_state_dict, strict)
101
+
102
+ def upgrade_state_dict(self, state_dict):
103
+ """Upgrade old state dicts to work with newer code."""
104
+ self.upgrade_state_dict_named(state_dict, "")
105
+
106
+ def upgrade_state_dict_named(self, state_dict, name):
107
+ """Upgrade old state dicts to work with newer code.
108
+
109
+ Args:
110
+ state_dict (dict): state dictionary to upgrade, in place
111
+ name (str): the state dict key corresponding to the current module
112
+ """
113
+ assert state_dict is not None
114
+
115
+ def do_upgrade(m, prefix):
116
+ if len(prefix) > 0:
117
+ prefix += "."
118
+
119
+ for n, c in m.named_children():
120
+ name = prefix + n
121
+ if hasattr(c, "upgrade_state_dict_named"):
122
+ c.upgrade_state_dict_named(state_dict, name)
123
+ elif hasattr(c, "upgrade_state_dict"):
124
+ c.upgrade_state_dict(state_dict)
125
+ do_upgrade(c, name)
126
+
127
+ do_upgrade(self, name)
128
+
129
+ def set_num_updates(self, num_updates):
130
+ """State from trainer to pass along to model at every update."""
131
+
132
+ def _apply(m):
133
+ if hasattr(m, "set_num_updates") and m != self:
134
+ m.set_num_updates(num_updates)
135
+
136
+ self.apply(_apply)
137
+
138
+ def prepare_for_inference_(self, args):
139
+ """Prepare model for inference."""
140
+ kwargs = {}
141
+ kwargs["beamable_mm_beam_size"] = (
142
+ None if getattr(args, "no_beamable_mm", False) else getattr(args, "beam", 5)
143
+ )
144
+ kwargs["need_attn"] = getattr(args, "print_alignment", False)
145
+ if hasattr(args, "retain_dropout"):
146
+ kwargs["retain_dropout"] = args.retain_dropout
147
+ kwargs["retain_dropout_modules"] = getattr(
148
+ args, "retain_dropout_modules", None
149
+ )
150
+ self.make_generation_fast_(**kwargs)
151
+
152
+ def make_generation_fast_(self, **kwargs):
153
+ """
154
+ Legacy entry point to optimize model for faster generation.
155
+ Prefer prepare_for_inference_.
156
+ """
157
+ if self._is_generation_fast:
158
+ return # only apply once
159
+ self._is_generation_fast = True
160
+
161
+ # remove weight norm from all modules in the network
162
+ def apply_remove_weight_norm(module):
163
+ try:
164
+ nn.utils.remove_weight_norm(module)
165
+ except (AttributeError, ValueError): # this module didn't have weight norm
166
+ return
167
+
168
+ self.apply(apply_remove_weight_norm)
169
+
170
+ def apply_make_generation_fast_(module, prefix):
171
+ if len(prefix) > 0:
172
+ prefix += "."
173
+
174
+ base_func = BaseFairseqModel.make_generation_fast_
175
+ for n, m in module.named_modules():
176
+ if (
177
+ m != self
178
+ and hasattr(m, "make_generation_fast_")
179
+ # don't call this implementation again, e.g., if
180
+ # children modules also inherit from BaseFairseqModel
181
+ and m.make_generation_fast_.__func__ is not base_func
182
+ ):
183
+ name = prefix + n
184
+ m.make_generation_fast_(name=name, **kwargs)
185
+
186
+ apply_make_generation_fast_(self, "")
187
+
188
+ def train(mode=True):
189
+ if mode:
190
+ raise RuntimeError("cannot train after make_generation_fast")
191
+
192
+ # this model should no longer be used for training
193
+ self.eval()
194
+ self.train = train
195
+
196
+ def prepare_for_onnx_export_(self, **kwargs):
197
+ """Make model exportable via ONNX trace."""
198
+ seen = set()
199
+
200
+ def apply_prepare_for_onnx_export_(module):
201
+ if (
202
+ module != self
203
+ and hasattr(module, "prepare_for_onnx_export_")
204
+ and module not in seen
205
+ ):
206
+ seen.add(module)
207
+ module.prepare_for_onnx_export_(**kwargs)
208
+
209
+ self.apply(apply_prepare_for_onnx_export_)
210
+
211
+ def prepare_for_tpu_(self, **kwargs):
212
+ """Optionally modify model for use on TPUs."""
213
+ seen = set()
214
+
215
+ def apply_prepare_for_tpu_(module):
216
+ if (
217
+ module != self
218
+ and hasattr(module, "prepare_for_tpu_")
219
+ and module not in seen
220
+ ):
221
+ seen.add(module)
222
+ module.prepare_for_tpu_(**kwargs)
223
+
224
+ self.apply(apply_prepare_for_tpu_)
225
+
226
+ @classmethod
227
+ def upgrade_args(cls, args):
228
+ if hasattr(args, "max_sentences") and not hasattr(args, "batch_size"):
229
+ args.batch_size = args.max_sentences
230
+
231
+ @classmethod
232
+ def from_pretrained(
233
+ cls,
234
+ model_name_or_path,
235
+ checkpoint_file="model.pt",
236
+ data_name_or_path=".",
237
+ **kwargs,
238
+ ):
239
+ """
240
+ Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model
241
+ file. Downloads and caches the pre-trained model file if needed.
242
+
243
+ The base implementation returns a
244
+ :class:`~fairseq.hub_utils.GeneratorHubInterface`, which can be used to
245
+ generate translations or sample from language models. The underlying
246
+ :class:`~fairseq.models.FairseqModel` can be accessed via the
247
+ *generator.models* attribute.
248
+
249
+ Other models may override this to implement custom hub interfaces.
250
+
251
+ Args:
252
+ model_name_or_path (str): either the name of a pre-trained model to
253
+ load or a path/URL to a pre-trained model state dict
254
+ checkpoint_file (str, optional): colon-separated list of checkpoint
255
+ files in the model archive to ensemble (default: 'model.pt')
256
+ data_name_or_path (str, optional): point args.data to the archive
257
+ at the given path/URL. Can start with '.' or './' to reuse the
258
+ model archive path.
259
+ """
260
+ from fairseq import hub_utils
261
+
262
+ x = hub_utils.from_pretrained(
263
+ model_name_or_path,
264
+ checkpoint_file,
265
+ data_name_or_path,
266
+ archive_map=cls.hub_models(),
267
+ **kwargs,
268
+ )
269
+
270
+ cls.upgrade_args(x["args"])
271
+
272
+ logger.info(x["args"])
273
+ return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"])
274
+
275
+ @classmethod
276
+ def hub_models(cls):
277
+ return {}
278
+
279
+
280
+ class FairseqEncoderDecoderModel(BaseFairseqModel):
281
+ """Base class for encoder-decoder models.
282
+
283
+ Args:
284
+ encoder (FairseqEncoder): the encoder
285
+ decoder (FairseqDecoder): the decoder
286
+ """
287
+
288
+ def __init__(self, encoder, decoder):
289
+ super().__init__()
290
+
291
+ self.encoder = encoder
292
+ self.decoder = decoder
293
+ assert isinstance(self.encoder, FairseqEncoder)
294
+ assert isinstance(self.decoder, FairseqDecoder)
295
+
296
+ def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
297
+ """
298
+ Run the forward pass for an encoder-decoder model.
299
+
300
+ First feed a batch of source tokens through the encoder. Then, feed the
301
+ encoder output and previous decoder outputs (i.e., teacher forcing) to
302
+ the decoder to produce the next outputs::
303
+
304
+ encoder_out = self.encoder(src_tokens, src_lengths)
305
+ return self.decoder(prev_output_tokens, encoder_out)
306
+
307
+ Args:
308
+ src_tokens (LongTensor): tokens in the source language of shape
309
+ `(batch, src_len)`
310
+ src_lengths (LongTensor): source sentence lengths of shape `(batch)`
311
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
312
+ `(batch, tgt_len)`, for teacher forcing
313
+
314
+ Returns:
315
+ tuple:
316
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
317
+ - a dictionary with any model-specific outputs
318
+ """
319
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
320
+ decoder_out = self.decoder(
321
+ prev_output_tokens, encoder_out=encoder_out, **kwargs
322
+ )
323
+ return decoder_out
324
+
325
+ def forward_decoder(self, prev_output_tokens, **kwargs):
326
+ return self.decoder(prev_output_tokens, **kwargs)
327
+
328
+ def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
329
+ """
330
+ Similar to *forward* but only return features.
331
+
332
+ Returns:
333
+ tuple:
334
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
335
+ - a dictionary with any model-specific outputs
336
+ """
337
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
338
+ features = self.decoder.extract_features(
339
+ prev_output_tokens, encoder_out=encoder_out, **kwargs
340
+ )
341
+ return features
342
+
343
+ def output_layer(self, features, **kwargs):
344
+ """Project features to the default output size (typically vocabulary size)."""
345
+ return self.decoder.output_layer(features, **kwargs)
346
+
347
+ def max_positions(self):
348
+ """Maximum length supported by the model."""
349
+ return (self.encoder.max_positions(), self.decoder.max_positions())
350
+
351
+ def max_decoder_positions(self):
352
+ """Maximum length supported by the decoder."""
353
+ return self.decoder.max_positions()
354
+
355
+
356
+ class FairseqModel(FairseqEncoderDecoderModel):
357
+ def __init__(self, *args, **kwargs):
358
+ super().__init__(*args, **kwargs)
359
+ utils.deprecation_warning(
360
+ "FairseqModel is deprecated, please use FairseqEncoderDecoderModel "
361
+ "or BaseFairseqModel instead",
362
+ stacklevel=4,
363
+ )
364
+
365
+
366
+ class FairseqMultiModel(BaseFairseqModel):
367
+ """Base class for combining multiple encoder-decoder models."""
368
+
369
+ def __init__(self, encoders, decoders):
370
+ super().__init__()
371
+ assert encoders.keys() == decoders.keys()
372
+ self.keys = list(encoders.keys())
373
+ for key in self.keys:
374
+ assert isinstance(encoders[key], FairseqEncoder)
375
+ assert isinstance(decoders[key], FairseqDecoder)
376
+
377
+ self.models = nn.ModuleDict(
378
+ {
379
+ key: FairseqEncoderDecoderModel(encoders[key], decoders[key])
380
+ for key in self.keys
381
+ }
382
+ )
383
+
384
+ @staticmethod
385
+ def build_shared_embeddings(
386
+ dicts: Dict[str, Dictionary],
387
+ langs: List[str],
388
+ embed_dim: int,
389
+ build_embedding: callable,
390
+ pretrained_embed_path: Optional[str] = None,
391
+ ):
392
+ """
393
+ Helper function to build shared embeddings for a set of languages after
394
+ checking that all dicts corresponding to those languages are equivalent.
395
+
396
+ Args:
397
+ dicts: Dict of lang_id to its corresponding Dictionary
398
+ langs: languages that we want to share embeddings for
399
+ embed_dim: embedding dimension
400
+ build_embedding: callable function to actually build the embedding
401
+ pretrained_embed_path: Optional path to load pretrained embeddings
402
+ """
403
+ shared_dict = dicts[langs[0]]
404
+ if any(dicts[lang] != shared_dict for lang in langs):
405
+ raise ValueError(
406
+ "--share-*-embeddings requires a joined dictionary: "
407
+ "--share-encoder-embeddings requires a joined source "
408
+ "dictionary, --share-decoder-embeddings requires a joined "
409
+ "target dictionary, and --share-all-embeddings requires a "
410
+ "joint source + target dictionary."
411
+ )
412
+ return build_embedding(shared_dict, embed_dim, pretrained_embed_path)
413
+
414
+ def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
415
+ raise NotImplementedError
416
+
417
+ def max_positions(self):
418
+ """Maximum length supported by the model."""
419
+ return {
420
+ key: (
421
+ self.models[key].encoder.max_positions(),
422
+ self.models[key].decoder.max_positions(),
423
+ )
424
+ for key in self.keys
425
+ }
426
+
427
+ def max_decoder_positions(self):
428
+ """Maximum length supported by the decoder."""
429
+ return min(model.decoder.max_positions() for model in self.models.values())
430
+
431
+ @property
432
+ def encoder(self):
433
+ return self.models[self.keys[0]].encoder
434
+
435
+ @property
436
+ def decoder(self):
437
+ return self.models[self.keys[0]].decoder
438
+
439
+ def forward_decoder(self, prev_output_tokens, **kwargs):
440
+ return self.decoder(prev_output_tokens, **kwargs)
441
+
442
+ def load_state_dict(self, state_dict, strict=True, args=None):
443
+ """Copies parameters and buffers from *state_dict* into this module and
444
+ its descendants.
445
+
446
+ Overrides the method in :class:`nn.Module`. Compared with that method
447
+ this additionally "upgrades" *state_dicts* from old checkpoints.
448
+ """
449
+ self.upgrade_state_dict(state_dict)
450
+ from fairseq.checkpoint_utils import prune_state_dict
451
+ new_state_dict = prune_state_dict(state_dict, args)
452
+ return super().load_state_dict(new_state_dict, strict)
453
+
454
+
455
+ class FairseqLanguageModel(BaseFairseqModel):
456
+ """Base class for decoder-only models.
457
+
458
+ Args:
459
+ decoder (FairseqDecoder): the decoder
460
+ """
461
+
462
+ def __init__(self, decoder):
463
+ super().__init__()
464
+ self.decoder = decoder
465
+ assert isinstance(self.decoder, FairseqDecoder)
466
+
467
+ def forward(self, src_tokens, **kwargs):
468
+ """
469
+ Run the forward pass for a decoder-only model.
470
+
471
+ Feeds a batch of tokens through the decoder to predict the next tokens.
472
+
473
+ Args:
474
+ src_tokens (LongTensor): tokens on which to condition the decoder,
475
+ of shape `(batch, tgt_len)`
476
+ src_lengths (LongTensor): source sentence lengths of shape `(batch)`
477
+
478
+ Returns:
479
+ tuple:
480
+ - the decoder's output of shape `(batch, seq_len, vocab)`
481
+ - a dictionary with any model-specific outputs
482
+ """
483
+ return self.decoder(src_tokens, **kwargs)
484
+
485
+ def forward_decoder(self, prev_output_tokens, **kwargs):
486
+ return self.decoder(prev_output_tokens, **kwargs)
487
+
488
+ def extract_features(self, src_tokens, **kwargs):
489
+ """
490
+ Similar to *forward* but only return features.
491
+
492
+ Returns:
493
+ tuple:
494
+ - the decoder's features of shape `(batch, seq_len, embed_dim)`
495
+ - a dictionary with any model-specific outputs
496
+ """
497
+ return self.decoder.extract_features(src_tokens, **kwargs)
498
+
499
+ def output_layer(self, features, **kwargs):
500
+ """Project features to the default output size (typically vocabulary size)."""
501
+ return self.decoder.output_layer(features, **kwargs)
502
+
503
+ def max_positions(self):
504
+ """Maximum length supported by the model."""
505
+ return self.decoder.max_positions()
506
+
507
+ def max_decoder_positions(self):
508
+ """Maximum length supported by the decoder."""
509
+ return self.decoder.max_positions()
510
+
511
+ @property
512
+ def supported_targets(self):
513
+ return {"future"}
514
+
515
+
516
+ class FairseqEncoderModel(BaseFairseqModel):
517
+ """Base class for encoder-only models.
518
+
519
+ Args:
520
+ encoder (FairseqEncoder): the encoder
521
+ """
522
+
523
+ def __init__(self, encoder):
524
+ super().__init__()
525
+ self.encoder = encoder
526
+ assert isinstance(self.encoder, FairseqEncoder)
527
+
528
+ def forward(self, src_tokens, src_lengths, **kwargs):
529
+ """
530
+ Run the forward pass for a encoder-only model.
531
+
532
+ Feeds a batch of tokens through the encoder to generate features.
533
+
534
+ Args:
535
+ src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
536
+ src_lengths (LongTensor): source sentence lengths of shape `(batch)`
537
+
538
+ Returns:
539
+ the encoder's output, typically of shape `(batch, src_len, features)`
540
+ """
541
+ return self.encoder(src_tokens, src_lengths, **kwargs)
542
+
543
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
544
+ """Get normalized probabilities (or log probs) from a net's output."""
545
+ encoder_out = net_output["encoder_out"]
546
+ if torch.is_tensor(encoder_out):
547
+ logits = encoder_out.float()
548
+ if log_probs:
549
+ return F.log_softmax(logits, dim=-1)
550
+ else:
551
+ return F.softmax(logits, dim=-1)
552
+ raise NotImplementedError
553
+
554
+ def max_positions(self):
555
+ """Maximum length supported by the model."""
556
+ return self.encoder.max_positions()
fairseq-0.10.2/fairseq/models/huggingface/hf_gpt2.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import sys
9
+ from typing import Dict, List, Optional
10
+
11
+ import torch
12
+ from fairseq.models import (
13
+ FairseqIncrementalDecoder,
14
+ FairseqLanguageModel,
15
+ register_model,
16
+ register_model_architecture,
17
+ )
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ DEFAULT_MAX_TARGET_POSITIONS = 1024
24
+
25
+
26
+ @register_model("hf_gpt2")
27
+ class HuggingFaceGPT2LanguageModel(FairseqLanguageModel):
28
+ def __init__(self, decoder):
29
+ super().__init__(decoder)
30
+
31
+ @staticmethod
32
+ def add_args(parser):
33
+ """Add model-specific arguments to the parser."""
34
+ # fmt: off
35
+ parser.add_argument('--embed-dim', type=int, metavar='N',
36
+ help='embedding dimension')
37
+ parser.add_argument('--num-attention-heads', type=int, metavar='N',
38
+ help='num attention heads')
39
+ parser.add_argument('--num-layers', type=int, metavar='N',
40
+ help='num layers')
41
+ parser.add_argument('--dropout', type=float, metavar='D',
42
+ help='dropout probability for all fully connected layers '
43
+ 'in the embeddings, encoder, and pooler')
44
+ parser.add_argument('--attention-dropout', type=float, metavar='D',
45
+ help='dropout probability for attention weights')
46
+ # fmt: on
47
+
48
+ @classmethod
49
+ def build_model(cls, args, task):
50
+ """Build a new model instance."""
51
+ default_architecture(args)
52
+ return cls(HuggingFaceGPT2Decoder(args, task))
53
+
54
+
55
+ class HuggingFaceGPT2Decoder(FairseqIncrementalDecoder):
56
+ def __init__(self, args, task):
57
+ try:
58
+ from transformers import GPT2Config, GPT2LMHeadModel
59
+ except ImportError:
60
+ raise ImportError(
61
+ "\n\nPlease install huggingface/transformers with:"
62
+ "\n\n pip install transformers"
63
+ )
64
+
65
+ super().__init__(task.target_dictionary)
66
+
67
+ config = GPT2Config(
68
+ vocab_size=len(task.target_dictionary),
69
+ n_positions=args.max_target_positions + 1,
70
+ n_ctx=args.max_target_positions,
71
+ n_embd=args.embed_dim,
72
+ n_layer=args.num_layers,
73
+ n_head=args.num_attention_heads,
74
+ resid_pdrop=args.dropout,
75
+ embd_pdrop=args.dropout,
76
+ attn_pdrop=args.attention_dropout,
77
+ layer_norm_epsilon=1e-6,
78
+ )
79
+ self.model = GPT2LMHeadModel(config)
80
+
81
+ # set zero embedding for padding symbol
82
+ self.pad_idx = task.target_dictionary.pad()
83
+ self.model.transformer.wte.weight.data[self.pad_idx].zero_()
84
+ self.model.transformer.wpe.weight.data[0].zero_()
85
+
86
+ def forward(
87
+ self,
88
+ prev_output_tokens,
89
+ src_lengths=None,
90
+ incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
91
+ encoder_out=None,
92
+ ):
93
+ features = self.extract_features(prev_output_tokens, incremental_state)
94
+ lm_logits = self.model.lm_head(features)
95
+ return (lm_logits,)
96
+
97
+ def extract_features(
98
+ self,
99
+ prev_output_tokens,
100
+ incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
101
+ ):
102
+ if incremental_state:
103
+ past = self.get_incremental_state("past")
104
+ else:
105
+ past = None
106
+
107
+ # don't attend to padding symbols
108
+ attention_mask = prev_output_tokens.ne(self.pad_idx).int()
109
+
110
+ # set position ids to exclude padding symbols
111
+ position_ids = attention_mask * (
112
+ torch.arange(1, 1 + prev_output_tokens.size(1))
113
+ .to(prev_output_tokens)
114
+ .repeat(prev_output_tokens.size(0), 1)
115
+ )
116
+
117
+ outputs = self.model.transformer(
118
+ input_ids=prev_output_tokens,
119
+ past=past,
120
+ attention_mask=attention_mask,
121
+ position_ids=position_ids,
122
+ )
123
+ last_hidden_states = outputs[0]
124
+
125
+ if incremental_state:
126
+ self.set_incremental_state(incremental_state, "past", outputs[1])
127
+
128
+ return last_hidden_states
129
+
130
+ def max_positions(self):
131
+ return self.model.config.n_positions - 1
132
+
133
+
134
+ @register_model_architecture("hf_gpt2", "hf_gpt2")
135
+ def default_architecture(args):
136
+ if getattr(args, "max_target_positions", None) is None:
137
+ args.max_target_positions = getattr(
138
+ args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
139
+ )
140
+ args.embed_dim = getattr(args, "embed_dim", 768)
141
+ args.num_attention_heads = getattr(args, "num_attention_heads", 12)
142
+ args.num_layers = getattr(args, "num_layers", 12)
143
+ args.dropout = getattr(args, "dropout", 0.1)
144
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
145
+
146
+
147
+ @register_model_architecture("hf_gpt2", "hf_gpt2_medium")
148
+ def hf_gpt2_medium(args):
149
+ args.embed_dim = getattr(args, "embed_dim", 1024)
150
+ args.num_attention_heads = getattr(args, "num_attention_heads", 16)
151
+ args.num_layers = getattr(args, "num_layers", 24)
152
+ default_architecture(args)
153
+
154
+
155
+ @register_model_architecture("hf_gpt2", "hf_gpt2_large")
156
+ def hf_gpt2_large(args):
157
+ args.embed_dim = getattr(args, "embed_dim", 1280)
158
+ args.num_attention_heads = getattr(args, "num_attention_heads", 20)
159
+ args.num_layers = getattr(args, "num_layers", 36)
160
+ default_architecture(args)
161
+
162
+
163
+ @register_model_architecture("hf_gpt2", "hf_gpt2_xl")
164
+ def hf_gpt2_xl(args):
165
+ args.embed_dim = getattr(args, "embed_dim", 1600)
166
+ args.num_attention_heads = getattr(args, "num_attention_heads", 25)
167
+ args.num_layers = getattr(args, "num_layers", 48)
168
+ default_architecture(args)
fairseq-0.10.2/fairseq/models/lightconv_lm.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq import utils
7
+ from fairseq.models import (
8
+ FairseqLanguageModel,
9
+ register_model,
10
+ register_model_architecture,
11
+ )
12
+ from fairseq.models.lightconv import Embedding, LightConvDecoder
13
+ from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder
14
+
15
+
16
+ @register_model("lightconv_lm")
17
+ class LightConvLanguageModel(FairseqLanguageModel):
18
+ def __init__(self, decoder):
19
+ super().__init__(decoder)
20
+
21
+ @staticmethod
22
+ def add_args(parser):
23
+ """Add model-specific arguments to the parser."""
24
+ parser.add_argument(
25
+ "--dropout",
26
+ default=0.1,
27
+ type=float,
28
+ metavar="D",
29
+ help="dropout probability",
30
+ )
31
+ parser.add_argument(
32
+ "--attention-dropout",
33
+ default=0.0,
34
+ type=float,
35
+ metavar="D",
36
+ help="dropout probability for attention weights",
37
+ )
38
+ parser.add_argument(
39
+ "--relu-dropout",
40
+ default=0.0,
41
+ type=float,
42
+ metavar="D",
43
+ help="dropout probability after ReLU in FFN",
44
+ )
45
+ parser.add_argument(
46
+ "--input-dropout",
47
+ type=float,
48
+ metavar="D",
49
+ help="dropout probability of the inputs",
50
+ )
51
+ parser.add_argument(
52
+ "--decoder-embed-dim",
53
+ type=int,
54
+ metavar="N",
55
+ help="decoder embedding dimension",
56
+ )
57
+ parser.add_argument(
58
+ "--decoder-output-dim",
59
+ type=int,
60
+ metavar="N",
61
+ help="decoder output dimension",
62
+ )
63
+ parser.add_argument(
64
+ "--decoder-input-dim", type=int, metavar="N", help="decoder input dimension"
65
+ )
66
+ parser.add_argument(
67
+ "--decoder-ffn-embed-dim",
68
+ type=int,
69
+ metavar="N",
70
+ help="decoder embedding dimension for FFN",
71
+ )
72
+ parser.add_argument(
73
+ "--decoder-layers", type=int, metavar="N", help="num decoder layers"
74
+ )
75
+ parser.add_argument(
76
+ "--decoder-attention-heads",
77
+ type=int,
78
+ metavar="N",
79
+ help="num decoder attention heads or LightConv/DynamicConv heads",
80
+ )
81
+ parser.add_argument(
82
+ "--decoder-normalize-before",
83
+ default=False,
84
+ action="store_true",
85
+ help="apply layernorm before each decoder block",
86
+ )
87
+ parser.add_argument(
88
+ "--adaptive-softmax-cutoff",
89
+ metavar="EXPR",
90
+ help="comma separated list of adaptive softmax cutoff points. "
91
+ "Must be used with adaptive_loss criterion",
92
+ )
93
+ parser.add_argument(
94
+ "--adaptive-softmax-dropout",
95
+ type=float,
96
+ metavar="D",
97
+ help="sets adaptive softmax dropout for the tail projections",
98
+ )
99
+ parser.add_argument(
100
+ "--adaptive-softmax-factor",
101
+ type=float,
102
+ metavar="N",
103
+ help="adaptive input factor",
104
+ )
105
+ parser.add_argument(
106
+ "--no-token-positional-embeddings",
107
+ default=False,
108
+ action="store_true",
109
+ help="if set, disables positional embeddings (outside self attention)",
110
+ )
111
+ parser.add_argument(
112
+ "--share-decoder-input-output-embed",
113
+ default=False,
114
+ action="store_true",
115
+ help="share decoder input and output embeddings",
116
+ )
117
+ parser.add_argument(
118
+ "--character-embeddings",
119
+ default=False,
120
+ action="store_true",
121
+ help="if set, uses character embedding convolutions to produce token embeddings",
122
+ )
123
+ parser.add_argument(
124
+ "--character-filters",
125
+ type=str,
126
+ metavar="LIST",
127
+ default="[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]",
128
+ help="size of character embeddings",
129
+ )
130
+ parser.add_argument(
131
+ "--character-embedding-dim",
132
+ type=int,
133
+ metavar="N",
134
+ default=4,
135
+ help="size of character embeddings",
136
+ )
137
+ parser.add_argument(
138
+ "--char-embedder-highway-layers",
139
+ type=int,
140
+ metavar="N",
141
+ default=2,
142
+ help="number of highway layers for character token embeddder",
143
+ )
144
+ parser.add_argument(
145
+ "--adaptive-input",
146
+ default=False,
147
+ action="store_true",
148
+ help="if set, uses adaptive input",
149
+ )
150
+ parser.add_argument(
151
+ "--adaptive-input-factor",
152
+ type=float,
153
+ metavar="N",
154
+ help="adaptive input factor",
155
+ )
156
+ parser.add_argument(
157
+ "--adaptive-input-cutoff",
158
+ metavar="EXPR",
159
+ help="comma separated list of adaptive input cutoff points.",
160
+ )
161
+ parser.add_argument(
162
+ "--tie-adaptive-weights",
163
+ action="store_true",
164
+ help="if set, ties the weights of adaptive softmax and adaptive input",
165
+ )
166
+ parser.add_argument(
167
+ "--tie-adaptive-proj",
168
+ action="store_true",
169
+ help="if set, ties the projection weights of adaptive softmax and adaptive input",
170
+ )
171
+ parser.add_argument(
172
+ "--decoder-learned-pos",
173
+ action="store_true",
174
+ help="use learned positional embeddings in the decoder",
175
+ )
176
+
177
+ """LightConv and DynamicConv arguments"""
178
+ parser.add_argument(
179
+ "--decoder-kernel-size-list",
180
+ type=lambda x: utils.eval_str_list(x, int),
181
+ help='list of kernel size (default: "[3,7,15,31,31,31]")',
182
+ )
183
+ parser.add_argument(
184
+ "--decoder-glu", type=utils.eval_bool, help="glu after in proj"
185
+ )
186
+ parser.add_argument(
187
+ "--decoder-conv-type",
188
+ default="dynamic",
189
+ type=str,
190
+ choices=["dynamic", "lightweight"],
191
+ help="type of convolution",
192
+ )
193
+ parser.add_argument("--weight-softmax", default=True, type=utils.eval_bool)
194
+ parser.add_argument(
195
+ "--weight-dropout",
196
+ type=float,
197
+ metavar="D",
198
+ help="dropout probability for conv weights",
199
+ )
200
+
201
+ @classmethod
202
+ def build_model(cls, args, task):
203
+ """Build a new model instance."""
204
+
205
+ # make sure all arguments are present in older models
206
+ base_lm_architecture(args)
207
+
208
+ if getattr(args, "max_source_positions", None) is None:
209
+ args.max_source_positions = args.tokens_per_sample
210
+ if getattr(args, "max_target_positions", None) is None:
211
+ args.max_target_positions = args.tokens_per_sample
212
+
213
+ if args.character_embeddings:
214
+ embed_tokens = CharacterTokenEmbedder(
215
+ task.dictionary,
216
+ eval(args.character_filters),
217
+ args.character_embedding_dim,
218
+ args.decoder_embed_dim,
219
+ args.char_embedder_highway_layers,
220
+ )
221
+ elif args.adaptive_input:
222
+ embed_tokens = AdaptiveInput(
223
+ len(task.dictionary),
224
+ task.dictionary.pad(),
225
+ args.decoder_input_dim,
226
+ args.adaptive_input_factor,
227
+ args.decoder_embed_dim,
228
+ utils.eval_str_list(args.adaptive_input_cutoff, type=int),
229
+ )
230
+ else:
231
+ embed_tokens = Embedding(
232
+ len(task.dictionary), args.decoder_input_dim, task.dictionary.pad()
233
+ )
234
+
235
+ if args.tie_adaptive_weights:
236
+ assert args.adaptive_input
237
+ assert args.adaptive_input_factor == args.adaptive_softmax_factor
238
+ assert (
239
+ args.adaptive_softmax_cutoff == args.adaptive_input_cutoff
240
+ ), "{} != {}".format(
241
+ args.adaptive_softmax_cutoff, args.adaptive_input_cutoff
242
+ )
243
+ assert args.decoder_input_dim == args.decoder_output_dim
244
+
245
+ decoder = LightConvDecoder(
246
+ args,
247
+ task.output_dictionary,
248
+ embed_tokens,
249
+ no_encoder_attn=True,
250
+ final_norm=False,
251
+ )
252
+ return LightConvLanguageModel(decoder)
253
+
254
+
255
+ @register_model_architecture("lightconv_lm", "lightconv_lm")
256
+ def base_lm_architecture(args):
257
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
258
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
259
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
260
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
261
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
262
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
263
+ args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
264
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
265
+
266
+ args.character_embeddings = getattr(args, "character_embeddings", False)
267
+
268
+ args.decoder_output_dim = getattr(
269
+ args, "decoder_output_dim", args.decoder_embed_dim
270
+ )
271
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
272
+ args.decoder_conv_dim = getattr(args, "decoder_conv_dim", args.decoder_embed_dim)
273
+
274
+ # The model training is not stable without this
275
+ args.decoder_normalize_before = True
276
+
277
+ args.adaptive_input = getattr(args, "adaptive_input", False)
278
+ args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
279
+ args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
280
+
281
+ args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
282
+ args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
283
+
284
+ args.decoder_kernel_size_list = getattr(
285
+ args, "decoder_kernel_size_list", [3, 7, 15, 31, 31, 31]
286
+ )
287
+ if len(args.decoder_kernel_size_list) == 1:
288
+ args.decoder_kernel_size_list = (
289
+ args.decoder_kernel_size_list * args.decoder_layers
290
+ )
291
+ assert (
292
+ len(args.decoder_kernel_size_list) == args.decoder_layers
293
+ ), "decoder_kernel_size_list doesn't match decoder_layers"
294
+ args.decoder_glu = getattr(args, "decoder_glu", True)
295
+ args.input_dropout = getattr(args, "input_dropout", 0.1)
296
+ args.weight_dropout = getattr(args, "weight_dropout", args.attention_dropout)
297
+
298
+
299
+ @register_model_architecture("lightconv_lm", "lightconv_lm_gbw")
300
+ def lightconv_lm_gbw(args):
301
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
302
+ args.dropout = getattr(args, "dropout", 0.1)
303
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
304
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
305
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
306
+ base_lm_architecture(args)
fairseq-0.10.2/fairseq/models/masked_lm.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from fairseq import utils
12
+ from fairseq.models import (
13
+ FairseqEncoder,
14
+ FairseqEncoderModel,
15
+ register_model,
16
+ register_model_architecture,
17
+ )
18
+ from fairseq.modules import (
19
+ LayerNorm,
20
+ SinusoidalPositionalEmbedding,
21
+ TransformerSentenceEncoder,
22
+ )
23
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ @register_model("masked_lm")
30
+ class MaskedLMModel(FairseqEncoderModel):
31
+ """
32
+ Class for training a Masked Language Model. It also supports an
33
+ additional sentence level prediction if the sent-loss argument is set.
34
+ """
35
+
36
+ def __init__(self, args, encoder):
37
+ super().__init__(encoder)
38
+ self.args = args
39
+
40
+ # if specified then apply bert initialization on the model. We need
41
+ # to explictly call this to make sure that the output embeddings
42
+ # and projection layers are also correctly initialized
43
+ if getattr(args, "apply_bert_init", False):
44
+ self.apply(init_bert_params)
45
+
46
+ @staticmethod
47
+ def add_args(parser):
48
+ """Add model-specific arguments to the parser."""
49
+ # Arguments related to dropout
50
+ parser.add_argument(
51
+ "--dropout", type=float, metavar="D", help="dropout probability"
52
+ )
53
+ parser.add_argument(
54
+ "--attention-dropout",
55
+ type=float,
56
+ metavar="D",
57
+ help="dropout probability for" " attention weights",
58
+ )
59
+ parser.add_argument(
60
+ "--act-dropout",
61
+ type=float,
62
+ metavar="D",
63
+ help="dropout probability after" " activation in FFN",
64
+ )
65
+
66
+ # Arguments related to hidden states and self-attention
67
+ parser.add_argument(
68
+ "--encoder-ffn-embed-dim",
69
+ type=int,
70
+ metavar="N",
71
+ help="encoder embedding dimension for FFN",
72
+ )
73
+ parser.add_argument(
74
+ "--encoder-layers", type=int, metavar="N", help="num encoder layers"
75
+ )
76
+ parser.add_argument(
77
+ "--encoder-attention-heads",
78
+ type=int,
79
+ metavar="N",
80
+ help="num encoder attention heads",
81
+ )
82
+
83
+ # Arguments related to input and output embeddings
84
+ parser.add_argument(
85
+ "--encoder-embed-dim",
86
+ type=int,
87
+ metavar="N",
88
+ help="encoder embedding dimension",
89
+ )
90
+ parser.add_argument(
91
+ "--share-encoder-input-output-embed",
92
+ action="store_true",
93
+ help="share encoder input" " and output embeddings",
94
+ )
95
+ parser.add_argument(
96
+ "--encoder-learned-pos",
97
+ action="store_true",
98
+ help="use learned positional embeddings in the encoder",
99
+ )
100
+ parser.add_argument(
101
+ "--no-token-positional-embeddings",
102
+ action="store_true",
103
+ help="if set, disables positional embeddings" " (outside self attention)",
104
+ )
105
+ parser.add_argument(
106
+ "--num-segment", type=int, metavar="N", help="num segment in the input"
107
+ )
108
+ parser.add_argument(
109
+ "--max-positions", type=int, help="number of positional embeddings to learn"
110
+ )
111
+
112
+ # Arguments related to sentence level prediction
113
+ parser.add_argument(
114
+ "--sentence-class-num",
115
+ type=int,
116
+ metavar="N",
117
+ help="number of classes for sentence task",
118
+ )
119
+ parser.add_argument(
120
+ "--sent-loss",
121
+ action="store_true",
122
+ help="if set," " calculate sentence level predictions",
123
+ )
124
+
125
+ # Arguments related to parameter initialization
126
+ parser.add_argument(
127
+ "--apply-bert-init",
128
+ action="store_true",
129
+ help="use custom param initialization for BERT",
130
+ )
131
+
132
+ # misc params
133
+ parser.add_argument(
134
+ "--activation-fn",
135
+ choices=utils.get_available_activation_fns(),
136
+ help="activation function to use",
137
+ )
138
+ parser.add_argument(
139
+ "--pooler-activation-fn",
140
+ choices=utils.get_available_activation_fns(),
141
+ help="Which activation function to use for pooler layer.",
142
+ )
143
+ parser.add_argument(
144
+ "--encoder-normalize-before",
145
+ action="store_true",
146
+ help="apply layernorm before each encoder block",
147
+ )
148
+
149
+ def forward(self, src_tokens, segment_labels=None, **kwargs):
150
+ return self.encoder(src_tokens, segment_labels=segment_labels, **kwargs)
151
+
152
+ def max_positions(self):
153
+ return self.encoder.max_positions
154
+
155
+ @classmethod
156
+ def build_model(cls, args, task):
157
+ """Build a new model instance."""
158
+ # make sure all arguments are present in older models
159
+ base_architecture(args)
160
+
161
+ if not hasattr(args, "max_positions"):
162
+ args.max_positions = args.tokens_per_sample
163
+
164
+ logger.info(args)
165
+
166
+ encoder = MaskedLMEncoder(args, task.dictionary)
167
+ return cls(args, encoder)
168
+
169
+
170
+ class MaskedLMEncoder(FairseqEncoder):
171
+ """
172
+ Encoder for Masked Language Modelling.
173
+ """
174
+
175
+ def __init__(self, args, dictionary):
176
+ super().__init__(dictionary)
177
+
178
+ self.padding_idx = dictionary.pad()
179
+ self.vocab_size = dictionary.__len__()
180
+ self.max_positions = args.max_positions
181
+
182
+ self.sentence_encoder = TransformerSentenceEncoder(
183
+ padding_idx=self.padding_idx,
184
+ vocab_size=self.vocab_size,
185
+ num_encoder_layers=args.encoder_layers,
186
+ embedding_dim=args.encoder_embed_dim,
187
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
188
+ num_attention_heads=args.encoder_attention_heads,
189
+ dropout=args.dropout,
190
+ attention_dropout=args.attention_dropout,
191
+ activation_dropout=args.act_dropout,
192
+ max_seq_len=self.max_positions,
193
+ num_segments=args.num_segment,
194
+ use_position_embeddings=not args.no_token_positional_embeddings,
195
+ encoder_normalize_before=args.encoder_normalize_before,
196
+ apply_bert_init=args.apply_bert_init,
197
+ activation_fn=args.activation_fn,
198
+ learned_pos_embedding=args.encoder_learned_pos,
199
+ )
200
+
201
+ self.share_input_output_embed = args.share_encoder_input_output_embed
202
+ self.embed_out = None
203
+ self.sentence_projection_layer = None
204
+ self.sentence_out_dim = args.sentence_class_num
205
+ self.lm_output_learned_bias = None
206
+
207
+ # Remove head is set to true during fine-tuning
208
+ self.load_softmax = not getattr(args, "remove_head", False)
209
+
210
+ self.masked_lm_pooler = nn.Linear(
211
+ args.encoder_embed_dim, args.encoder_embed_dim
212
+ )
213
+ self.pooler_activation = utils.get_activation_fn(args.pooler_activation_fn)
214
+
215
+ self.lm_head_transform_weight = nn.Linear(
216
+ args.encoder_embed_dim, args.encoder_embed_dim
217
+ )
218
+ self.activation_fn = utils.get_activation_fn(args.activation_fn)
219
+ self.layer_norm = LayerNorm(args.encoder_embed_dim)
220
+
221
+ self.lm_output_learned_bias = None
222
+ if self.load_softmax:
223
+ self.lm_output_learned_bias = nn.Parameter(torch.zeros(self.vocab_size))
224
+
225
+ if not self.share_input_output_embed:
226
+ self.embed_out = nn.Linear(
227
+ args.encoder_embed_dim, self.vocab_size, bias=False
228
+ )
229
+
230
+ if args.sent_loss:
231
+ self.sentence_projection_layer = nn.Linear(
232
+ args.encoder_embed_dim, self.sentence_out_dim, bias=False
233
+ )
234
+
235
+ def forward(self, src_tokens, segment_labels=None, masked_tokens=None, **unused):
236
+ """
237
+ Forward pass for Masked LM encoder. This first computes the token
238
+ embedding using the token embedding matrix, position embeddings (if
239
+ specified) and segment embeddings (if specified).
240
+
241
+ Here we assume that the sentence representation corresponds to the
242
+ output of the classification_token (see bert_task or cross_lingual_lm
243
+ task for more details).
244
+ Args:
245
+ - src_tokens: B x T matrix representing sentences
246
+ - segment_labels: B x T matrix representing segment label for tokens
247
+ Returns:
248
+ - a tuple of the following:
249
+ - logits for predictions in format B x T x C to be used in
250
+ softmax afterwards
251
+ - a dictionary of additional data, where 'pooled_output' contains
252
+ the representation for classification_token and 'inner_states'
253
+ is a list of internal model states used to compute the
254
+ predictions (similar in ELMO). 'sentence_logits'
255
+ is the prediction logit for NSP task and is only computed if
256
+ this is specified in the input arguments.
257
+ """
258
+
259
+ inner_states, sentence_rep = self.sentence_encoder(
260
+ src_tokens,
261
+ segment_labels=segment_labels,
262
+ )
263
+
264
+ x = inner_states[-1].transpose(0, 1)
265
+ # project masked tokens only
266
+ if masked_tokens is not None:
267
+ x = x[masked_tokens, :]
268
+ x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x)))
269
+
270
+ pooled_output = self.pooler_activation(self.masked_lm_pooler(sentence_rep))
271
+
272
+ # project back to size of vocabulary
273
+ if self.share_input_output_embed and hasattr(
274
+ self.sentence_encoder.embed_tokens, "weight"
275
+ ):
276
+ x = F.linear(x, self.sentence_encoder.embed_tokens.weight)
277
+ elif self.embed_out is not None:
278
+ x = self.embed_out(x)
279
+ if self.lm_output_learned_bias is not None:
280
+ x = x + self.lm_output_learned_bias
281
+ sentence_logits = None
282
+ if self.sentence_projection_layer:
283
+ sentence_logits = self.sentence_projection_layer(pooled_output)
284
+
285
+ return x, {
286
+ "inner_states": inner_states,
287
+ "pooled_output": pooled_output,
288
+ "sentence_logits": sentence_logits,
289
+ }
290
+
291
+ def max_positions(self):
292
+ """Maximum output length supported by the encoder."""
293
+ return self.max_positions
294
+
295
+ def upgrade_state_dict_named(self, state_dict, name):
296
+ if isinstance(
297
+ self.sentence_encoder.embed_positions, SinusoidalPositionalEmbedding
298
+ ):
299
+ state_dict[
300
+ name + ".sentence_encoder.embed_positions._float_tensor"
301
+ ] = torch.FloatTensor(1)
302
+ if not self.load_softmax:
303
+ for k in list(state_dict.keys()):
304
+ if (
305
+ "embed_out.weight" in k
306
+ or "sentence_projection_layer.weight" in k
307
+ or "lm_output_learned_bias" in k
308
+ ):
309
+ del state_dict[k]
310
+ return state_dict
311
+
312
+
313
+ @register_model_architecture("masked_lm", "masked_lm")
314
+ def base_architecture(args):
315
+ args.dropout = getattr(args, "dropout", 0.1)
316
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
317
+ args.act_dropout = getattr(args, "act_dropout", 0.0)
318
+
319
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
320
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
321
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
322
+
323
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
324
+ args.share_encoder_input_output_embed = getattr(
325
+ args, "share_encoder_input_output_embed", False
326
+ )
327
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
328
+ args.no_token_positional_embeddings = getattr(
329
+ args, "no_token_positional_embeddings", False
330
+ )
331
+ args.num_segment = getattr(args, "num_segment", 2)
332
+
333
+ args.sentence_class_num = getattr(args, "sentence_class_num", 2)
334
+ args.sent_loss = getattr(args, "sent_loss", False)
335
+
336
+ args.apply_bert_init = getattr(args, "apply_bert_init", False)
337
+
338
+ args.activation_fn = getattr(args, "activation_fn", "relu")
339
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
340
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
341
+
342
+
343
+ @register_model_architecture("masked_lm", "bert_base")
344
+ def bert_base_architecture(args):
345
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
346
+ args.share_encoder_input_output_embed = getattr(
347
+ args, "share_encoder_input_output_embed", True
348
+ )
349
+ args.no_token_positional_embeddings = getattr(
350
+ args, "no_token_positional_embeddings", False
351
+ )
352
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
353
+ args.num_segment = getattr(args, "num_segment", 2)
354
+
355
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
356
+
357
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
358
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
359
+
360
+ args.sentence_class_num = getattr(args, "sentence_class_num", 2)
361
+ args.sent_loss = getattr(args, "sent_loss", True)
362
+
363
+ args.apply_bert_init = getattr(args, "apply_bert_init", True)
364
+
365
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
366
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
367
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
368
+ base_architecture(args)
369
+
370
+
371
+ @register_model_architecture("masked_lm", "bert_large")
372
+ def bert_large_architecture(args):
373
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
374
+ args.encoder_layers = getattr(args, "encoder_layers", 24)
375
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
376
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
377
+ bert_base_architecture(args)
378
+
379
+
380
+ @register_model_architecture("masked_lm", "xlm_base")
381
+ def xlm_architecture(args):
382
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
383
+ args.share_encoder_input_output_embed = getattr(
384
+ args, "share_encoder_input_output_embed", True
385
+ )
386
+ args.no_token_positional_embeddings = getattr(
387
+ args, "no_token_positional_embeddings", False
388
+ )
389
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
390
+ args.num_segment = getattr(args, "num_segment", 1)
391
+
392
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
393
+
394
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
395
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
396
+
397
+ args.sent_loss = getattr(args, "sent_loss", False)
398
+
399
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
400
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
401
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
402
+ args.apply_bert_init = getattr(args, "apply_bert_init", True)
403
+ base_architecture(args)
fairseq-0.10.2/fairseq/models/model_utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import List, Optional
7
+
8
+ import torch
9
+ from torch import Tensor
10
+
11
+
12
+ @torch.jit.script
13
+ def script_skip_tensor_list(x: List[Tensor], mask):
14
+ res = [xi[mask] if xi.size(0) == mask.size(0) else xi[:, mask] for xi in x]
15
+ outputs = []
16
+ for i, t in enumerate(res):
17
+ if t.numel() != 0:
18
+ outputs.append(t)
19
+ else:
20
+ outputs.append(x[i])
21
+ return outputs
22
+
23
+
24
+ @torch.jit.script
25
+ def script_skip_tensor(x: Tensor, mask):
26
+ # None case
27
+ if x.size(0) == 0:
28
+ return x
29
+ res = x[mask] if x.size(0) == mask.size(0) else x[:, mask]
30
+ if res.numel() == 0:
31
+ return x
32
+ else:
33
+ return res
34
+
35
+
36
+ @torch.jit.script
37
+ def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):
38
+ """
39
+ Expand 2D/3D tensor on dim=1
40
+ """
41
+ if x is None:
42
+ return None
43
+
44
+ assert x.dim() == 2 or x.dim() == 3
45
+ assert trg_dim >= x.size(1), (trg_dim, x.size())
46
+ if trg_dim == x.size(1):
47
+ return x
48
+
49
+ dims = [x.size(0), trg_dim - x.size(1)]
50
+ if x.dim() == 3:
51
+ dims.append(x.size(2))
52
+ x = torch.cat([x, torch.zeros(dims).to(x).fill_(padding_idx)], 1)
53
+
54
+ return x
55
+
56
+
57
+ @torch.jit.script
58
+ def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor:
59
+ return x if x is not None else y
60
+
61
+
62
+ @torch.jit.script
63
+ def fill_tensors(
64
+ x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int
65
+ ) -> Optional[Tensor]:
66
+ """
67
+ Filling tensor x with y at masked positions (dim=0).
68
+ """
69
+ if x is None or x.size()[0] == 0 or y is None:
70
+ return x
71
+ assert x.dim() == y.dim() and mask.size(0) == x.size(0)
72
+ assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
73
+
74
+ n_selected = mask.sum()
75
+ if n_selected == 0:
76
+ return x
77
+ assert n_selected == y.size(0)
78
+ if n_selected == x.size(0):
79
+ return y
80
+
81
+ if x.size(1) < y.size(1):
82
+ x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx)
83
+ x[mask] = y
84
+ elif x.size(1) > y.size(1):
85
+ x[mask] = torch.tensor(padding_idx).type_as(x)
86
+ if x.dim() == 2:
87
+ x[mask, : y.size(1)] = y
88
+ else:
89
+ x[mask, : y.size(1), :] = y
90
+ else:
91
+ x[mask] = y
92
+ return x
fairseq-0.10.2/fairseq/models/multilingual_transformer.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+
8
+ from fairseq import utils
9
+ from fairseq.models import (
10
+ FairseqMultiModel,
11
+ register_model,
12
+ register_model_architecture,
13
+ )
14
+ from fairseq.models.transformer import (
15
+ Embedding,
16
+ TransformerDecoder,
17
+ TransformerEncoder,
18
+ TransformerModel,
19
+ base_architecture,
20
+ )
21
+
22
+
23
+ @register_model("multilingual_transformer")
24
+ class MultilingualTransformerModel(FairseqMultiModel):
25
+ """Train Transformer models for multiple language pairs simultaneously.
26
+
27
+ Requires `--task multilingual_translation`.
28
+
29
+ We inherit all arguments from TransformerModel and assume that all language
30
+ pairs use a single Transformer architecture. In addition, we provide several
31
+ options that are specific to the multilingual setting.
32
+
33
+ Args:
34
+ --share-encoder-embeddings: share encoder embeddings across all source languages
35
+ --share-decoder-embeddings: share decoder embeddings across all target languages
36
+ --share-encoders: share all encoder params (incl. embeddings) across all source languages
37
+ --share-decoders: share all decoder params (incl. embeddings) across all target languages
38
+ """
39
+
40
+ def __init__(self, encoders, decoders):
41
+ super().__init__(encoders, decoders)
42
+
43
+ @staticmethod
44
+ def add_args(parser):
45
+ """Add model-specific arguments to the parser."""
46
+ TransformerModel.add_args(parser)
47
+ parser.add_argument(
48
+ "--share-encoder-embeddings",
49
+ action="store_true",
50
+ help="share encoder embeddings across languages",
51
+ )
52
+ parser.add_argument(
53
+ "--share-decoder-embeddings",
54
+ action="store_true",
55
+ help="share decoder embeddings across languages",
56
+ )
57
+ parser.add_argument(
58
+ "--share-encoders",
59
+ action="store_true",
60
+ help="share encoders across languages",
61
+ )
62
+ parser.add_argument(
63
+ "--share-decoders",
64
+ action="store_true",
65
+ help="share decoders across languages",
66
+ )
67
+
68
+ @classmethod
69
+ def build_model(cls, args, task):
70
+ """Build a new model instance."""
71
+ from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
72
+
73
+ assert isinstance(task, MultilingualTranslationTask)
74
+
75
+ # make sure all arguments are present in older models
76
+ base_multilingual_architecture(args)
77
+
78
+ if not hasattr(args, "max_source_positions"):
79
+ args.max_source_positions = 1024
80
+ if not hasattr(args, "max_target_positions"):
81
+ args.max_target_positions = 1024
82
+
83
+ src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs]
84
+ tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs]
85
+
86
+ if args.share_encoders:
87
+ args.share_encoder_embeddings = True
88
+ if args.share_decoders:
89
+ args.share_decoder_embeddings = True
90
+
91
+ def build_embedding(dictionary, embed_dim, path=None):
92
+ num_embeddings = len(dictionary)
93
+ padding_idx = dictionary.pad()
94
+ emb = Embedding(num_embeddings, embed_dim, padding_idx)
95
+ # if provided, load from preloaded dictionaries
96
+ if path:
97
+ embed_dict = utils.parse_embedding(path)
98
+ utils.load_embedding(embed_dict, dictionary, emb)
99
+ return emb
100
+
101
+ # build shared embeddings (if applicable)
102
+ shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
103
+ if args.share_all_embeddings:
104
+ if args.encoder_embed_dim != args.decoder_embed_dim:
105
+ raise ValueError(
106
+ "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
107
+ )
108
+ if args.decoder_embed_path and (
109
+ args.decoder_embed_path != args.encoder_embed_path
110
+ ):
111
+ raise ValueError(
112
+ "--share-all-embeddings not compatible with --decoder-embed-path"
113
+ )
114
+ shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
115
+ dicts=task.dicts,
116
+ langs=task.langs,
117
+ embed_dim=args.encoder_embed_dim,
118
+ build_embedding=build_embedding,
119
+ pretrained_embed_path=args.encoder_embed_path,
120
+ )
121
+ shared_decoder_embed_tokens = shared_encoder_embed_tokens
122
+ args.share_decoder_input_output_embed = True
123
+ else:
124
+ if args.share_encoder_embeddings:
125
+ shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
126
+ dicts=task.dicts,
127
+ langs=src_langs,
128
+ embed_dim=args.encoder_embed_dim,
129
+ build_embedding=build_embedding,
130
+ pretrained_embed_path=args.encoder_embed_path,
131
+ )
132
+ if args.share_decoder_embeddings:
133
+ shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
134
+ dicts=task.dicts,
135
+ langs=tgt_langs,
136
+ embed_dim=args.decoder_embed_dim,
137
+ build_embedding=build_embedding,
138
+ pretrained_embed_path=args.decoder_embed_path,
139
+ )
140
+
141
+ # encoders/decoders for each language
142
+ lang_encoders, lang_decoders = {}, {}
143
+
144
+ def get_encoder(lang):
145
+ if lang not in lang_encoders:
146
+ if shared_encoder_embed_tokens is not None:
147
+ encoder_embed_tokens = shared_encoder_embed_tokens
148
+ else:
149
+ encoder_embed_tokens = build_embedding(
150
+ task.dicts[lang],
151
+ args.encoder_embed_dim,
152
+ args.encoder_embed_path,
153
+ )
154
+ lang_encoders[lang] = cls._get_module_class(
155
+ True, args, task.dicts[lang], encoder_embed_tokens, src_langs
156
+ )
157
+ return lang_encoders[lang]
158
+
159
+ def get_decoder(lang):
160
+ if lang not in lang_decoders:
161
+ if shared_decoder_embed_tokens is not None:
162
+ decoder_embed_tokens = shared_decoder_embed_tokens
163
+ else:
164
+ decoder_embed_tokens = build_embedding(
165
+ task.dicts[lang],
166
+ args.decoder_embed_dim,
167
+ args.decoder_embed_path,
168
+ )
169
+ lang_decoders[lang] = cls._get_module_class(
170
+ False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs
171
+ )
172
+ return lang_decoders[lang]
173
+
174
+ # shared encoders/decoders (if applicable)
175
+ shared_encoder, shared_decoder = None, None
176
+ if args.share_encoders:
177
+ shared_encoder = get_encoder(src_langs[0])
178
+ if args.share_decoders:
179
+ shared_decoder = get_decoder(tgt_langs[0])
180
+
181
+ encoders, decoders = OrderedDict(), OrderedDict()
182
+ for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs):
183
+ encoders[lang_pair] = (
184
+ shared_encoder if shared_encoder is not None else get_encoder(src)
185
+ )
186
+ decoders[lang_pair] = (
187
+ shared_decoder if shared_decoder is not None else get_decoder(tgt)
188
+ )
189
+
190
+ return MultilingualTransformerModel(encoders, decoders)
191
+
192
+ @classmethod
193
+ def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
194
+ module_class = TransformerEncoder if is_encoder else TransformerDecoder
195
+ return module_class(args, lang_dict, embed_tokens)
196
+
197
+ def load_state_dict(self, state_dict, strict=True, args=None):
198
+ state_dict_subset = state_dict.copy()
199
+ for k, _ in state_dict.items():
200
+ assert k.startswith("models.")
201
+ lang_pair = k.split(".")[1]
202
+ if lang_pair not in self.models:
203
+ del state_dict_subset[k]
204
+ super().load_state_dict(state_dict_subset, strict=strict, args=args)
205
+
206
+
207
+ @register_model_architecture("multilingual_transformer", "multilingual_transformer")
208
+ def base_multilingual_architecture(args):
209
+ base_architecture(args)
210
+ args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False)
211
+ args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False)
212
+ args.share_encoders = getattr(args, "share_encoders", False)
213
+ args.share_decoders = getattr(args, "share_decoders", False)
214
+
215
+
216
+ @register_model_architecture(
217
+ "multilingual_transformer", "multilingual_transformer_iwslt_de_en"
218
+ )
219
+ def multilingual_transformer_iwslt_de_en(args):
220
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
221
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
222
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
223
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
224
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
225
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
226
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
227
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
228
+ base_multilingual_architecture(args)
fairseq-0.10.2/fairseq/models/roberta/__pycache__/hub_interface.cpython-310.pyc ADDED
Binary file (8.12 kB). View file
 
fairseq-0.10.2/fairseq/models/roberta/__pycache__/model_camembert.cpython-310.pyc ADDED
Binary file (1.85 kB). View file
 
fairseq-0.10.2/fairseq/models/roberta/__pycache__/model_xlmr.cpython-310.pyc ADDED
Binary file (1.37 kB). View file
 
fairseq-0.10.2/fairseq/models/roberta/model_xlmr.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ Unsupervised Cross-lingual Representation Learning at Scale
7
+ """
8
+
9
+ from fairseq.models import register_model
10
+
11
+ from .hub_interface import RobertaHubInterface
12
+ from .model import RobertaModel
13
+
14
+
15
+ @register_model("xlmr")
16
+ class XLMRModel(RobertaModel):
17
+ @classmethod
18
+ def hub_models(cls):
19
+ return {
20
+ "xlmr.base": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz",
21
+ "xlmr.large": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz",
22
+ }
23
+
24
+ @classmethod
25
+ def from_pretrained(
26
+ cls,
27
+ model_name_or_path,
28
+ checkpoint_file="model.pt",
29
+ data_name_or_path=".",
30
+ bpe="sentencepiece",
31
+ **kwargs
32
+ ):
33
+ from fairseq import hub_utils
34
+
35
+ x = hub_utils.from_pretrained(
36
+ model_name_or_path,
37
+ checkpoint_file,
38
+ data_name_or_path,
39
+ archive_map=cls.hub_models(),
40
+ bpe=bpe,
41
+ load_checkpoint_heads=True,
42
+ **kwargs,
43
+ )
44
+ return RobertaHubInterface(x["args"], x["task"], x["models"][0])
fairseq-0.10.2/fairseq/modules/__pycache__/beamable_mm.cpython-310.pyc ADDED
Binary file (1.64 kB). View file