Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq-0.10.2/fairseq/criterions/__init__.py +38 -0
- fairseq-0.10.2/fairseq/criterions/__pycache__/ctc.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/criterions/adaptive_loss.py +123 -0
- fairseq-0.10.2/fairseq/criterions/composite_loss.py +100 -0
- fairseq-0.10.2/fairseq/criterions/ctc.py +253 -0
- fairseq-0.10.2/fairseq/criterions/fairseq_criterion.py +119 -0
- fairseq-0.10.2/fairseq/criterions/legacy_masked_lm.py +177 -0
- fairseq-0.10.2/fairseq/criterions/sentence_prediction.py +99 -0
- fairseq-0.10.2/fairseq/criterions/sentence_ranking.py +120 -0
- fairseq-0.10.2/fairseq/model_parallel/models/__pycache__/transformer.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py +721 -0
- fairseq-0.10.2/fairseq/model_parallel/models/roberta/__init__.py +6 -0
- fairseq-0.10.2/fairseq/model_parallel/models/roberta/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/distributed_fairseq_model.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/fairseq_decoder.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/fairseq_encoder.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/fairseq_model.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/fconv.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/fconv_lm.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/fconv_self_att.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/lightconv.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/lightconv_lm.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/lstm.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/lstm_lm.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/masked_lm.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/model_utils.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/multilingual_transformer.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/transformer.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/transformer_lm.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/bart/__init__.py +7 -0
- fairseq-0.10.2/fairseq/models/bart/hub_interface.py +201 -0
- fairseq-0.10.2/fairseq/models/distributed_fairseq_model.py +103 -0
- fairseq-0.10.2/fairseq/models/fairseq_decoder.py +90 -0
- fairseq-0.10.2/fairseq/models/fairseq_encoder.py +92 -0
- fairseq-0.10.2/fairseq/models/fairseq_model.py +556 -0
- fairseq-0.10.2/fairseq/models/huggingface/hf_gpt2.py +168 -0
- fairseq-0.10.2/fairseq/models/lightconv_lm.py +306 -0
- fairseq-0.10.2/fairseq/models/masked_lm.py +403 -0
- fairseq-0.10.2/fairseq/models/model_utils.py +92 -0
- fairseq-0.10.2/fairseq/models/multilingual_transformer.py +228 -0
- fairseq-0.10.2/fairseq/models/roberta/__pycache__/hub_interface.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/roberta/__pycache__/model_camembert.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/roberta/__pycache__/model_xlmr.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/roberta/model_xlmr.py +44 -0
- fairseq-0.10.2/fairseq/modules/__pycache__/beamable_mm.cpython-310.pyc +0 -0
fairseq-0.10.2/fairseq/criterions/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""isort:skip_file"""
|
| 6 |
+
|
| 7 |
+
import importlib
|
| 8 |
+
import os
|
| 9 |
+
from argparse import Namespace
|
| 10 |
+
from typing import Union
|
| 11 |
+
|
| 12 |
+
from fairseq import registry
|
| 13 |
+
from fairseq.criterions.fairseq_criterion import ( # noqa
|
| 14 |
+
FairseqCriterion,
|
| 15 |
+
LegacyFairseqCriterion,
|
| 16 |
+
)
|
| 17 |
+
from omegaconf import DictConfig
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
(
|
| 21 |
+
build_criterion_,
|
| 22 |
+
register_criterion,
|
| 23 |
+
CRITERION_REGISTRY,
|
| 24 |
+
CRITERION_DATACLASS_REGISTRY,
|
| 25 |
+
) = registry.setup_registry(
|
| 26 |
+
"--criterion", base_class=FairseqCriterion, default="cross_entropy"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def build_criterion(criterion_cfg: Union[DictConfig, Namespace], task):
|
| 31 |
+
return build_criterion_(criterion_cfg, task)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# automatically import any Python files in the criterions/ directory
|
| 35 |
+
for file in os.listdir(os.path.dirname(__file__)):
|
| 36 |
+
if file.endswith(".py") and not file.startswith("_"):
|
| 37 |
+
file_name = file[: file.find(".py")]
|
| 38 |
+
importlib.import_module("fairseq.criterions." + file_name)
|
fairseq-0.10.2/fairseq/criterions/__pycache__/ctc.cpython-310.pyc
ADDED
|
Binary file (7.02 kB). View file
|
|
|
fairseq-0.10.2/fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc
ADDED
|
Binary file (4.36 kB). View file
|
|
|
fairseq-0.10.2/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc
ADDED
|
Binary file (5.46 kB). View file
|
|
|
fairseq-0.10.2/fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc
ADDED
|
Binary file (3.88 kB). View file
|
|
|
fairseq-0.10.2/fairseq/criterions/adaptive_loss.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from fairseq import metrics, utils
|
| 11 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
| 12 |
+
from fairseq.dataclass import FairseqDataclass
|
| 13 |
+
from fairseq.dataclass.constants import DDP_BACKEND_CHOICES
|
| 14 |
+
from omegaconf import II
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class AdaptiveLossConfig(FairseqDataclass):
|
| 19 |
+
sentence_avg: bool = II("params.optimization.sentence_avg")
|
| 20 |
+
ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig)
|
| 24 |
+
class AdaptiveLoss(FairseqCriterion):
|
| 25 |
+
"""This is an implementation of the loss function accompanying the adaptive softmax approximation for
|
| 26 |
+
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
|
| 27 |
+
(http://arxiv.org/abs/1609.04309)."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, task, sentence_avg):
|
| 30 |
+
super().__init__(task)
|
| 31 |
+
self.sentence_avg = sentence_avg
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def build_criterion(cls, args, task):
|
| 35 |
+
if getattr(args, "ddp_backend", None) == "c10d":
|
| 36 |
+
raise Exception(
|
| 37 |
+
"AdaptiveLoss is not compatible with the c10d "
|
| 38 |
+
"version of DistributedDataParallel. Please use "
|
| 39 |
+
"`--ddp-backend=no_c10d` instead."
|
| 40 |
+
)
|
| 41 |
+
return cls(task, args.sentence_avg)
|
| 42 |
+
|
| 43 |
+
def forward(self, model, sample, reduce=True):
|
| 44 |
+
"""Compute the loss for the given sample.
|
| 45 |
+
|
| 46 |
+
Returns a tuple with three elements:
|
| 47 |
+
1) the loss
|
| 48 |
+
2) the sample size, which is used as the denominator for the gradient
|
| 49 |
+
3) logging outputs to display while training
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
assert (
|
| 53 |
+
hasattr(model.decoder, "adaptive_softmax")
|
| 54 |
+
and model.decoder.adaptive_softmax is not None
|
| 55 |
+
)
|
| 56 |
+
adaptive_softmax = model.decoder.adaptive_softmax
|
| 57 |
+
|
| 58 |
+
net_output = model(**sample["net_input"])
|
| 59 |
+
orig_target = model.get_targets(sample, net_output)
|
| 60 |
+
|
| 61 |
+
nsentences = orig_target.size(0)
|
| 62 |
+
orig_target = orig_target.view(-1)
|
| 63 |
+
|
| 64 |
+
bsz = orig_target.size(0)
|
| 65 |
+
|
| 66 |
+
logits, target = adaptive_softmax(net_output[0], orig_target)
|
| 67 |
+
assert len(target) == len(logits)
|
| 68 |
+
|
| 69 |
+
loss = net_output[0].new(1 if reduce else bsz).zero_()
|
| 70 |
+
|
| 71 |
+
for i in range(len(target)):
|
| 72 |
+
if target[i] is not None:
|
| 73 |
+
assert target[i].min() >= 0 and target[i].max() <= logits[i].size(1)
|
| 74 |
+
loss += F.cross_entropy(
|
| 75 |
+
logits[i],
|
| 76 |
+
target[i],
|
| 77 |
+
ignore_index=self.padding_idx,
|
| 78 |
+
reduction="sum" if reduce else "none",
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
orig = utils.strip_pad(orig_target, self.padding_idx)
|
| 82 |
+
ntokens = orig.numel()
|
| 83 |
+
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
|
| 84 |
+
logging_output = {
|
| 85 |
+
"loss": loss.data,
|
| 86 |
+
"ntokens": ntokens,
|
| 87 |
+
"nsentences": nsentences,
|
| 88 |
+
"sample_size": sample_size,
|
| 89 |
+
}
|
| 90 |
+
return loss, sample_size, logging_output
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
def reduce_metrics(logging_outputs) -> None:
|
| 94 |
+
"""Aggregate logging outputs from data parallel training."""
|
| 95 |
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
| 96 |
+
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
| 97 |
+
sample_size = utils.item(
|
| 98 |
+
sum(log.get("sample_size", 0) for log in logging_outputs)
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
metrics.log_scalar(
|
| 102 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
| 103 |
+
)
|
| 104 |
+
if sample_size != ntokens:
|
| 105 |
+
metrics.log_scalar(
|
| 106 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
| 107 |
+
)
|
| 108 |
+
metrics.log_derived(
|
| 109 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
metrics.log_derived(
|
| 113 |
+
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def logging_outputs_can_be_summed() -> bool:
|
| 118 |
+
"""
|
| 119 |
+
Whether the logging outputs returned by `forward` can be summed
|
| 120 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
| 121 |
+
to True will improves distributed training speed.
|
| 122 |
+
"""
|
| 123 |
+
return True
|
fairseq-0.10.2/fairseq/criterions/composite_loss.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from fairseq import utils
|
| 7 |
+
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@register_criterion("composite_loss")
|
| 12 |
+
class CompositeLoss(LegacyFairseqCriterion):
|
| 13 |
+
"""This is a composite loss that, given a list of model outputs and a list of targets,
|
| 14 |
+
computes an average of losses for each output-target pair"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, args, task):
|
| 17 |
+
super().__init__(args, task)
|
| 18 |
+
self.underlying_criterion = args.underlying_criterion
|
| 19 |
+
|
| 20 |
+
@staticmethod
|
| 21 |
+
def add_args(parser):
|
| 22 |
+
"""Add criterion-specific arguments to the parser."""
|
| 23 |
+
# fmt: off
|
| 24 |
+
parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True,
|
| 25 |
+
help='underlying criterion to use for the composite loss')
|
| 26 |
+
# fmt: on
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
def build_underlying_criterion(args, task):
|
| 30 |
+
saved_criterion = args.criterion
|
| 31 |
+
args.criterion = args.underlying_criterion
|
| 32 |
+
assert saved_criterion != args.underlying_criterion
|
| 33 |
+
underlying_criterion = task.build_criterion(args)
|
| 34 |
+
args.criterion = saved_criterion
|
| 35 |
+
return underlying_criterion
|
| 36 |
+
|
| 37 |
+
@classmethod
|
| 38 |
+
def build_criterion(cls, args, task):
|
| 39 |
+
underlying_criterion = CompositeLoss.build_underlying_criterion(args, task)
|
| 40 |
+
|
| 41 |
+
class FakeModel(nn.Module):
|
| 42 |
+
def __init__(self, model, net_out, target):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.model = model
|
| 45 |
+
self.net_out = net_out
|
| 46 |
+
self.target = target
|
| 47 |
+
|
| 48 |
+
def forward(self, **unused):
|
| 49 |
+
return self.net_out
|
| 50 |
+
|
| 51 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
| 52 |
+
return self.model.get_normalized_probs(
|
| 53 |
+
net_output, log_probs, sample=sample
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def get_targets(self, *unused):
|
| 57 |
+
return self.target
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def decoder(self):
|
| 61 |
+
return self.model.decoder
|
| 62 |
+
|
| 63 |
+
class _CompositeLoss(LegacyFairseqCriterion):
|
| 64 |
+
def __init__(self, args, task, underlying_criterion):
|
| 65 |
+
super().__init__(args, task)
|
| 66 |
+
self.underlying_criterion = underlying_criterion
|
| 67 |
+
|
| 68 |
+
def forward(self, model, sample, reduce=True):
|
| 69 |
+
net_outputs = model(**sample["net_input"])
|
| 70 |
+
targets = sample["target"]
|
| 71 |
+
|
| 72 |
+
bsz = targets[0].size(0)
|
| 73 |
+
loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_()
|
| 74 |
+
|
| 75 |
+
sample_size = 0
|
| 76 |
+
logging_output = {}
|
| 77 |
+
for o, t in zip(net_outputs[0], targets):
|
| 78 |
+
m = FakeModel(model, (o, net_outputs[1]), t)
|
| 79 |
+
sample["target"] = t
|
| 80 |
+
l, ss, logging_output = self.underlying_criterion(m, sample, reduce)
|
| 81 |
+
loss += l
|
| 82 |
+
sample_size += ss
|
| 83 |
+
|
| 84 |
+
loss.div_(len(targets))
|
| 85 |
+
sample_size /= len(targets)
|
| 86 |
+
|
| 87 |
+
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
|
| 88 |
+
return loss, sample_size, logging_output
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def aggregate_logging_outputs(logging_outputs):
|
| 92 |
+
return underlying_criterion.__class__.aggregate_logging_outputs(
|
| 93 |
+
logging_outputs
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
def reduce_metrics(logging_outputs) -> None:
|
| 98 |
+
underlying_criterion.__class__.reduce_metrics(logging_outputs)
|
| 99 |
+
|
| 100 |
+
return _CompositeLoss(args, task, underlying_criterion)
|
fairseq-0.10.2/fairseq/criterions/ctc.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the license found in the LICENSE file in
|
| 4 |
+
# the root directory of this source tree. An additional grant of patent rights
|
| 5 |
+
# can be found in the PATENTS file in the same directory.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from argparse import Namespace
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from fairseq import metrics, utils
|
| 13 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
| 14 |
+
from fairseq.data.data_utils import post_process
|
| 15 |
+
from fairseq.logging.meters import safe_round
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@register_criterion("ctc")
|
| 19 |
+
class CtcCriterion(FairseqCriterion):
|
| 20 |
+
def __init__(self, task, wer_args, zero_infinity, sentence_avg, remove_bpe):
|
| 21 |
+
super().__init__(task)
|
| 22 |
+
self.blank_idx = task.target_dictionary.bos()
|
| 23 |
+
self.pad_idx = task.target_dictionary.pad()
|
| 24 |
+
self.eos_idx = task.target_dictionary.eos()
|
| 25 |
+
self.post_process = remove_bpe if remove_bpe else "letter"
|
| 26 |
+
|
| 27 |
+
if wer_args is not None:
|
| 28 |
+
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
|
| 29 |
+
|
| 30 |
+
wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(wer_args)
|
| 31 |
+
|
| 32 |
+
dec_args = Namespace()
|
| 33 |
+
dec_args.nbest = 1
|
| 34 |
+
dec_args.criterion = "ctc"
|
| 35 |
+
dec_args.kenlm_model = wer_compute_kenlm
|
| 36 |
+
dec_args.lexicon = wer_lexicon
|
| 37 |
+
dec_args.beam = 50
|
| 38 |
+
dec_args.beam_size_token = min(50, len(task.target_dictionary))
|
| 39 |
+
dec_args.beam_threshold = min(50, len(task.target_dictionary))
|
| 40 |
+
dec_args.lm_weight = lm_w
|
| 41 |
+
dec_args.word_score = ws_w
|
| 42 |
+
dec_args.unk_weight = -math.inf
|
| 43 |
+
dec_args.sil_weight = 0
|
| 44 |
+
|
| 45 |
+
self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
|
| 46 |
+
else:
|
| 47 |
+
self.w2l_decoder = None
|
| 48 |
+
|
| 49 |
+
self.zero_infinity = zero_infinity
|
| 50 |
+
self.sentence_avg = sentence_avg
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def add_args(parser):
|
| 54 |
+
"""Add criterion-specific arguments to the parser."""
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--zero-infinity", action="store_true", help="zero inf loss"
|
| 57 |
+
)
|
| 58 |
+
try:
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--remove-bpe",
|
| 61 |
+
"--post-process",
|
| 62 |
+
default="letter",
|
| 63 |
+
help="remove BPE tokens before scoring (can be set to sentencepiece, letter, and more)",
|
| 64 |
+
)
|
| 65 |
+
except:
|
| 66 |
+
pass # this option might have been added from eval args
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--wer-args",
|
| 69 |
+
type=str,
|
| 70 |
+
default=None,
|
| 71 |
+
help="options for wer computation on valid set using 4 gram lm. this should be a tuple of 4 elements: path to 4-gram lm, \
|
| 72 |
+
path to lexicon, lm score, word score",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def forward(self, model, sample, reduce=True):
|
| 76 |
+
net_output = model(**sample["net_input"])
|
| 77 |
+
lprobs = model.get_normalized_probs(
|
| 78 |
+
net_output, log_probs=True
|
| 79 |
+
).contiguous() # (T, B, C) from the encoder
|
| 80 |
+
|
| 81 |
+
if "src_lengths" in sample["net_input"]:
|
| 82 |
+
input_lengths = sample["net_input"]["src_lengths"]
|
| 83 |
+
else:
|
| 84 |
+
non_padding_mask = ~net_output["padding_mask"]
|
| 85 |
+
input_lengths = non_padding_mask.long().sum(-1)
|
| 86 |
+
|
| 87 |
+
pad_mask = (sample["target"] != self.pad_idx) & (
|
| 88 |
+
sample["target"] != self.eos_idx
|
| 89 |
+
)
|
| 90 |
+
targets_flat = sample["target"].masked_select(pad_mask)
|
| 91 |
+
target_lengths = sample["target_lengths"]
|
| 92 |
+
|
| 93 |
+
with torch.backends.cudnn.flags(enabled=False):
|
| 94 |
+
loss = F.ctc_loss(
|
| 95 |
+
lprobs,
|
| 96 |
+
targets_flat,
|
| 97 |
+
input_lengths,
|
| 98 |
+
target_lengths,
|
| 99 |
+
blank=self.blank_idx,
|
| 100 |
+
reduction="sum",
|
| 101 |
+
zero_infinity=self.zero_infinity,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
ntokens = (
|
| 105 |
+
sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
|
| 109 |
+
logging_output = {
|
| 110 |
+
"loss": utils.item(loss.data), # * sample['ntokens'],
|
| 111 |
+
"ntokens": ntokens,
|
| 112 |
+
"nsentences": sample["id"].numel(),
|
| 113 |
+
"sample_size": sample_size,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
if not model.training:
|
| 117 |
+
import editdistance
|
| 118 |
+
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
|
| 121 |
+
|
| 122 |
+
c_err = 0
|
| 123 |
+
c_len = 0
|
| 124 |
+
w_errs = 0
|
| 125 |
+
w_len = 0
|
| 126 |
+
wv_errs = 0
|
| 127 |
+
for lp, t, inp_l in zip(
|
| 128 |
+
lprobs_t,
|
| 129 |
+
sample["target_label"]
|
| 130 |
+
if "target_label" in sample
|
| 131 |
+
else sample["target"],
|
| 132 |
+
input_lengths,
|
| 133 |
+
):
|
| 134 |
+
lp = lp[:inp_l].unsqueeze(0)
|
| 135 |
+
|
| 136 |
+
decoded = None
|
| 137 |
+
if self.w2l_decoder is not None:
|
| 138 |
+
decoded = self.w2l_decoder.decode(lp)
|
| 139 |
+
if len(decoded) < 1:
|
| 140 |
+
decoded = None
|
| 141 |
+
else:
|
| 142 |
+
decoded = decoded[0]
|
| 143 |
+
if len(decoded) < 1:
|
| 144 |
+
decoded = None
|
| 145 |
+
else:
|
| 146 |
+
decoded = decoded[0]
|
| 147 |
+
|
| 148 |
+
p = (t != self.task.target_dictionary.pad()) & (
|
| 149 |
+
t != self.task.target_dictionary.eos()
|
| 150 |
+
)
|
| 151 |
+
targ = t[p]
|
| 152 |
+
targ_units = self.task.target_dictionary.string(targ)
|
| 153 |
+
targ_units_arr = targ.tolist()
|
| 154 |
+
|
| 155 |
+
toks = lp.argmax(dim=-1).unique_consecutive()
|
| 156 |
+
pred_units_arr = toks[toks != self.blank_idx].tolist()
|
| 157 |
+
|
| 158 |
+
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
|
| 159 |
+
c_len += len(targ_units_arr)
|
| 160 |
+
|
| 161 |
+
targ_words = post_process(targ_units, self.post_process).split()
|
| 162 |
+
|
| 163 |
+
pred_units = self.task.target_dictionary.string(pred_units_arr)
|
| 164 |
+
pred_words_raw = post_process(pred_units, self.post_process).split()
|
| 165 |
+
|
| 166 |
+
if decoded is not None and "words" in decoded:
|
| 167 |
+
pred_words = decoded["words"]
|
| 168 |
+
w_errs += editdistance.eval(pred_words, targ_words)
|
| 169 |
+
wv_errs += editdistance.eval(pred_words_raw, targ_words)
|
| 170 |
+
else:
|
| 171 |
+
dist = editdistance.eval(pred_words_raw, targ_words)
|
| 172 |
+
w_errs += dist
|
| 173 |
+
wv_errs += dist
|
| 174 |
+
|
| 175 |
+
w_len += len(targ_words)
|
| 176 |
+
|
| 177 |
+
logging_output["wv_errors"] = wv_errs
|
| 178 |
+
logging_output["w_errors"] = w_errs
|
| 179 |
+
logging_output["w_total"] = w_len
|
| 180 |
+
logging_output["c_errors"] = c_err
|
| 181 |
+
logging_output["c_total"] = c_len
|
| 182 |
+
|
| 183 |
+
return loss, sample_size, logging_output
|
| 184 |
+
|
| 185 |
+
@staticmethod
|
| 186 |
+
def reduce_metrics(logging_outputs) -> None:
|
| 187 |
+
"""Aggregate logging outputs from data parallel training."""
|
| 188 |
+
|
| 189 |
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
| 190 |
+
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
| 191 |
+
nsentences = utils.item(
|
| 192 |
+
sum(log.get("nsentences", 0) for log in logging_outputs)
|
| 193 |
+
)
|
| 194 |
+
sample_size = utils.item(
|
| 195 |
+
sum(log.get("sample_size", 0) for log in logging_outputs)
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
metrics.log_scalar(
|
| 199 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
| 200 |
+
)
|
| 201 |
+
metrics.log_scalar("ntokens", ntokens)
|
| 202 |
+
metrics.log_scalar("nsentences", nsentences)
|
| 203 |
+
if sample_size != ntokens:
|
| 204 |
+
metrics.log_scalar(
|
| 205 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
|
| 209 |
+
metrics.log_scalar("_c_errors", c_errors)
|
| 210 |
+
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
|
| 211 |
+
metrics.log_scalar("_c_total", c_total)
|
| 212 |
+
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
|
| 213 |
+
metrics.log_scalar("_w_errors", w_errors)
|
| 214 |
+
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
|
| 215 |
+
metrics.log_scalar("_wv_errors", wv_errors)
|
| 216 |
+
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
|
| 217 |
+
metrics.log_scalar("_w_total", w_total)
|
| 218 |
+
|
| 219 |
+
if c_total > 0:
|
| 220 |
+
metrics.log_derived(
|
| 221 |
+
"uer",
|
| 222 |
+
lambda meters: safe_round(
|
| 223 |
+
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
|
| 224 |
+
)
|
| 225 |
+
if meters["_c_total"].sum > 0
|
| 226 |
+
else float("nan"),
|
| 227 |
+
)
|
| 228 |
+
if w_total > 0:
|
| 229 |
+
metrics.log_derived(
|
| 230 |
+
"wer",
|
| 231 |
+
lambda meters: safe_round(
|
| 232 |
+
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
| 233 |
+
)
|
| 234 |
+
if meters["_w_total"].sum > 0
|
| 235 |
+
else float("nan"),
|
| 236 |
+
)
|
| 237 |
+
metrics.log_derived(
|
| 238 |
+
"raw_wer",
|
| 239 |
+
lambda meters: safe_round(
|
| 240 |
+
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
| 241 |
+
)
|
| 242 |
+
if meters["_w_total"].sum > 0
|
| 243 |
+
else float("nan"),
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def logging_outputs_can_be_summed() -> bool:
|
| 248 |
+
"""
|
| 249 |
+
Whether the logging outputs returned by `forward` can be summed
|
| 250 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
| 251 |
+
to True will improves distributed training speed.
|
| 252 |
+
"""
|
| 253 |
+
return True
|
fairseq-0.10.2/fairseq/criterions/fairseq_criterion.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import inspect
|
| 7 |
+
from typing import Any, Dict, List
|
| 8 |
+
|
| 9 |
+
from fairseq import metrics, utils
|
| 10 |
+
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
| 11 |
+
from torch.nn.modules.loss import _Loss
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class FairseqCriterion(_Loss):
|
| 15 |
+
def __init__(self, task):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.task = task
|
| 18 |
+
if hasattr(task, "target_dictionary"):
|
| 19 |
+
tgt_dict = task.target_dictionary
|
| 20 |
+
self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def add_args(cls, parser):
|
| 24 |
+
"""Add criterion-specific arguments to the parser."""
|
| 25 |
+
dc = getattr(cls, "__dataclass", None)
|
| 26 |
+
if dc is not None:
|
| 27 |
+
gen_parser_from_dataclass(parser, dc())
|
| 28 |
+
|
| 29 |
+
@classmethod
|
| 30 |
+
def build_criterion(cls, args, task):
|
| 31 |
+
"""Construct a criterion from command-line args."""
|
| 32 |
+
# Criterions can override this, but for convenience we also try
|
| 33 |
+
# to automatically map argparse.Namespace keys to corresponding
|
| 34 |
+
# arguments in the __init__.
|
| 35 |
+
init_args = {}
|
| 36 |
+
for p in inspect.signature(cls).parameters.values():
|
| 37 |
+
if (
|
| 38 |
+
p.kind == p.POSITIONAL_ONLY
|
| 39 |
+
or p.kind == p.VAR_POSITIONAL
|
| 40 |
+
or p.kind == p.VAR_KEYWORD
|
| 41 |
+
):
|
| 42 |
+
# we haven't implemented inference for these argument types,
|
| 43 |
+
# but PRs welcome :)
|
| 44 |
+
raise NotImplementedError("{} not supported".format(p.kind))
|
| 45 |
+
|
| 46 |
+
assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
|
| 47 |
+
|
| 48 |
+
if p.name == "task":
|
| 49 |
+
init_args["task"] = task
|
| 50 |
+
elif hasattr(args, p.name):
|
| 51 |
+
init_args[p.name] = getattr(args, p.name)
|
| 52 |
+
elif p.default != p.empty:
|
| 53 |
+
pass # we'll use the default value
|
| 54 |
+
else:
|
| 55 |
+
raise NotImplementedError(
|
| 56 |
+
"Unable to infer Criterion arguments, please implement "
|
| 57 |
+
"{}.build_criterion".format(cls.__name__)
|
| 58 |
+
)
|
| 59 |
+
return cls(**init_args)
|
| 60 |
+
|
| 61 |
+
def forward(self, model, sample, reduce=True):
|
| 62 |
+
"""Compute the loss for the given sample.
|
| 63 |
+
|
| 64 |
+
Returns a tuple with three elements:
|
| 65 |
+
1) the loss
|
| 66 |
+
2) the sample size, which is used as the denominator for the gradient
|
| 67 |
+
3) logging outputs to display while training
|
| 68 |
+
"""
|
| 69 |
+
raise NotImplementedError
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def aggregate_logging_outputs(
|
| 73 |
+
logging_outputs: List[Dict[str, Any]],
|
| 74 |
+
) -> Dict[str, Any]:
|
| 75 |
+
"""Aggregate logging outputs from data parallel training."""
|
| 76 |
+
utils.deprecation_warning(
|
| 77 |
+
"The aggregate_logging_outputs API is deprecated. "
|
| 78 |
+
"Please use the reduce_metrics API instead."
|
| 79 |
+
)
|
| 80 |
+
raise NotImplementedError
|
| 81 |
+
|
| 82 |
+
@classmethod
|
| 83 |
+
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
|
| 84 |
+
"""Aggregate logging outputs from data parallel training."""
|
| 85 |
+
utils.deprecation_warning(
|
| 86 |
+
"Criterions should implement the reduce_metrics API. "
|
| 87 |
+
"Falling back to deprecated aggregate_logging_outputs API."
|
| 88 |
+
)
|
| 89 |
+
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
|
| 90 |
+
for k, v in agg_logging_outputs.items():
|
| 91 |
+
if k in {"nsentences", "ntokens", "sample_size"}:
|
| 92 |
+
continue
|
| 93 |
+
metrics.log_scalar(k, v)
|
| 94 |
+
|
| 95 |
+
@staticmethod
|
| 96 |
+
def logging_outputs_can_be_summed() -> bool:
|
| 97 |
+
"""
|
| 98 |
+
Whether the logging outputs returned by `forward` can be summed
|
| 99 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
| 100 |
+
to True will improves distributed training speed.
|
| 101 |
+
"""
|
| 102 |
+
return False
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class LegacyFairseqCriterion(FairseqCriterion):
|
| 106 |
+
def __init__(self, args, task):
|
| 107 |
+
super().__init__(task=task)
|
| 108 |
+
self.args = args
|
| 109 |
+
|
| 110 |
+
utils.deprecation_warning(
|
| 111 |
+
"Criterions should take explicit arguments instead of an "
|
| 112 |
+
"argparse.Namespace object, please update your criterion by "
|
| 113 |
+
"extending FairseqCriterion instead of LegacyFairseqCriterion."
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def build_criterion(cls, args, task):
|
| 118 |
+
"""Construct a criterion from command-line args."""
|
| 119 |
+
return cls(args, task)
|
fairseq-0.10.2/fairseq/criterions/legacy_masked_lm.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from fairseq import metrics, utils
|
| 11 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
|
| 15 |
+
"""
|
| 16 |
+
Function to compute the cross entropy loss. The default value of
|
| 17 |
+
ignore_index is the same as the default value for F.cross_entropy in
|
| 18 |
+
pytorch.
|
| 19 |
+
"""
|
| 20 |
+
assert logits.size(0) == targets.size(
|
| 21 |
+
-1
|
| 22 |
+
), "Logits and Targets tensor shapes don't match up"
|
| 23 |
+
|
| 24 |
+
loss = F.nll_loss(
|
| 25 |
+
F.log_softmax(logits, -1, dtype=torch.float32),
|
| 26 |
+
targets,
|
| 27 |
+
reduction="sum",
|
| 28 |
+
ignore_index=ignore_index,
|
| 29 |
+
)
|
| 30 |
+
return loss
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@register_criterion("legacy_masked_lm_loss")
|
| 34 |
+
class LegacyMaskedLmLoss(FairseqCriterion):
|
| 35 |
+
"""
|
| 36 |
+
Implementation for the loss used in masked language model (MLM) training.
|
| 37 |
+
This optionally also computes the next sentence prediction (NSP) loss and
|
| 38 |
+
adds it to the overall loss based on the specified args. There are three
|
| 39 |
+
cases to consider:
|
| 40 |
+
1) Generic MLM training without NSP loss. In this case sentence_targets
|
| 41 |
+
and sentence_logits are both None.
|
| 42 |
+
2) BERT training without NSP loss. In this case sentence_targets is
|
| 43 |
+
not None but sentence_logits is None and we should not be computing
|
| 44 |
+
a sentence level loss.
|
| 45 |
+
3) BERT training with NSP loss. In this case both sentence_targets and
|
| 46 |
+
sentence_logits are not None and we should be computing a sentence
|
| 47 |
+
level loss. The weight of the sentence level loss is specified as
|
| 48 |
+
an argument.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, task, masked_lm_only, nsp_loss_weight):
|
| 52 |
+
super().__init__(task)
|
| 53 |
+
self.masked_lm_only = masked_lm_only
|
| 54 |
+
self.nsp_loss_weight = nsp_loss_weight
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def add_args(parser):
|
| 58 |
+
"""Args for MaskedLM Loss"""
|
| 59 |
+
# Default for masked_lm_only is False so as to not break BERT training
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--masked-lm-only",
|
| 62 |
+
default=False,
|
| 63 |
+
action="store_true",
|
| 64 |
+
help="compute MLM loss only",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--nsp-loss-weight",
|
| 68 |
+
default=1.0,
|
| 69 |
+
type=float,
|
| 70 |
+
help="weight for next sentence prediction" " loss (default 1)",
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def forward(self, model, sample, reduce=True):
|
| 74 |
+
"""Compute the loss for the given sample.
|
| 75 |
+
Returns a tuple with three elements:
|
| 76 |
+
1) the loss
|
| 77 |
+
2) the sample size, which is used as the denominator for the gradient
|
| 78 |
+
3) logging outputs to display while training
|
| 79 |
+
"""
|
| 80 |
+
lm_logits, output_metadata = model(**sample["net_input"])
|
| 81 |
+
|
| 82 |
+
# reshape lm_logits from (N,T,C) to (N*T,C)
|
| 83 |
+
lm_logits = lm_logits.view(-1, lm_logits.size(-1))
|
| 84 |
+
lm_targets = sample["lm_target"].view(-1)
|
| 85 |
+
lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets, self.padding_idx)
|
| 86 |
+
|
| 87 |
+
# compute the number of tokens for which loss is computed. This is used
|
| 88 |
+
# to normalize the loss
|
| 89 |
+
ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
|
| 90 |
+
loss = lm_loss / ntokens
|
| 91 |
+
nsentences = sample["nsentences"]
|
| 92 |
+
# nsentences = 0
|
| 93 |
+
|
| 94 |
+
# Compute sentence loss if masked_lm_only is False
|
| 95 |
+
sentence_loss = None
|
| 96 |
+
if not self.masked_lm_only:
|
| 97 |
+
sentence_logits = output_metadata["sentence_logits"]
|
| 98 |
+
sentence_targets = sample["sentence_target"].view(-1)
|
| 99 |
+
# This needs to be recomputed due to some differences between
|
| 100 |
+
# TokenBlock and BlockPair dataset. This can be resolved with a
|
| 101 |
+
# refactor of BERTModel which we will do in the future.
|
| 102 |
+
# TODO: Remove this after refactor of BERTModel
|
| 103 |
+
nsentences = sentence_targets.size(0)
|
| 104 |
+
|
| 105 |
+
# Check for logits being none which can happen when remove_heads
|
| 106 |
+
# is set to true in the BERT model. Ideally we should set
|
| 107 |
+
# masked_lm_only to true in this case, but that requires some
|
| 108 |
+
# refactor in the BERT model.
|
| 109 |
+
if sentence_logits is not None:
|
| 110 |
+
sentence_loss = compute_cross_entropy_loss(
|
| 111 |
+
sentence_logits, sentence_targets
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
loss += self.nsp_loss_weight * (sentence_loss / nsentences)
|
| 115 |
+
|
| 116 |
+
# NOTE: as we are summing up per token mlm loss and per sentence nsp loss
|
| 117 |
+
# we don't need to use sample_size as denominator for the gradient
|
| 118 |
+
# here sample_size is just used for logging
|
| 119 |
+
sample_size = 1
|
| 120 |
+
logging_output = {
|
| 121 |
+
"loss": utils.item(loss.data) if reduce else loss.data,
|
| 122 |
+
"lm_loss": utils.item(lm_loss.data) if reduce else lm_loss.data,
|
| 123 |
+
# sentence loss is not always computed
|
| 124 |
+
"sentence_loss": (
|
| 125 |
+
(utils.item(sentence_loss.data) if reduce else sentence_loss.data)
|
| 126 |
+
if sentence_loss is not None
|
| 127 |
+
else 0.0
|
| 128 |
+
),
|
| 129 |
+
"ntokens": ntokens,
|
| 130 |
+
"nsentences": nsentences,
|
| 131 |
+
"sample_size": sample_size,
|
| 132 |
+
}
|
| 133 |
+
return loss, sample_size, logging_output
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def reduce_metrics(logging_outputs) -> None:
|
| 137 |
+
"""Aggregate logging outputs from data parallel training."""
|
| 138 |
+
lm_loss_sum = sum(log.get("lm_loss", 0) for log in logging_outputs)
|
| 139 |
+
sentence_loss_sum = sum(log.get("sentence_loss", 0) for log in logging_outputs)
|
| 140 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
| 141 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
| 142 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
| 143 |
+
agg_loss = sum(log.get("loss", 0) for log in logging_outputs)
|
| 144 |
+
|
| 145 |
+
metrics.log_scalar(
|
| 146 |
+
"loss",
|
| 147 |
+
agg_loss / sample_size / math.log(2) if sample_size > 0 else 0.0,
|
| 148 |
+
sample_size,
|
| 149 |
+
round=3,
|
| 150 |
+
)
|
| 151 |
+
metrics.log_scalar(
|
| 152 |
+
"lm_loss",
|
| 153 |
+
lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
|
| 154 |
+
ntokens,
|
| 155 |
+
round=3,
|
| 156 |
+
)
|
| 157 |
+
metrics.log_scalar(
|
| 158 |
+
"sentence_loss",
|
| 159 |
+
sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0.0,
|
| 160 |
+
nsentences,
|
| 161 |
+
round=3,
|
| 162 |
+
)
|
| 163 |
+
metrics.log_scalar(
|
| 164 |
+
"nll_loss",
|
| 165 |
+
lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
|
| 166 |
+
ntokens,
|
| 167 |
+
round=3,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
def logging_outputs_can_be_summed() -> bool:
|
| 172 |
+
"""
|
| 173 |
+
Whether the logging outputs returned by `forward` can be summed
|
| 174 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
| 175 |
+
to True will improves distributed training speed.
|
| 176 |
+
"""
|
| 177 |
+
return True
|
fairseq-0.10.2/fairseq/criterions/sentence_prediction.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from fairseq import metrics, utils
|
| 11 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@register_criterion("sentence_prediction")
|
| 15 |
+
class SentencePredictionCriterion(FairseqCriterion):
|
| 16 |
+
def __init__(self, task, classification_head_name, regression_target):
|
| 17 |
+
super().__init__(task)
|
| 18 |
+
self.classification_head_name = classification_head_name
|
| 19 |
+
self.regression_target = regression_target
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def add_args(parser):
|
| 23 |
+
# fmt: off
|
| 24 |
+
parser.add_argument('--classification-head-name',
|
| 25 |
+
default='sentence_classification_head',
|
| 26 |
+
help='name of the classification head to use')
|
| 27 |
+
# fmt: on
|
| 28 |
+
|
| 29 |
+
def forward(self, model, sample, reduce=True):
|
| 30 |
+
"""Compute the loss for the given sample.
|
| 31 |
+
|
| 32 |
+
Returns a tuple with three elements:
|
| 33 |
+
1) the loss
|
| 34 |
+
2) the sample size, which is used as the denominator for the gradient
|
| 35 |
+
3) logging outputs to display while training
|
| 36 |
+
"""
|
| 37 |
+
assert (
|
| 38 |
+
hasattr(model, "classification_heads")
|
| 39 |
+
and self.classification_head_name in model.classification_heads
|
| 40 |
+
), "model must provide sentence classification head for --criterion=sentence_prediction"
|
| 41 |
+
|
| 42 |
+
logits, _ = model(
|
| 43 |
+
**sample["net_input"],
|
| 44 |
+
features_only=True,
|
| 45 |
+
classification_head_name=self.classification_head_name,
|
| 46 |
+
)
|
| 47 |
+
targets = model.get_targets(sample, [logits]).view(-1)
|
| 48 |
+
sample_size = targets.numel()
|
| 49 |
+
|
| 50 |
+
if not self.regression_target:
|
| 51 |
+
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
| 52 |
+
loss = F.nll_loss(lprobs, targets, reduction="sum")
|
| 53 |
+
else:
|
| 54 |
+
logits = logits.view(-1).float()
|
| 55 |
+
targets = targets.float()
|
| 56 |
+
loss = F.mse_loss(logits, targets, reduction="sum")
|
| 57 |
+
|
| 58 |
+
logging_output = {
|
| 59 |
+
"loss": loss.data,
|
| 60 |
+
"ntokens": sample["ntokens"],
|
| 61 |
+
"nsentences": sample_size,
|
| 62 |
+
"sample_size": sample_size,
|
| 63 |
+
}
|
| 64 |
+
if not self.regression_target:
|
| 65 |
+
preds = logits.argmax(dim=1)
|
| 66 |
+
logging_output["ncorrect"] = (preds == targets).sum()
|
| 67 |
+
|
| 68 |
+
return loss, sample_size, logging_output
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def reduce_metrics(logging_outputs) -> None:
|
| 72 |
+
"""Aggregate logging outputs from data parallel training."""
|
| 73 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
| 74 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
| 75 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
| 76 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
| 77 |
+
|
| 78 |
+
metrics.log_scalar(
|
| 79 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
| 80 |
+
)
|
| 81 |
+
if sample_size != ntokens:
|
| 82 |
+
metrics.log_scalar(
|
| 83 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
|
| 87 |
+
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
|
| 88 |
+
metrics.log_scalar(
|
| 89 |
+
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
def logging_outputs_can_be_summed() -> bool:
|
| 94 |
+
"""
|
| 95 |
+
Whether the logging outputs returned by `forward` can be summed
|
| 96 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
| 97 |
+
to True will improves distributed training speed.
|
| 98 |
+
"""
|
| 99 |
+
return True
|
fairseq-0.10.2/fairseq/criterions/sentence_ranking.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from fairseq import metrics, utils
|
| 11 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@register_criterion("sentence_ranking")
|
| 15 |
+
class SentenceRankingCriterion(FairseqCriterion):
|
| 16 |
+
def __init__(self, task, ranking_head_name, save_predictions, num_classes):
|
| 17 |
+
super().__init__(task)
|
| 18 |
+
self.ranking_head_name = ranking_head_name
|
| 19 |
+
if save_predictions is not None:
|
| 20 |
+
self.prediction_h = open(save_predictions, "w")
|
| 21 |
+
else:
|
| 22 |
+
self.prediction_h = None
|
| 23 |
+
self.num_classes = num_classes
|
| 24 |
+
|
| 25 |
+
def __del__(self):
|
| 26 |
+
if self.prediction_h is not None:
|
| 27 |
+
self.prediction_h.close()
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def add_args(parser):
|
| 31 |
+
# fmt: off
|
| 32 |
+
parser.add_argument('--save-predictions', metavar='FILE',
|
| 33 |
+
help='file to save predictions to')
|
| 34 |
+
parser.add_argument('--ranking-head-name',
|
| 35 |
+
default='sentence_classification_head',
|
| 36 |
+
help='name of the ranking head to use')
|
| 37 |
+
# fmt: on
|
| 38 |
+
|
| 39 |
+
def forward(self, model, sample, reduce=True):
|
| 40 |
+
"""Compute ranking loss for the given sample.
|
| 41 |
+
|
| 42 |
+
Returns a tuple with three elements:
|
| 43 |
+
1) the loss
|
| 44 |
+
2) the sample size, which is used as the denominator for the gradient
|
| 45 |
+
3) logging outputs to display while training
|
| 46 |
+
"""
|
| 47 |
+
assert (
|
| 48 |
+
hasattr(model, "classification_heads")
|
| 49 |
+
and self.ranking_head_name in model.classification_heads
|
| 50 |
+
), "model must provide sentence ranking head for --criterion=sentence_ranking"
|
| 51 |
+
|
| 52 |
+
scores = []
|
| 53 |
+
for idx in range(self.num_classes):
|
| 54 |
+
score, _ = model(
|
| 55 |
+
**sample["net_input{idx}".format(idx=idx + 1)],
|
| 56 |
+
classification_head_name=self.ranking_head_name,
|
| 57 |
+
)
|
| 58 |
+
scores.append(score)
|
| 59 |
+
|
| 60 |
+
logits = torch.cat(scores, dim=1)
|
| 61 |
+
sample_size = logits.size(0)
|
| 62 |
+
|
| 63 |
+
if "target" in sample:
|
| 64 |
+
targets = model.get_targets(sample, [logits]).view(-1)
|
| 65 |
+
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
| 66 |
+
loss = F.nll_loss(lprobs, targets, reduction="sum")
|
| 67 |
+
else:
|
| 68 |
+
targets = None
|
| 69 |
+
loss = torch.tensor(0.0, requires_grad=True)
|
| 70 |
+
|
| 71 |
+
if self.prediction_h is not None:
|
| 72 |
+
preds = logits.argmax(dim=1)
|
| 73 |
+
for i, (id, pred) in enumerate(zip(sample["id"].tolist(), preds.tolist())):
|
| 74 |
+
if targets is not None:
|
| 75 |
+
label = targets[i].item()
|
| 76 |
+
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
|
| 77 |
+
else:
|
| 78 |
+
print("{}\t{}".format(id, pred), file=self.prediction_h)
|
| 79 |
+
|
| 80 |
+
logging_output = {
|
| 81 |
+
"loss": loss.data,
|
| 82 |
+
"ntokens": sample["ntokens"],
|
| 83 |
+
"nsentences": sample_size,
|
| 84 |
+
"sample_size": sample_size,
|
| 85 |
+
}
|
| 86 |
+
if targets is not None:
|
| 87 |
+
logging_output["ncorrect"] = (logits.argmax(dim=1) == targets).sum()
|
| 88 |
+
|
| 89 |
+
return loss, sample_size, logging_output
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def reduce_metrics(logging_outputs) -> None:
|
| 93 |
+
"""Aggregate logging outputs from data parallel training."""
|
| 94 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
| 95 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
| 96 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
| 97 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
| 98 |
+
|
| 99 |
+
metrics.log_scalar(
|
| 100 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
| 101 |
+
)
|
| 102 |
+
if sample_size != ntokens:
|
| 103 |
+
metrics.log_scalar(
|
| 104 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
|
| 108 |
+
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
|
| 109 |
+
metrics.log_scalar(
|
| 110 |
+
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
@staticmethod
|
| 114 |
+
def logging_outputs_can_be_summed() -> bool:
|
| 115 |
+
"""
|
| 116 |
+
Whether the logging outputs returned by `forward` can be summed
|
| 117 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
| 118 |
+
to True will improves distributed training speed.
|
| 119 |
+
"""
|
| 120 |
+
return True
|
fairseq-0.10.2/fairseq/model_parallel/models/__pycache__/transformer.cpython-310.pyc
ADDED
|
Binary file (4.05 kB). View file
|
|
|
fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py
ADDED
|
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from fairseq import utils
|
| 12 |
+
from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import (
|
| 13 |
+
Embedding,
|
| 14 |
+
TransformerDecoderEmbedding,
|
| 15 |
+
TransformerDecoderLayer,
|
| 16 |
+
TransformerDecoderOutputLayer,
|
| 17 |
+
TransformerEncoderEmbedding,
|
| 18 |
+
TransformerEncoderLayer,
|
| 19 |
+
TransformerEncoderLayerNorm,
|
| 20 |
+
)
|
| 21 |
+
from fairseq.models import (
|
| 22 |
+
BaseFairseqModel,
|
| 23 |
+
FairseqDecoder,
|
| 24 |
+
FairseqEncoder,
|
| 25 |
+
register_model,
|
| 26 |
+
register_model_architecture,
|
| 27 |
+
)
|
| 28 |
+
from fairseq.models.fairseq_encoder import EncoderOut
|
| 29 |
+
from fairseq.models.transformer import (
|
| 30 |
+
base_architecture,
|
| 31 |
+
transformer_iwslt_de_en,
|
| 32 |
+
transformer_wmt_en_de_big,
|
| 33 |
+
)
|
| 34 |
+
from fairseq.modules import SinusoidalPositionalEmbedding
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
| 41 |
+
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@register_model("pipeline_parallel_transformer")
|
| 45 |
+
class PipelineParallelTransformerModel(BaseFairseqModel):
|
| 46 |
+
def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint):
|
| 47 |
+
try:
|
| 48 |
+
from fairscale.nn import Pipe
|
| 49 |
+
except ImportError:
|
| 50 |
+
raise ImportError("Please install fairscale with: pip install fairscale")
|
| 51 |
+
super().__init__()
|
| 52 |
+
assert isinstance(encoder, FairseqEncoder)
|
| 53 |
+
assert isinstance(decoder, FairseqDecoder)
|
| 54 |
+
encoder_module_list = (
|
| 55 |
+
[encoder.embedding_layer]
|
| 56 |
+
+ list(encoder.encoder_layers)
|
| 57 |
+
+ [encoder.final_layer_norm]
|
| 58 |
+
)
|
| 59 |
+
self.num_encoder_modules = len(encoder_module_list)
|
| 60 |
+
decoder_module_list = (
|
| 61 |
+
[decoder.embedding_layer]
|
| 62 |
+
+ list(decoder.decoder_layers)
|
| 63 |
+
+ [decoder.decoder_output_layer]
|
| 64 |
+
)
|
| 65 |
+
self.num_decoder_modules = len(decoder_module_list)
|
| 66 |
+
module_list = encoder_module_list + decoder_module_list
|
| 67 |
+
self.devices = devices
|
| 68 |
+
self.model = Pipe(
|
| 69 |
+
nn.Sequential(*module_list),
|
| 70 |
+
balance=balance,
|
| 71 |
+
devices=devices,
|
| 72 |
+
chunks=chunks,
|
| 73 |
+
checkpoint=checkpoint,
|
| 74 |
+
)
|
| 75 |
+
self.encoder_max_positions = self.max_positions_helper(
|
| 76 |
+
encoder.embedding_layer, "max_source_positions"
|
| 77 |
+
)
|
| 78 |
+
self.decoder_max_positions = self.max_positions_helper(
|
| 79 |
+
decoder.embedding_layer, "max_target_positions"
|
| 80 |
+
)
|
| 81 |
+
self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None)
|
| 82 |
+
# Note: To be populated during inference
|
| 83 |
+
self.encoder = None
|
| 84 |
+
self.decoder = None
|
| 85 |
+
|
| 86 |
+
def forward(self, src_tokens, src_lengths, prev_output_tokens):
|
| 87 |
+
if self.training:
|
| 88 |
+
input_lst = [src_tokens, src_lengths, prev_output_tokens]
|
| 89 |
+
input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst)
|
| 90 |
+
return self.model(input)
|
| 91 |
+
else:
|
| 92 |
+
assert self.encoder is not None and self.decoder is not None, (
|
| 93 |
+
"encoder and decoder need to be initialized by "
|
| 94 |
+
+ "calling the `prepare_for_inference_()` method"
|
| 95 |
+
)
|
| 96 |
+
encoder_output_tuple = self.encoder(input)
|
| 97 |
+
return self.decoder(encoder_output_tuple)
|
| 98 |
+
|
| 99 |
+
def prepare_for_inference_(self, args):
|
| 100 |
+
if self.encoder is not None and self.decoder is not None:
|
| 101 |
+
logger.info("Encoder and Decoder already initialized")
|
| 102 |
+
return
|
| 103 |
+
encoder_module_list = []
|
| 104 |
+
decoder_module_list = []
|
| 105 |
+
module_count = 0
|
| 106 |
+
for partition in self.model.partitions:
|
| 107 |
+
for module in partition:
|
| 108 |
+
if module_count < self.num_encoder_modules:
|
| 109 |
+
encoder_module_list.append(module)
|
| 110 |
+
else:
|
| 111 |
+
decoder_module_list.append(module)
|
| 112 |
+
module_count += 1
|
| 113 |
+
self.model = None
|
| 114 |
+
self.encoder = TransformerEncoder(args, None, None, encoder_module_list)
|
| 115 |
+
self.decoder = TransformerDecoder(
|
| 116 |
+
args, None, None, decoder_module_list=decoder_module_list
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def add_args(parser):
|
| 121 |
+
"""Add model-specific arguments to the parser."""
|
| 122 |
+
# fmt: off
|
| 123 |
+
parser.add_argument('--activation-fn',
|
| 124 |
+
choices=utils.get_available_activation_fns(),
|
| 125 |
+
help='activation function to use')
|
| 126 |
+
parser.add_argument('--dropout', type=float, metavar='D',
|
| 127 |
+
help='dropout probability')
|
| 128 |
+
parser.add_argument('--attention-dropout', type=float, metavar='D',
|
| 129 |
+
help='dropout probability for attention weights')
|
| 130 |
+
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
|
| 131 |
+
help='dropout probability after activation in FFN.')
|
| 132 |
+
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
|
| 133 |
+
help='path to pre-trained encoder embedding')
|
| 134 |
+
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
|
| 135 |
+
help='encoder embedding dimension')
|
| 136 |
+
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
|
| 137 |
+
help='encoder embedding dimension for FFN')
|
| 138 |
+
parser.add_argument('--encoder-layers', type=int, metavar='N',
|
| 139 |
+
help='num encoder layers')
|
| 140 |
+
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
|
| 141 |
+
help='num encoder attention heads')
|
| 142 |
+
parser.add_argument('--encoder-normalize-before', action='store_true',
|
| 143 |
+
help='apply layernorm before each encoder block')
|
| 144 |
+
parser.add_argument('--encoder-learned-pos', action='store_true',
|
| 145 |
+
help='use learned positional embeddings in the encoder')
|
| 146 |
+
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
|
| 147 |
+
help='path to pre-trained decoder embedding')
|
| 148 |
+
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
|
| 149 |
+
help='decoder embedding dimension')
|
| 150 |
+
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
|
| 151 |
+
help='decoder embedding dimension for FFN')
|
| 152 |
+
parser.add_argument('--decoder-layers', type=int, metavar='N',
|
| 153 |
+
help='num decoder layers')
|
| 154 |
+
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
|
| 155 |
+
help='num decoder attention heads')
|
| 156 |
+
parser.add_argument('--decoder-learned-pos', action='store_true',
|
| 157 |
+
help='use learned positional embeddings in the decoder')
|
| 158 |
+
parser.add_argument('--decoder-normalize-before', action='store_true',
|
| 159 |
+
help='apply layernorm before each decoder block')
|
| 160 |
+
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
|
| 161 |
+
help='share decoder input and output embeddings')
|
| 162 |
+
parser.add_argument('--share-all-embeddings', action='store_true',
|
| 163 |
+
help='share encoder, decoder and output embeddings'
|
| 164 |
+
' (requires shared dictionary and embed dim)')
|
| 165 |
+
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
|
| 166 |
+
help='if set, disables positional embeddings (outside self attention)')
|
| 167 |
+
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
|
| 168 |
+
help='comma separated list of adaptive softmax cutoff points. '
|
| 169 |
+
'Must be used with adaptive_loss criterion'),
|
| 170 |
+
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
|
| 171 |
+
help='sets adaptive softmax dropout for the tail projections')
|
| 172 |
+
parser.add_argument('--num-embedding-chunks', type=int, metavar='N', default=1,
|
| 173 |
+
help='Number of embedding layer chunks (enables more even distribution'
|
| 174 |
+
'of optimizer states across data parallel nodes'
|
| 175 |
+
'when using optimizer state sharding and'
|
| 176 |
+
'a big embedding vocabulary)')
|
| 177 |
+
# fmt: on
|
| 178 |
+
|
| 179 |
+
@classmethod
|
| 180 |
+
def build_model_base(cls, args, task):
|
| 181 |
+
"""Build a new model instance."""
|
| 182 |
+
|
| 183 |
+
# make sure all arguments are present in older models
|
| 184 |
+
base_architecture(args)
|
| 185 |
+
|
| 186 |
+
if not hasattr(args, "max_source_positions"):
|
| 187 |
+
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
|
| 188 |
+
if not hasattr(args, "max_target_positions"):
|
| 189 |
+
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
|
| 190 |
+
|
| 191 |
+
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
| 192 |
+
|
| 193 |
+
def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1):
|
| 194 |
+
assert embed_dim % num_embed_chunks == 0, (
|
| 195 |
+
f"Number of embedding chunks = {num_embed_chunks} should be "
|
| 196 |
+
+ f"divisible by the embedding dimension = {embed_dim}"
|
| 197 |
+
)
|
| 198 |
+
assert path is None or num_embed_chunks == 1, (
|
| 199 |
+
"Loading embedding from a path with number of embedding chunks > 1"
|
| 200 |
+
+ " is not yet supported"
|
| 201 |
+
)
|
| 202 |
+
num_embeddings = len(dictionary)
|
| 203 |
+
padding_idx = dictionary.pad()
|
| 204 |
+
# if provided, load from preloaded dictionaries
|
| 205 |
+
if path:
|
| 206 |
+
emb = Embedding(num_embeddings, embed_dim, padding_idx)
|
| 207 |
+
embed_dict = utils.parse_embedding(path)
|
| 208 |
+
utils.load_embedding(embed_dict, dictionary, emb)
|
| 209 |
+
else:
|
| 210 |
+
embed_chunk_dim = embed_dim // num_embed_chunks
|
| 211 |
+
emb = nn.ModuleList()
|
| 212 |
+
for i in range(num_embed_chunks):
|
| 213 |
+
emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx))
|
| 214 |
+
return emb
|
| 215 |
+
|
| 216 |
+
num_embed_chunks = args.num_embedding_chunks
|
| 217 |
+
if args.share_all_embeddings:
|
| 218 |
+
if src_dict != tgt_dict:
|
| 219 |
+
raise ValueError("--share-all-embeddings requires a joined dictionary")
|
| 220 |
+
if args.encoder_embed_dim != args.decoder_embed_dim:
|
| 221 |
+
raise ValueError(
|
| 222 |
+
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
|
| 223 |
+
)
|
| 224 |
+
if args.decoder_embed_path and (
|
| 225 |
+
args.decoder_embed_path != args.encoder_embed_path
|
| 226 |
+
):
|
| 227 |
+
raise ValueError(
|
| 228 |
+
"--share-all-embeddings not compatible with --decoder-embed-path"
|
| 229 |
+
)
|
| 230 |
+
encoder_embed_tokens = build_embedding(
|
| 231 |
+
src_dict,
|
| 232 |
+
args.encoder_embed_dim,
|
| 233 |
+
args.encoder_embed_path,
|
| 234 |
+
num_embed_chunks,
|
| 235 |
+
)
|
| 236 |
+
decoder_embed_tokens = encoder_embed_tokens
|
| 237 |
+
args.share_decoder_input_output_embed = True
|
| 238 |
+
else:
|
| 239 |
+
assert args.share_decoder_input_output_embed or num_embed_chunks == 1, (
|
| 240 |
+
"Not sharing decoder I/O embeddings is not yet supported with number of "
|
| 241 |
+
+ "embedding chunks > 1"
|
| 242 |
+
)
|
| 243 |
+
encoder_embed_tokens = build_embedding(
|
| 244 |
+
src_dict,
|
| 245 |
+
args.encoder_embed_dim,
|
| 246 |
+
args.encoder_embed_path,
|
| 247 |
+
num_embed_chunks,
|
| 248 |
+
)
|
| 249 |
+
decoder_embed_tokens = build_embedding(
|
| 250 |
+
tgt_dict,
|
| 251 |
+
args.decoder_embed_dim,
|
| 252 |
+
args.decoder_embed_path,
|
| 253 |
+
num_embed_chunks,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
|
| 257 |
+
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
|
| 258 |
+
return (encoder, decoder)
|
| 259 |
+
|
| 260 |
+
@classmethod
|
| 261 |
+
def build_encoder(cls, args, src_dict, embed_tokens):
|
| 262 |
+
return TransformerEncoder(args, src_dict, embed_tokens)
|
| 263 |
+
|
| 264 |
+
@classmethod
|
| 265 |
+
def build_decoder(cls, args, tgt_dict, embed_tokens):
|
| 266 |
+
return TransformerDecoder(args, tgt_dict, embed_tokens)
|
| 267 |
+
|
| 268 |
+
@classmethod
|
| 269 |
+
def build_model(cls, args, task):
|
| 270 |
+
encoder, decoder = cls.build_model_base(args, task)
|
| 271 |
+
return PipelineParallelTransformerModel(
|
| 272 |
+
encoder=encoder,
|
| 273 |
+
decoder=decoder,
|
| 274 |
+
balance=utils.eval_str_list(args.pipeline_balance, type=int),
|
| 275 |
+
devices=utils.eval_str_list(args.pipeline_devices, type=int),
|
| 276 |
+
chunks=args.pipeline_chunks,
|
| 277 |
+
checkpoint=args.pipeline_checkpoint,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def output_layer(self, features, **kwargs):
|
| 281 |
+
"""Project features to the default output size (typically vocabulary size)."""
|
| 282 |
+
return self.decoder.output_layer(features, **kwargs)
|
| 283 |
+
|
| 284 |
+
def max_positions(self):
|
| 285 |
+
"""Maximum length supported by the model."""
|
| 286 |
+
return (self.encoder_max_positions, self.decoder_max_positions)
|
| 287 |
+
|
| 288 |
+
def max_positions_helper(
|
| 289 |
+
self, embedding_layer, max_positions_field="max_source_positions"
|
| 290 |
+
):
|
| 291 |
+
"""Maximum input length supported by the encoder or decoder."""
|
| 292 |
+
if embedding_layer.embed_positions is None:
|
| 293 |
+
return getattr(embedding_layer, max_positions_field)
|
| 294 |
+
return min(
|
| 295 |
+
getattr(embedding_layer, max_positions_field),
|
| 296 |
+
embedding_layer.embed_positions.max_positions,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
| 300 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
| 301 |
+
|
| 302 |
+
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
|
| 303 |
+
if sample is not None:
|
| 304 |
+
assert "target" in sample
|
| 305 |
+
target = sample["target"]
|
| 306 |
+
else:
|
| 307 |
+
target = None
|
| 308 |
+
out = self.adaptive_softmax.get_log_prob(net_output, target=target)
|
| 309 |
+
return out.exp_() if not log_probs else out
|
| 310 |
+
|
| 311 |
+
# A Pipe() module returns a tuple of tensors as the output.
|
| 312 |
+
# In this case, the tuple has one element - the output tensor of logits
|
| 313 |
+
logits = net_output if isinstance(net_output, torch.Tensor) else net_output[0]
|
| 314 |
+
if log_probs:
|
| 315 |
+
return utils.log_softmax(logits, dim=-1, onnx_trace=False)
|
| 316 |
+
else:
|
| 317 |
+
return utils.softmax(logits, dim=-1, onnx_trace=False)
|
| 318 |
+
|
| 319 |
+
def max_decoder_positions(self):
|
| 320 |
+
"""Maximum length supported by the decoder."""
|
| 321 |
+
return self.decoder_max_positions
|
| 322 |
+
|
| 323 |
+
def load_state_dict(self, state_dict, strict=True, args=None):
|
| 324 |
+
"""Copies parameters and buffers from *state_dict* into this module and
|
| 325 |
+
its descendants.
|
| 326 |
+
|
| 327 |
+
Overrides the method in :class:`nn.Module`. Compared with that method
|
| 328 |
+
this additionally "upgrades" *state_dicts* from old checkpoints.
|
| 329 |
+
"""
|
| 330 |
+
self.upgrade_state_dict(state_dict)
|
| 331 |
+
is_regular_transformer = not any("model.partitions" in k for k in state_dict)
|
| 332 |
+
if is_regular_transformer:
|
| 333 |
+
state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict)
|
| 334 |
+
return super().load_state_dict(state_dict, strict)
|
| 335 |
+
|
| 336 |
+
def convert_to_pipeline_parallel_state_dict(self, state_dict):
|
| 337 |
+
new_state_dict = self.state_dict()
|
| 338 |
+
encoder_layer_idx = 0
|
| 339 |
+
decoder_layer_idx = 0
|
| 340 |
+
encoder_key_suffixes = [
|
| 341 |
+
"self_attn.k_proj.weight",
|
| 342 |
+
"self_attn.k_proj.bias",
|
| 343 |
+
"self_attn.v_proj.weight",
|
| 344 |
+
"self_attn.v_proj.bias",
|
| 345 |
+
"self_attn.q_proj.weight",
|
| 346 |
+
"self_attn.q_proj.bias",
|
| 347 |
+
"self_attn.out_proj.weight",
|
| 348 |
+
"self_attn.out_proj.bias",
|
| 349 |
+
"self_attn_layer_norm.weight",
|
| 350 |
+
"self_attn_layer_norm.bias",
|
| 351 |
+
"fc1.weight",
|
| 352 |
+
"fc1.bias",
|
| 353 |
+
"fc2.weight",
|
| 354 |
+
"fc2.bias",
|
| 355 |
+
"final_layer_norm.weight",
|
| 356 |
+
"final_layer_norm.bias",
|
| 357 |
+
]
|
| 358 |
+
decoder_key_suffixes = [
|
| 359 |
+
"self_attn.k_proj.weight",
|
| 360 |
+
"self_attn.k_proj.bias",
|
| 361 |
+
"self_attn.v_proj.weight",
|
| 362 |
+
"self_attn.v_proj.bias",
|
| 363 |
+
"self_attn.q_proj.weight",
|
| 364 |
+
"self_attn.q_proj.bias",
|
| 365 |
+
"self_attn.out_proj.weight",
|
| 366 |
+
"self_attn.out_proj.bias",
|
| 367 |
+
"self_attn_layer_norm.weight",
|
| 368 |
+
"self_attn_layer_norm.bias",
|
| 369 |
+
"encoder_attn.k_proj.weight",
|
| 370 |
+
"encoder_attn.k_proj.bias",
|
| 371 |
+
"encoder_attn.v_proj.weight",
|
| 372 |
+
"encoder_attn.v_proj.bias",
|
| 373 |
+
"encoder_attn.q_proj.weight",
|
| 374 |
+
"encoder_attn.q_proj.bias",
|
| 375 |
+
"encoder_attn.out_proj.weight",
|
| 376 |
+
"encoder_attn.out_proj.bias",
|
| 377 |
+
"encoder_attn_layer_norm.weight",
|
| 378 |
+
"encoder_attn_layer_norm.bias",
|
| 379 |
+
"fc1.weight",
|
| 380 |
+
"fc1.bias",
|
| 381 |
+
"fc2.weight",
|
| 382 |
+
"fc2.bias",
|
| 383 |
+
"final_layer_norm.weight",
|
| 384 |
+
"final_layer_norm.bias",
|
| 385 |
+
]
|
| 386 |
+
for pid, partition in enumerate(self.model.partitions):
|
| 387 |
+
logger.info(f"Begin Partition {pid}")
|
| 388 |
+
for mid, module in enumerate(partition):
|
| 389 |
+
# fmt: off
|
| 390 |
+
if isinstance(module, TransformerEncoderEmbedding):
|
| 391 |
+
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight']
|
| 392 |
+
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['encoder.embed_positions._float_tensor']
|
| 393 |
+
if isinstance(module, TransformerEncoderLayer):
|
| 394 |
+
for suffix in encoder_key_suffixes:
|
| 395 |
+
new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}']
|
| 396 |
+
encoder_layer_idx += 1
|
| 397 |
+
if isinstance(module, TransformerDecoderLayer):
|
| 398 |
+
for suffix in decoder_key_suffixes:
|
| 399 |
+
new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'decoder.layers.{decoder_layer_idx}.{suffix}']
|
| 400 |
+
decoder_layer_idx += 1
|
| 401 |
+
if isinstance(module, TransformerEncoderLayerNorm):
|
| 402 |
+
if 'encoder.layer_norm.weight' in state_dict:
|
| 403 |
+
new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.weight'] = state_dict['encoder.layer_norm.weight']
|
| 404 |
+
new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias']
|
| 405 |
+
if isinstance(module, TransformerDecoderEmbedding):
|
| 406 |
+
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight']
|
| 407 |
+
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['decoder.embed_positions._float_tensor']
|
| 408 |
+
if isinstance(module, TransformerDecoderOutputLayer):
|
| 409 |
+
new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight']
|
| 410 |
+
# fmt: on
|
| 411 |
+
return new_state_dict
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class TransformerEncoder(FairseqEncoder):
|
| 415 |
+
"""
|
| 416 |
+
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
|
| 417 |
+
is a :class:`TransformerEncoderLayer`.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
args (argparse.Namespace): parsed command-line arguments
|
| 421 |
+
dictionary (~fairseq.data.Dictionary): encoding dictionary
|
| 422 |
+
embed_tokens (torch.nn.Embedding): input embedding
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None):
|
| 426 |
+
super().__init__(dictionary)
|
| 427 |
+
self.register_buffer("version", torch.Tensor([3]))
|
| 428 |
+
try:
|
| 429 |
+
from fairscale.nn import Pipe
|
| 430 |
+
except ImportError:
|
| 431 |
+
raise ImportError("Please install fairscale with: pip install fairscale")
|
| 432 |
+
if encoder_module_list is None:
|
| 433 |
+
embedding_layer = TransformerEncoderEmbedding(args, embed_tokens)
|
| 434 |
+
layers = [TransformerEncoderLayer(args) for i in range(args.encoder_layers)]
|
| 435 |
+
if isinstance(embed_tokens, nn.ModuleList):
|
| 436 |
+
emb_dim = sum(e.embedding_dim for e in embed_tokens)
|
| 437 |
+
else:
|
| 438 |
+
emb_dim = embed_tokens.embedding_dim
|
| 439 |
+
final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim)
|
| 440 |
+
encoder_module_list = [embedding_layer] + layers + [final_layer_norm]
|
| 441 |
+
self.use_pipeline = getattr(args, "pipeline_encoder_balance", None) is not None
|
| 442 |
+
if self.use_pipeline:
|
| 443 |
+
encoder_balance = utils.eval_str_list(
|
| 444 |
+
args.pipeline_encoder_balance, type=int
|
| 445 |
+
)
|
| 446 |
+
encoder_devices = utils.eval_str_list(
|
| 447 |
+
args.pipeline_encoder_devices, type=int
|
| 448 |
+
)
|
| 449 |
+
assert sum(encoder_balance) == len(encoder_module_list), (
|
| 450 |
+
f"Sum of encoder_balance={encoder_balance} is not equal "
|
| 451 |
+
+ f"to num_encoder_modules={len(encoder_module_list)}"
|
| 452 |
+
)
|
| 453 |
+
self.model = Pipe(
|
| 454 |
+
module=nn.Sequential(*encoder_module_list),
|
| 455 |
+
balance=encoder_balance,
|
| 456 |
+
devices=encoder_devices,
|
| 457 |
+
chunks=args.pipeline_chunks,
|
| 458 |
+
checkpoint=args.pipeline_checkpoint,
|
| 459 |
+
)
|
| 460 |
+
else:
|
| 461 |
+
self.embedding_layer = encoder_module_list[0]
|
| 462 |
+
self.encoder_layers = nn.Sequential(*encoder_module_list[1:-1])
|
| 463 |
+
self.final_layer_norm = encoder_module_list[-1]
|
| 464 |
+
|
| 465 |
+
def forward(self, src_tokens, src_lengths):
|
| 466 |
+
"""
|
| 467 |
+
Args:
|
| 468 |
+
input_tuple(
|
| 469 |
+
src_tokens (LongTensor): tokens in the source language of shape
|
| 470 |
+
`(batch, src_len)`
|
| 471 |
+
src_lengths (torch.LongTensor): lengths of each source sentence of
|
| 472 |
+
shape `(batch)`
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
Returns:
|
| 476 |
+
output_tuple(
|
| 477 |
+
- **encoder_out** (Tensor): the last encoder layer's output of
|
| 478 |
+
shape `(src_len, batch, embed_dim)`
|
| 479 |
+
- **encoder_padding_mask** (ByteTensor): the positions of
|
| 480 |
+
padding elements of shape `(batch, src_len)`
|
| 481 |
+
- prev_output_tokens
|
| 482 |
+
- **encoder_states** (List[Tensor]): all intermediate
|
| 483 |
+
hidden states of shape `(src_len, batch, embed_dim)`.
|
| 484 |
+
Only populated if *return_all_hiddens* is True.
|
| 485 |
+
)
|
| 486 |
+
"""
|
| 487 |
+
dummy_prev_output_tokens = torch.zeros(
|
| 488 |
+
1, dtype=src_tokens.dtype, device=src_tokens.device
|
| 489 |
+
)
|
| 490 |
+
input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens)
|
| 491 |
+
if self.use_pipeline:
|
| 492 |
+
input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
|
| 493 |
+
encoder_out = self.model(input_tuple)
|
| 494 |
+
else:
|
| 495 |
+
encoder_embed_output_tuple = self.embedding_layer(input_tuple)
|
| 496 |
+
encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple)
|
| 497 |
+
encoder_out = self.final_layer_norm(encoder_layers_output)
|
| 498 |
+
# first element is the encoder output
|
| 499 |
+
# second element is the encoder padding mask
|
| 500 |
+
# the remaining elements of EncoderOut are not computed by
|
| 501 |
+
# the PipelineParallelTransformer
|
| 502 |
+
return EncoderOut(encoder_out[0], encoder_out[1], None, None, None, None)
|
| 503 |
+
|
| 504 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
| 505 |
+
"""
|
| 506 |
+
Reorder encoder output according to *new_order*.
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
encoder_out: output from the ``forward()`` method
|
| 510 |
+
new_order (LongTensor): desired order
|
| 511 |
+
|
| 512 |
+
Returns:
|
| 513 |
+
*encoder_out* rearranged according to *new_order*
|
| 514 |
+
"""
|
| 515 |
+
if encoder_out.encoder_out is not None:
|
| 516 |
+
encoder_out = encoder_out._replace(
|
| 517 |
+
encoder_out=encoder_out.encoder_out.index_select(1, new_order)
|
| 518 |
+
)
|
| 519 |
+
if encoder_out.encoder_padding_mask is not None:
|
| 520 |
+
encoder_out = encoder_out._replace(
|
| 521 |
+
encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(
|
| 522 |
+
0, new_order
|
| 523 |
+
)
|
| 524 |
+
)
|
| 525 |
+
if encoder_out.encoder_embedding is not None:
|
| 526 |
+
encoder_out = encoder_out._replace(
|
| 527 |
+
encoder_embedding=encoder_out.encoder_embedding.index_select(
|
| 528 |
+
0, new_order
|
| 529 |
+
)
|
| 530 |
+
)
|
| 531 |
+
if encoder_out.encoder_states is not None:
|
| 532 |
+
for idx, state in enumerate(encoder_out.encoder_states):
|
| 533 |
+
encoder_out.encoder_states[idx] = state.index_select(1, new_order)
|
| 534 |
+
return encoder_out
|
| 535 |
+
|
| 536 |
+
def max_positions(self):
|
| 537 |
+
"""Maximum input length supported by the encoder."""
|
| 538 |
+
if self.embedding_layer.embed_positions is None:
|
| 539 |
+
return self.embedding_layer.max_source_positions
|
| 540 |
+
return min(
|
| 541 |
+
self.embedding_layer.max_source_positions,
|
| 542 |
+
self.embedding_layer.embed_positions.max_positions,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
class TransformerDecoder(FairseqDecoder):
|
| 547 |
+
"""
|
| 548 |
+
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
| 549 |
+
is a :class:`TransformerDecoderLayer`.
|
| 550 |
+
|
| 551 |
+
Args:
|
| 552 |
+
args (argparse.Namespace): parsed command-line arguments
|
| 553 |
+
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
| 554 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
| 555 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
| 556 |
+
(default: False).
|
| 557 |
+
"""
|
| 558 |
+
|
| 559 |
+
def __init__(
|
| 560 |
+
self,
|
| 561 |
+
args,
|
| 562 |
+
dictionary,
|
| 563 |
+
embed_tokens,
|
| 564 |
+
no_encoder_attn=False,
|
| 565 |
+
decoder_module_list=None,
|
| 566 |
+
):
|
| 567 |
+
super().__init__(dictionary)
|
| 568 |
+
self.register_buffer("version", torch.Tensor([3]))
|
| 569 |
+
try:
|
| 570 |
+
from fairscale.nn import Pipe
|
| 571 |
+
except ImportError:
|
| 572 |
+
raise ImportError("Please install fairscale with: pip install fairscale")
|
| 573 |
+
if decoder_module_list is None:
|
| 574 |
+
embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
|
| 575 |
+
layers = [
|
| 576 |
+
TransformerDecoderLayer(args, no_encoder_attn)
|
| 577 |
+
for _ in range(args.decoder_layers)
|
| 578 |
+
]
|
| 579 |
+
decoder_output_layer = TransformerDecoderOutputLayer(
|
| 580 |
+
args, embed_tokens, dictionary
|
| 581 |
+
)
|
| 582 |
+
decoder_module_list = [embedding_layer] + layers + [decoder_output_layer]
|
| 583 |
+
self.use_pipeline = getattr(args, "pipeline_decoder_balance", None) is not None
|
| 584 |
+
if self.use_pipeline:
|
| 585 |
+
decoder_balance = utils.eval_str_list(
|
| 586 |
+
args.pipeline_decoder_balance, type=int
|
| 587 |
+
)
|
| 588 |
+
decoder_devices = utils.eval_str_list(
|
| 589 |
+
args.pipeline_decoder_devices, type=int
|
| 590 |
+
)
|
| 591 |
+
assert sum(decoder_balance) == len(decoder_module_list), (
|
| 592 |
+
f"Sum of decoder_balance={decoder_balance} is not equal "
|
| 593 |
+
+ f"to num_decoder_modules={len(decoder_module_list)}"
|
| 594 |
+
)
|
| 595 |
+
self.model = Pipe(
|
| 596 |
+
module=nn.Sequential(*decoder_module_list),
|
| 597 |
+
balance=decoder_balance,
|
| 598 |
+
devices=decoder_devices,
|
| 599 |
+
chunks=args.pipeline_chunks,
|
| 600 |
+
checkpoint=args.pipeline_checkpoint,
|
| 601 |
+
)
|
| 602 |
+
else:
|
| 603 |
+
self.embedding_layer = decoder_module_list[0]
|
| 604 |
+
self.decoder_layers = nn.Sequential(*decoder_module_list[1:-1])
|
| 605 |
+
self.decoder_output_layer = decoder_module_list[-1]
|
| 606 |
+
|
| 607 |
+
def forward(
|
| 608 |
+
self,
|
| 609 |
+
prev_output_tokens,
|
| 610 |
+
encoder_out=None,
|
| 611 |
+
):
|
| 612 |
+
"""
|
| 613 |
+
Args:
|
| 614 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
| 615 |
+
`(batch, tgt_len)`, for teacher forcing
|
| 616 |
+
encoder_out (optional): output from the encoder, used for
|
| 617 |
+
encoder-side attention
|
| 618 |
+
incremental_state (dict): dictionary used for storing state during
|
| 619 |
+
:ref:`Incremental decoding`
|
| 620 |
+
features_only (bool, optional): only return features without
|
| 621 |
+
applying output layer (default: False).
|
| 622 |
+
|
| 623 |
+
Returns:
|
| 624 |
+
tuple:
|
| 625 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
| 626 |
+
- a dictionary with any model-specific outputs
|
| 627 |
+
"""
|
| 628 |
+
input_tuple = (
|
| 629 |
+
encoder_out.encoder_out,
|
| 630 |
+
encoder_out.encoder_padding_mask,
|
| 631 |
+
prev_output_tokens,
|
| 632 |
+
)
|
| 633 |
+
if self.use_pipeline:
|
| 634 |
+
input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
|
| 635 |
+
return (self.model(input_tuple),)
|
| 636 |
+
else:
|
| 637 |
+
embed_layer_output = self.embedding_layer(input_tuple)
|
| 638 |
+
state = self.decoder_layers(embed_layer_output)
|
| 639 |
+
return (self.decoder_output_layer(state),)
|
| 640 |
+
|
| 641 |
+
def output_layer(self, features, **kwargs):
|
| 642 |
+
"""Project features to the vocabulary size."""
|
| 643 |
+
if self.adaptive_softmax is None:
|
| 644 |
+
# project back to size of vocabulary
|
| 645 |
+
if self.share_input_output_embed:
|
| 646 |
+
return F.linear(features, self.embed_tokens.weight)
|
| 647 |
+
else:
|
| 648 |
+
return F.linear(features, self.embed_out)
|
| 649 |
+
else:
|
| 650 |
+
return features
|
| 651 |
+
|
| 652 |
+
def max_positions(self):
|
| 653 |
+
"""Maximum output length supported by the decoder."""
|
| 654 |
+
if self.embedding_layer.embed_positions is None:
|
| 655 |
+
return self.embedding_layer.max_target_positions
|
| 656 |
+
return min(
|
| 657 |
+
self.embedding_layer.max_target_positions,
|
| 658 |
+
self.embedding_layer.embed_positions.max_positions,
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
def buffered_future_mask(self, tensor):
|
| 662 |
+
dim = tensor.size(0)
|
| 663 |
+
if (
|
| 664 |
+
not hasattr(self, "_future_mask")
|
| 665 |
+
or self._future_mask is None
|
| 666 |
+
or self._future_mask.device != tensor.device
|
| 667 |
+
or self._future_mask.size(0) < dim
|
| 668 |
+
):
|
| 669 |
+
self._future_mask = torch.triu(
|
| 670 |
+
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
|
| 671 |
+
)
|
| 672 |
+
return self._future_mask[:dim, :dim]
|
| 673 |
+
|
| 674 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 675 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
| 676 |
+
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
|
| 677 |
+
weights_key = "{}.embed_positions.weights".format(name)
|
| 678 |
+
if weights_key in state_dict:
|
| 679 |
+
del state_dict[weights_key]
|
| 680 |
+
state_dict[
|
| 681 |
+
"{}.embed_positions._float_tensor".format(name)
|
| 682 |
+
] = torch.FloatTensor(1)
|
| 683 |
+
|
| 684 |
+
for i in range(len(self.layers)):
|
| 685 |
+
# update layer norms
|
| 686 |
+
layer_norm_map = {
|
| 687 |
+
"0": "self_attn_layer_norm",
|
| 688 |
+
"1": "encoder_attn_layer_norm",
|
| 689 |
+
"2": "final_layer_norm",
|
| 690 |
+
}
|
| 691 |
+
for old, new in layer_norm_map.items():
|
| 692 |
+
for m in ("weight", "bias"):
|
| 693 |
+
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
|
| 694 |
+
if k in state_dict:
|
| 695 |
+
state_dict[
|
| 696 |
+
"{}.layers.{}.{}.{}".format(name, i, new, m)
|
| 697 |
+
] = state_dict[k]
|
| 698 |
+
del state_dict[k]
|
| 699 |
+
|
| 700 |
+
version_key = "{}.version".format(name)
|
| 701 |
+
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
|
| 702 |
+
# earlier checkpoints did not normalize after the stack of layers
|
| 703 |
+
self.layer_norm = None
|
| 704 |
+
self.normalize = False
|
| 705 |
+
state_dict[version_key] = torch.Tensor([1])
|
| 706 |
+
|
| 707 |
+
return state_dict
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
@register_model_architecture(
|
| 711 |
+
"pipeline_parallel_transformer", "transformer_iwslt_de_en_pipeline_parallel"
|
| 712 |
+
)
|
| 713 |
+
def transformer_iwslt_de_en_dist(args):
|
| 714 |
+
transformer_iwslt_de_en(args)
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
@register_model_architecture(
|
| 718 |
+
"pipeline_parallel_transformer", "transformer_wmt_en_de_big_pipeline_parallel"
|
| 719 |
+
)
|
| 720 |
+
def transformer_wmt_en_de_big_dist(args):
|
| 721 |
+
transformer_wmt_en_de_big(args)
|
fairseq-0.10.2/fairseq/model_parallel/models/roberta/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .model import * # noqa
|
fairseq-0.10.2/fairseq/model_parallel/models/roberta/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (208 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (5.58 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/distributed_fairseq_model.cpython-310.pyc
ADDED
|
Binary file (3.06 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/fairseq_decoder.cpython-310.pyc
ADDED
|
Binary file (3.41 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/fairseq_encoder.cpython-310.pyc
ADDED
|
Binary file (3.59 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-310.pyc
ADDED
|
Binary file (4.8 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/fairseq_model.cpython-310.pyc
ADDED
|
Binary file (20.6 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/fconv.cpython-310.pyc
ADDED
|
Binary file (19.1 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/fconv_lm.cpython-310.pyc
ADDED
|
Binary file (3.77 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/fconv_self_att.cpython-310.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/lightconv.cpython-310.pyc
ADDED
|
Binary file (25 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/lightconv_lm.cpython-310.pyc
ADDED
|
Binary file (6.99 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/lstm.cpython-310.pyc
ADDED
|
Binary file (18.6 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/lstm_lm.cpython-310.pyc
ADDED
|
Binary file (4.34 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/masked_lm.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/model_utils.cpython-310.pyc
ADDED
|
Binary file (2.35 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/multilingual_transformer.cpython-310.pyc
ADDED
|
Binary file (6.69 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/transformer.cpython-310.pyc
ADDED
|
Binary file (27.7 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/transformer_lm.cpython-310.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/bart/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .hub_interface import * # noqa
|
| 7 |
+
from .model import * # noqa
|
fairseq-0.10.2/fairseq/models/bart/hub_interface.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import logging
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from fairseq import utils
|
| 15 |
+
from fairseq.data import encoders
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BARTHubInterface(nn.Module):
|
| 22 |
+
"""A simple PyTorch Hub interface to BART.
|
| 23 |
+
|
| 24 |
+
Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, args, task, model):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.args = args
|
| 30 |
+
self.task = task
|
| 31 |
+
self.model = model
|
| 32 |
+
|
| 33 |
+
self.bpe = encoders.build_bpe(args)
|
| 34 |
+
|
| 35 |
+
self.max_positions = min(
|
| 36 |
+
utils.resolve_max_positions(
|
| 37 |
+
self.task.max_positions(),
|
| 38 |
+
self.model.max_positions(),
|
| 39 |
+
)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# this is useful for determining the device
|
| 43 |
+
self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def device(self):
|
| 47 |
+
return self._float_tensor.device
|
| 48 |
+
|
| 49 |
+
def encode(
|
| 50 |
+
self, sentence: str, *addl_sentences, no_separator=True
|
| 51 |
+
) -> torch.LongTensor:
|
| 52 |
+
"""
|
| 53 |
+
BPE-encode a sentence (or multiple sentences).
|
| 54 |
+
|
| 55 |
+
Every sequence begins with a beginning-of-sentence (`<s>`) symbol.
|
| 56 |
+
Every sentence ends with an end-of-sentence (`</s>`).
|
| 57 |
+
|
| 58 |
+
Example (single sentence): `<s> a b c </s>`
|
| 59 |
+
Example (sentence pair): `<s> d e f </s> 1 2 3 </s>`
|
| 60 |
+
|
| 61 |
+
The BPE encoding follows GPT-2. One subtle detail is that the GPT-2 BPE
|
| 62 |
+
requires leading spaces. For example::
|
| 63 |
+
|
| 64 |
+
>>> bart.encode('Hello world').tolist()
|
| 65 |
+
[0, 31414, 232, 2]
|
| 66 |
+
>>> bart.encode(' world').tolist()
|
| 67 |
+
[0, 232, 2]
|
| 68 |
+
>>> bart.encode('world').tolist()
|
| 69 |
+
[0, 8331, 2]
|
| 70 |
+
"""
|
| 71 |
+
tokens = self.bpe.encode(sentence)
|
| 72 |
+
if len(tokens.split(" ")) > self.max_positions - 2:
|
| 73 |
+
tokens = " ".join(tokens.split(" ")[: self.max_positions - 2])
|
| 74 |
+
bpe_sentence = "<s> " + tokens + " </s>"
|
| 75 |
+
for s in addl_sentences:
|
| 76 |
+
bpe_sentence += " </s>" if not no_separator else ""
|
| 77 |
+
bpe_sentence += " " + self.bpe.encode(s) + " </s>"
|
| 78 |
+
tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False)
|
| 79 |
+
return tokens.long()
|
| 80 |
+
|
| 81 |
+
def decode(self, tokens: torch.LongTensor):
|
| 82 |
+
assert tokens.dim() == 1
|
| 83 |
+
tokens = tokens.cpu().numpy()
|
| 84 |
+
if tokens[0] == self.task.source_dictionary.bos():
|
| 85 |
+
tokens = tokens[1:] # remove <s>
|
| 86 |
+
eos_mask = tokens == self.task.source_dictionary.eos()
|
| 87 |
+
doc_mask = eos_mask[1:] & eos_mask[:-1]
|
| 88 |
+
sentences = np.split(tokens, doc_mask.nonzero()[0] + 1)
|
| 89 |
+
sentences = [
|
| 90 |
+
self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences
|
| 91 |
+
]
|
| 92 |
+
if len(sentences) == 1:
|
| 93 |
+
return sentences[0]
|
| 94 |
+
return sentences
|
| 95 |
+
|
| 96 |
+
def _build_sample(self, src_tokens: List[torch.LongTensor]):
|
| 97 |
+
# assert torch.is_tensor(src_tokens)
|
| 98 |
+
dataset = self.task.build_dataset_for_inference(
|
| 99 |
+
src_tokens,
|
| 100 |
+
[x.numel() for x in src_tokens],
|
| 101 |
+
)
|
| 102 |
+
sample = dataset.collater(dataset)
|
| 103 |
+
sample = utils.apply_to_sample(lambda tensor: tensor.to(self.device), sample)
|
| 104 |
+
return sample
|
| 105 |
+
|
| 106 |
+
def sample(
|
| 107 |
+
self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs
|
| 108 |
+
) -> str:
|
| 109 |
+
input = [self.encode(sentence) for sentence in sentences]
|
| 110 |
+
hypos = self.generate(input, beam, verbose, **kwargs)
|
| 111 |
+
return [self.decode(x["tokens"]) for x in hypos]
|
| 112 |
+
|
| 113 |
+
def generate(
|
| 114 |
+
self,
|
| 115 |
+
tokens: List[torch.LongTensor],
|
| 116 |
+
beam: int = 5,
|
| 117 |
+
verbose: bool = False,
|
| 118 |
+
**kwargs
|
| 119 |
+
) -> torch.LongTensor:
|
| 120 |
+
sample = self._build_sample(tokens)
|
| 121 |
+
|
| 122 |
+
# build generator using current args as well as any kwargs
|
| 123 |
+
gen_args = copy.copy(self.args)
|
| 124 |
+
gen_args.beam = beam
|
| 125 |
+
for k, v in kwargs.items():
|
| 126 |
+
setattr(gen_args, k, v)
|
| 127 |
+
generator = self.task.build_generator([self.model], gen_args)
|
| 128 |
+
translations = self.task.inference_step(
|
| 129 |
+
generator,
|
| 130 |
+
[self.model],
|
| 131 |
+
sample,
|
| 132 |
+
prefix_tokens=sample["net_input"]["src_tokens"]
|
| 133 |
+
.new_zeros((len(tokens), 1))
|
| 134 |
+
.fill_(self.task.source_dictionary.bos()),
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
if verbose:
|
| 138 |
+
src_str_with_unk = self.string(tokens)
|
| 139 |
+
logger.info("S\t{}".format(src_str_with_unk))
|
| 140 |
+
|
| 141 |
+
def getarg(name, default):
|
| 142 |
+
return getattr(gen_args, name, getattr(self.args, name, default))
|
| 143 |
+
|
| 144 |
+
# Process top predictions
|
| 145 |
+
hypos = [x[0] for x in translations]
|
| 146 |
+
hypos = [v for _, v in sorted(zip(sample["id"].tolist(), hypos))]
|
| 147 |
+
return hypos
|
| 148 |
+
|
| 149 |
+
def extract_features(
|
| 150 |
+
self, tokens: torch.LongTensor, return_all_hiddens: bool = False
|
| 151 |
+
) -> torch.Tensor:
|
| 152 |
+
if tokens.dim() == 1:
|
| 153 |
+
tokens = tokens.unsqueeze(0)
|
| 154 |
+
if tokens.size(-1) > min(self.model.max_positions()):
|
| 155 |
+
raise ValueError(
|
| 156 |
+
"tokens exceeds maximum length: {} > {}".format(
|
| 157 |
+
tokens.size(-1), self.model.max_positions()
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
tokens.to(device=self.device),
|
| 161 |
+
prev_output_tokens = tokens.clone()
|
| 162 |
+
|
| 163 |
+
prev_output_tokens[:, 0] = tokens.gather(
|
| 164 |
+
1,
|
| 165 |
+
(tokens.ne(self.task.source_dictionary.pad()).sum(dim=1) - 1).unsqueeze(-1),
|
| 166 |
+
).squeeze()
|
| 167 |
+
|
| 168 |
+
prev_output_tokens[:, 1:] = tokens[:, :-1]
|
| 169 |
+
features, extra = self.model(
|
| 170 |
+
src_tokens=tokens,
|
| 171 |
+
src_lengths=None,
|
| 172 |
+
prev_output_tokens=prev_output_tokens,
|
| 173 |
+
features_only=True,
|
| 174 |
+
return_all_hiddens=return_all_hiddens,
|
| 175 |
+
)
|
| 176 |
+
if return_all_hiddens:
|
| 177 |
+
# convert from T x B x C -> B x T x C
|
| 178 |
+
inner_states = extra["inner_states"]
|
| 179 |
+
return [inner_state.transpose(0, 1) for inner_state in inner_states]
|
| 180 |
+
else:
|
| 181 |
+
return features # just the last layer's features
|
| 182 |
+
|
| 183 |
+
def register_classification_head(
|
| 184 |
+
self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs
|
| 185 |
+
):
|
| 186 |
+
self.model.register_classification_head(
|
| 187 |
+
name, num_classes=num_classes, embedding_size=embedding_size, **kwargs
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False):
|
| 191 |
+
if tokens.dim() == 1:
|
| 192 |
+
tokens = tokens.unsqueeze(0)
|
| 193 |
+
features = self.extract_features(tokens.to(device=self.device))
|
| 194 |
+
sentence_representation = features[
|
| 195 |
+
tokens.eq(self.task.source_dictionary.eos()), :
|
| 196 |
+
].view(features.size(0), -1, features.size(-1))[:, -1, :]
|
| 197 |
+
|
| 198 |
+
logits = self.model.classification_heads[head](sentence_representation)
|
| 199 |
+
if return_logits:
|
| 200 |
+
return logits
|
| 201 |
+
return F.log_softmax(logits, dim=-1)
|
fairseq-0.10.2/fairseq/models/distributed_fairseq_model.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import inspect
|
| 7 |
+
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
_GOSSIP_DISABLED = False
|
| 13 |
+
try:
|
| 14 |
+
import gossip
|
| 15 |
+
except ImportError:
|
| 16 |
+
_GOSSIP_DISABLED = True
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def DistributedFairseqModel(args, model, process_group=None):
|
| 20 |
+
"""
|
| 21 |
+
Wrap a *model* to support distributed data parallel training.
|
| 22 |
+
|
| 23 |
+
This is similar to the built-in DistributedDataParallel, but allows
|
| 24 |
+
additional configuration of the DistributedDataParallel class to
|
| 25 |
+
use, and also provides easier access to the wrapped model by
|
| 26 |
+
forwarding requests for missing attributes to the wrapped model.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
args (argparse.Namespace): fairseq args
|
| 30 |
+
model (BaseFairseqModel): model to wrap
|
| 31 |
+
"""
|
| 32 |
+
# determine which DDP class to extend
|
| 33 |
+
assert isinstance(model, nn.Module)
|
| 34 |
+
if args.distributed_wrapper == "DDP" and args.ddp_backend == "c10d":
|
| 35 |
+
ddp_class = nn.parallel.DistributedDataParallel
|
| 36 |
+
init_kwargs = dict(
|
| 37 |
+
module=model,
|
| 38 |
+
device_ids=[args.device_id],
|
| 39 |
+
output_device=args.device_id,
|
| 40 |
+
broadcast_buffers=args.broadcast_buffers,
|
| 41 |
+
bucket_cap_mb=args.bucket_cap_mb,
|
| 42 |
+
process_group=process_group,
|
| 43 |
+
)
|
| 44 |
+
# Maintain backward compatibility
|
| 45 |
+
if "check_reduction" in inspect.getargspec(ddp_class)[0]:
|
| 46 |
+
init_kwargs["check_reduction"] = True
|
| 47 |
+
if "find_unused_parameters" in inspect.getargspec(ddp_class)[0]:
|
| 48 |
+
init_kwargs["find_unused_parameters"] = args.find_unused_parameters
|
| 49 |
+
elif args.distributed_wrapper == "DDP" and args.ddp_backend == "no_c10d":
|
| 50 |
+
ddp_class = LegacyDistributedDataParallel
|
| 51 |
+
init_kwargs = dict(
|
| 52 |
+
module=model,
|
| 53 |
+
world_size=args.distributed_world_size,
|
| 54 |
+
buffer_size=2 ** 28,
|
| 55 |
+
process_group=process_group,
|
| 56 |
+
)
|
| 57 |
+
elif args.distributed_wrapper == "SlowMo":
|
| 58 |
+
if _GOSSIP_DISABLED:
|
| 59 |
+
raise ImportError(
|
| 60 |
+
"Cannot find gossip library. Please install from: "
|
| 61 |
+
"github.com/facebookresearch/stochastic_gradient_push"
|
| 62 |
+
)
|
| 63 |
+
ddp_class = gossip.GossipDataParallel
|
| 64 |
+
|
| 65 |
+
# The values of slowmo_momentum below were obtained by tuning on the
|
| 66 |
+
# En-De 16 dataset by training the transformer_wmt_en_de_large model
|
| 67 |
+
if args.slowmo_momentum is None:
|
| 68 |
+
if args.distributed_world_size <= 16:
|
| 69 |
+
args.slowmo_momentum = 0.0
|
| 70 |
+
elif args.distributed_world_size <= 32:
|
| 71 |
+
args.slowmo_momentum = 0.2
|
| 72 |
+
elif args.distributed_world_size <= 64:
|
| 73 |
+
args.slowmo_momentum = 0.5
|
| 74 |
+
else:
|
| 75 |
+
args.slowmo_momentum = 0.6
|
| 76 |
+
|
| 77 |
+
init_kwargs = dict(
|
| 78 |
+
module=model,
|
| 79 |
+
device_ids=[args.device_id],
|
| 80 |
+
output_device=args.device_id,
|
| 81 |
+
broadcast_buffers=args.broadcast_buffers,
|
| 82 |
+
nprocs_per_node=args.nprocs_per_node,
|
| 83 |
+
slowmo_momentum=args.slowmo_momentum,
|
| 84 |
+
localsgd=(args.slowmo_algorithm == "LocalSGD"),
|
| 85 |
+
localsgd_frequency=args.localsgd_frequency,
|
| 86 |
+
)
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError("Unknown --ddp-backend: " + args.ddp_backend)
|
| 89 |
+
|
| 90 |
+
class _DistributedFairseqModel(ddp_class):
|
| 91 |
+
"""Extend DistributedDataParallel to check for missing
|
| 92 |
+
attributes in the wrapped module."""
|
| 93 |
+
|
| 94 |
+
def __init__(self, *args, **kwargs):
|
| 95 |
+
super().__init__(*args, **kwargs)
|
| 96 |
+
|
| 97 |
+
def __getattr__(self, name):
|
| 98 |
+
wrapped_module = super().__getattr__("module")
|
| 99 |
+
if hasattr(wrapped_module, name):
|
| 100 |
+
return getattr(wrapped_module, name)
|
| 101 |
+
return super().__getattr__(name)
|
| 102 |
+
|
| 103 |
+
return _DistributedFairseqModel(**init_kwargs)
|
fairseq-0.10.2/fairseq/models/fairseq_decoder.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from fairseq import utils
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FairseqDecoder(nn.Module):
|
| 14 |
+
"""Base class for decoders."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, dictionary):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.dictionary = dictionary
|
| 19 |
+
self.onnx_trace = False
|
| 20 |
+
|
| 21 |
+
def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
|
| 22 |
+
"""
|
| 23 |
+
Args:
|
| 24 |
+
prev_output_tokens (LongTensor): shifted output tokens of shape
|
| 25 |
+
`(batch, tgt_len)`, for teacher forcing
|
| 26 |
+
encoder_out (dict, optional): output from the encoder, used for
|
| 27 |
+
encoder-side attention
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
tuple:
|
| 31 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
| 32 |
+
- a dictionary with any model-specific outputs
|
| 33 |
+
"""
|
| 34 |
+
x, extra = self.extract_features(
|
| 35 |
+
prev_output_tokens, encoder_out=encoder_out, **kwargs
|
| 36 |
+
)
|
| 37 |
+
x = self.output_layer(x)
|
| 38 |
+
return x, extra
|
| 39 |
+
|
| 40 |
+
def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
|
| 41 |
+
"""
|
| 42 |
+
Returns:
|
| 43 |
+
tuple:
|
| 44 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
| 45 |
+
- a dictionary with any model-specific outputs
|
| 46 |
+
"""
|
| 47 |
+
raise NotImplementedError
|
| 48 |
+
|
| 49 |
+
def output_layer(self, features, **kwargs):
|
| 50 |
+
"""
|
| 51 |
+
Project features to the default output size, e.g., vocabulary size.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
features (Tensor): features returned by *extract_features*.
|
| 55 |
+
"""
|
| 56 |
+
raise NotImplementedError
|
| 57 |
+
|
| 58 |
+
def get_normalized_probs(
|
| 59 |
+
self,
|
| 60 |
+
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
| 61 |
+
log_probs: bool,
|
| 62 |
+
sample: Optional[Dict[str, Tensor]] = None,
|
| 63 |
+
):
|
| 64 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
| 65 |
+
|
| 66 |
+
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
|
| 67 |
+
if sample is not None:
|
| 68 |
+
assert "target" in sample
|
| 69 |
+
target = sample["target"]
|
| 70 |
+
else:
|
| 71 |
+
target = None
|
| 72 |
+
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
|
| 73 |
+
return out.exp_() if not log_probs else out
|
| 74 |
+
|
| 75 |
+
logits = net_output[0]
|
| 76 |
+
if log_probs:
|
| 77 |
+
return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
|
| 78 |
+
else:
|
| 79 |
+
return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
|
| 80 |
+
|
| 81 |
+
def max_positions(self):
|
| 82 |
+
"""Maximum input length supported by the decoder."""
|
| 83 |
+
return 1e6 # an arbitrary large number
|
| 84 |
+
|
| 85 |
+
def upgrade_state_dict(self, state_dict):
|
| 86 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
| 87 |
+
return state_dict
|
| 88 |
+
|
| 89 |
+
def prepare_for_onnx_export_(self):
|
| 90 |
+
self.onnx_trace = True
|
fairseq-0.10.2/fairseq/models/fairseq_encoder.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Dict, List, NamedTuple, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
EncoderOut = NamedTuple(
|
| 14 |
+
"EncoderOut",
|
| 15 |
+
[
|
| 16 |
+
("encoder_out", Tensor), # T x B x C
|
| 17 |
+
("encoder_padding_mask", Optional[Tensor]), # B x T
|
| 18 |
+
("encoder_embedding", Optional[Tensor]), # B x T x C
|
| 19 |
+
("encoder_states", Optional[List[Tensor]]), # List[T x B x C]
|
| 20 |
+
("src_tokens", Optional[Tensor]), # B x T
|
| 21 |
+
("src_lengths", Optional[Tensor]), # B x 1
|
| 22 |
+
],
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class FairseqEncoder(nn.Module):
|
| 27 |
+
"""Base class for encoders."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, dictionary):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.dictionary = dictionary
|
| 32 |
+
|
| 33 |
+
def forward(self, src_tokens, src_lengths=None, **kwargs):
|
| 34 |
+
"""
|
| 35 |
+
Args:
|
| 36 |
+
src_tokens (LongTensor): tokens in the source language of shape
|
| 37 |
+
`(batch, src_len)`
|
| 38 |
+
src_lengths (LongTensor): lengths of each source sentence of shape
|
| 39 |
+
`(batch)`
|
| 40 |
+
"""
|
| 41 |
+
raise NotImplementedError
|
| 42 |
+
|
| 43 |
+
def forward_torchscript(self, net_input: Dict[str, Tensor]):
|
| 44 |
+
"""A TorchScript-compatible version of forward.
|
| 45 |
+
|
| 46 |
+
Encoders which use additional arguments may want to override
|
| 47 |
+
this method for TorchScript compatibility.
|
| 48 |
+
"""
|
| 49 |
+
if torch.jit.is_scripting():
|
| 50 |
+
return self.forward(
|
| 51 |
+
src_tokens=net_input["src_tokens"],
|
| 52 |
+
src_lengths=net_input["src_lengths"],
|
| 53 |
+
)
|
| 54 |
+
else:
|
| 55 |
+
return self.forward_non_torchscript(net_input)
|
| 56 |
+
|
| 57 |
+
@torch.jit.unused
|
| 58 |
+
def forward_non_torchscript(self, net_input: Dict[str, Tensor]):
|
| 59 |
+
encoder_input = {
|
| 60 |
+
k: v for k, v in net_input.items() if k != "prev_output_tokens"
|
| 61 |
+
}
|
| 62 |
+
return self.forward(**encoder_input)
|
| 63 |
+
|
| 64 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
| 65 |
+
"""
|
| 66 |
+
Reorder encoder output according to `new_order`.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
encoder_out: output from the ``forward()`` method
|
| 70 |
+
new_order (LongTensor): desired order
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
`encoder_out` rearranged according to `new_order`
|
| 74 |
+
"""
|
| 75 |
+
raise NotImplementedError
|
| 76 |
+
|
| 77 |
+
def max_positions(self):
|
| 78 |
+
"""Maximum input length supported by the encoder."""
|
| 79 |
+
return 1e6 # an arbitrary large number
|
| 80 |
+
|
| 81 |
+
def upgrade_state_dict(self, state_dict):
|
| 82 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
| 83 |
+
return state_dict
|
| 84 |
+
|
| 85 |
+
def set_num_updates(self, num_updates):
|
| 86 |
+
"""State from trainer to pass along to model at every update."""
|
| 87 |
+
|
| 88 |
+
def _apply(m):
|
| 89 |
+
if hasattr(m, "set_num_updates") and m != self:
|
| 90 |
+
m.set_num_updates(num_updates)
|
| 91 |
+
|
| 92 |
+
self.apply(_apply)
|
fairseq-0.10.2/fairseq/models/fairseq_model.py
ADDED
|
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""
|
| 6 |
+
Base classes for various fairseq models.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Dict, List, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from fairseq import utils
|
| 16 |
+
# from fairseq.checkpoint_utils import prune_state_dict
|
| 17 |
+
from fairseq.data import Dictionary
|
| 18 |
+
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
| 19 |
+
from fairseq.models import FairseqDecoder, FairseqEncoder
|
| 20 |
+
from torch import Tensor
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BaseFairseqModel(nn.Module):
|
| 27 |
+
"""Base class for fairseq models."""
|
| 28 |
+
|
| 29 |
+
def __init__(self):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self._is_generation_fast = False
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def add_args(cls, parser):
|
| 35 |
+
"""Add model-specific arguments to the parser."""
|
| 36 |
+
dc = getattr(cls, "__dataclass", None)
|
| 37 |
+
if dc is not None:
|
| 38 |
+
# do not set defaults so that settings defaults from various architectures still works
|
| 39 |
+
gen_parser_from_dataclass(parser, dc(), delete_default=True)
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def build_model(cls, args, task):
|
| 43 |
+
"""Build a new model instance."""
|
| 44 |
+
raise NotImplementedError("Model must implement the build_model method")
|
| 45 |
+
|
| 46 |
+
def get_targets(self, sample, net_output):
|
| 47 |
+
"""Get targets from either the sample or the net's output."""
|
| 48 |
+
return sample["target"]
|
| 49 |
+
|
| 50 |
+
def get_normalized_probs(
|
| 51 |
+
self,
|
| 52 |
+
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
| 53 |
+
log_probs: bool,
|
| 54 |
+
sample: Optional[Dict[str, Tensor]] = None,
|
| 55 |
+
):
|
| 56 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
| 57 |
+
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
| 58 |
+
|
| 59 |
+
# TorchScript doesn't support super() method so that the scriptable Subclass
|
| 60 |
+
# can't access the base class model in Torchscript.
|
| 61 |
+
# Current workaround is to add a helper function with different name and
|
| 62 |
+
# call the helper function from scriptable Subclass.
|
| 63 |
+
def get_normalized_probs_scriptable(
|
| 64 |
+
self,
|
| 65 |
+
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
| 66 |
+
log_probs: bool,
|
| 67 |
+
sample: Optional[Dict[str, Tensor]] = None,
|
| 68 |
+
):
|
| 69 |
+
"""Scriptable helper function for get_normalized_probs in ~BaseFairseqModel"""
|
| 70 |
+
if hasattr(self, "decoder"):
|
| 71 |
+
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
|
| 72 |
+
elif torch.is_tensor(net_output):
|
| 73 |
+
# syntactic sugar for simple models which don't have a decoder
|
| 74 |
+
# (e.g., the classification tutorial)
|
| 75 |
+
logits = net_output.float()
|
| 76 |
+
if log_probs:
|
| 77 |
+
return F.log_softmax(logits, dim=-1)
|
| 78 |
+
else:
|
| 79 |
+
return F.softmax(logits, dim=-1)
|
| 80 |
+
raise NotImplementedError
|
| 81 |
+
|
| 82 |
+
def extract_features(self, *args, **kwargs):
|
| 83 |
+
"""Similar to *forward* but only return features."""
|
| 84 |
+
return self(*args, **kwargs)
|
| 85 |
+
|
| 86 |
+
def max_positions(self):
|
| 87 |
+
"""Maximum length supported by the model."""
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
def load_state_dict(self, state_dict, strict=True, args=None):
|
| 91 |
+
"""Copies parameters and buffers from *state_dict* into this module and
|
| 92 |
+
its descendants.
|
| 93 |
+
|
| 94 |
+
Overrides the method in :class:`nn.Module`. Compared with that method
|
| 95 |
+
this additionally "upgrades" *state_dicts* from old checkpoints.
|
| 96 |
+
"""
|
| 97 |
+
self.upgrade_state_dict(state_dict)
|
| 98 |
+
from fairseq.checkpoint_utils import prune_state_dict
|
| 99 |
+
new_state_dict = prune_state_dict(state_dict, args)
|
| 100 |
+
return super().load_state_dict(new_state_dict, strict)
|
| 101 |
+
|
| 102 |
+
def upgrade_state_dict(self, state_dict):
|
| 103 |
+
"""Upgrade old state dicts to work with newer code."""
|
| 104 |
+
self.upgrade_state_dict_named(state_dict, "")
|
| 105 |
+
|
| 106 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 107 |
+
"""Upgrade old state dicts to work with newer code.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
state_dict (dict): state dictionary to upgrade, in place
|
| 111 |
+
name (str): the state dict key corresponding to the current module
|
| 112 |
+
"""
|
| 113 |
+
assert state_dict is not None
|
| 114 |
+
|
| 115 |
+
def do_upgrade(m, prefix):
|
| 116 |
+
if len(prefix) > 0:
|
| 117 |
+
prefix += "."
|
| 118 |
+
|
| 119 |
+
for n, c in m.named_children():
|
| 120 |
+
name = prefix + n
|
| 121 |
+
if hasattr(c, "upgrade_state_dict_named"):
|
| 122 |
+
c.upgrade_state_dict_named(state_dict, name)
|
| 123 |
+
elif hasattr(c, "upgrade_state_dict"):
|
| 124 |
+
c.upgrade_state_dict(state_dict)
|
| 125 |
+
do_upgrade(c, name)
|
| 126 |
+
|
| 127 |
+
do_upgrade(self, name)
|
| 128 |
+
|
| 129 |
+
def set_num_updates(self, num_updates):
|
| 130 |
+
"""State from trainer to pass along to model at every update."""
|
| 131 |
+
|
| 132 |
+
def _apply(m):
|
| 133 |
+
if hasattr(m, "set_num_updates") and m != self:
|
| 134 |
+
m.set_num_updates(num_updates)
|
| 135 |
+
|
| 136 |
+
self.apply(_apply)
|
| 137 |
+
|
| 138 |
+
def prepare_for_inference_(self, args):
|
| 139 |
+
"""Prepare model for inference."""
|
| 140 |
+
kwargs = {}
|
| 141 |
+
kwargs["beamable_mm_beam_size"] = (
|
| 142 |
+
None if getattr(args, "no_beamable_mm", False) else getattr(args, "beam", 5)
|
| 143 |
+
)
|
| 144 |
+
kwargs["need_attn"] = getattr(args, "print_alignment", False)
|
| 145 |
+
if hasattr(args, "retain_dropout"):
|
| 146 |
+
kwargs["retain_dropout"] = args.retain_dropout
|
| 147 |
+
kwargs["retain_dropout_modules"] = getattr(
|
| 148 |
+
args, "retain_dropout_modules", None
|
| 149 |
+
)
|
| 150 |
+
self.make_generation_fast_(**kwargs)
|
| 151 |
+
|
| 152 |
+
def make_generation_fast_(self, **kwargs):
|
| 153 |
+
"""
|
| 154 |
+
Legacy entry point to optimize model for faster generation.
|
| 155 |
+
Prefer prepare_for_inference_.
|
| 156 |
+
"""
|
| 157 |
+
if self._is_generation_fast:
|
| 158 |
+
return # only apply once
|
| 159 |
+
self._is_generation_fast = True
|
| 160 |
+
|
| 161 |
+
# remove weight norm from all modules in the network
|
| 162 |
+
def apply_remove_weight_norm(module):
|
| 163 |
+
try:
|
| 164 |
+
nn.utils.remove_weight_norm(module)
|
| 165 |
+
except (AttributeError, ValueError): # this module didn't have weight norm
|
| 166 |
+
return
|
| 167 |
+
|
| 168 |
+
self.apply(apply_remove_weight_norm)
|
| 169 |
+
|
| 170 |
+
def apply_make_generation_fast_(module, prefix):
|
| 171 |
+
if len(prefix) > 0:
|
| 172 |
+
prefix += "."
|
| 173 |
+
|
| 174 |
+
base_func = BaseFairseqModel.make_generation_fast_
|
| 175 |
+
for n, m in module.named_modules():
|
| 176 |
+
if (
|
| 177 |
+
m != self
|
| 178 |
+
and hasattr(m, "make_generation_fast_")
|
| 179 |
+
# don't call this implementation again, e.g., if
|
| 180 |
+
# children modules also inherit from BaseFairseqModel
|
| 181 |
+
and m.make_generation_fast_.__func__ is not base_func
|
| 182 |
+
):
|
| 183 |
+
name = prefix + n
|
| 184 |
+
m.make_generation_fast_(name=name, **kwargs)
|
| 185 |
+
|
| 186 |
+
apply_make_generation_fast_(self, "")
|
| 187 |
+
|
| 188 |
+
def train(mode=True):
|
| 189 |
+
if mode:
|
| 190 |
+
raise RuntimeError("cannot train after make_generation_fast")
|
| 191 |
+
|
| 192 |
+
# this model should no longer be used for training
|
| 193 |
+
self.eval()
|
| 194 |
+
self.train = train
|
| 195 |
+
|
| 196 |
+
def prepare_for_onnx_export_(self, **kwargs):
|
| 197 |
+
"""Make model exportable via ONNX trace."""
|
| 198 |
+
seen = set()
|
| 199 |
+
|
| 200 |
+
def apply_prepare_for_onnx_export_(module):
|
| 201 |
+
if (
|
| 202 |
+
module != self
|
| 203 |
+
and hasattr(module, "prepare_for_onnx_export_")
|
| 204 |
+
and module not in seen
|
| 205 |
+
):
|
| 206 |
+
seen.add(module)
|
| 207 |
+
module.prepare_for_onnx_export_(**kwargs)
|
| 208 |
+
|
| 209 |
+
self.apply(apply_prepare_for_onnx_export_)
|
| 210 |
+
|
| 211 |
+
def prepare_for_tpu_(self, **kwargs):
|
| 212 |
+
"""Optionally modify model for use on TPUs."""
|
| 213 |
+
seen = set()
|
| 214 |
+
|
| 215 |
+
def apply_prepare_for_tpu_(module):
|
| 216 |
+
if (
|
| 217 |
+
module != self
|
| 218 |
+
and hasattr(module, "prepare_for_tpu_")
|
| 219 |
+
and module not in seen
|
| 220 |
+
):
|
| 221 |
+
seen.add(module)
|
| 222 |
+
module.prepare_for_tpu_(**kwargs)
|
| 223 |
+
|
| 224 |
+
self.apply(apply_prepare_for_tpu_)
|
| 225 |
+
|
| 226 |
+
@classmethod
|
| 227 |
+
def upgrade_args(cls, args):
|
| 228 |
+
if hasattr(args, "max_sentences") and not hasattr(args, "batch_size"):
|
| 229 |
+
args.batch_size = args.max_sentences
|
| 230 |
+
|
| 231 |
+
@classmethod
|
| 232 |
+
def from_pretrained(
|
| 233 |
+
cls,
|
| 234 |
+
model_name_or_path,
|
| 235 |
+
checkpoint_file="model.pt",
|
| 236 |
+
data_name_or_path=".",
|
| 237 |
+
**kwargs,
|
| 238 |
+
):
|
| 239 |
+
"""
|
| 240 |
+
Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model
|
| 241 |
+
file. Downloads and caches the pre-trained model file if needed.
|
| 242 |
+
|
| 243 |
+
The base implementation returns a
|
| 244 |
+
:class:`~fairseq.hub_utils.GeneratorHubInterface`, which can be used to
|
| 245 |
+
generate translations or sample from language models. The underlying
|
| 246 |
+
:class:`~fairseq.models.FairseqModel` can be accessed via the
|
| 247 |
+
*generator.models* attribute.
|
| 248 |
+
|
| 249 |
+
Other models may override this to implement custom hub interfaces.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
model_name_or_path (str): either the name of a pre-trained model to
|
| 253 |
+
load or a path/URL to a pre-trained model state dict
|
| 254 |
+
checkpoint_file (str, optional): colon-separated list of checkpoint
|
| 255 |
+
files in the model archive to ensemble (default: 'model.pt')
|
| 256 |
+
data_name_or_path (str, optional): point args.data to the archive
|
| 257 |
+
at the given path/URL. Can start with '.' or './' to reuse the
|
| 258 |
+
model archive path.
|
| 259 |
+
"""
|
| 260 |
+
from fairseq import hub_utils
|
| 261 |
+
|
| 262 |
+
x = hub_utils.from_pretrained(
|
| 263 |
+
model_name_or_path,
|
| 264 |
+
checkpoint_file,
|
| 265 |
+
data_name_or_path,
|
| 266 |
+
archive_map=cls.hub_models(),
|
| 267 |
+
**kwargs,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
cls.upgrade_args(x["args"])
|
| 271 |
+
|
| 272 |
+
logger.info(x["args"])
|
| 273 |
+
return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"])
|
| 274 |
+
|
| 275 |
+
@classmethod
|
| 276 |
+
def hub_models(cls):
|
| 277 |
+
return {}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class FairseqEncoderDecoderModel(BaseFairseqModel):
|
| 281 |
+
"""Base class for encoder-decoder models.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
encoder (FairseqEncoder): the encoder
|
| 285 |
+
decoder (FairseqDecoder): the decoder
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
def __init__(self, encoder, decoder):
|
| 289 |
+
super().__init__()
|
| 290 |
+
|
| 291 |
+
self.encoder = encoder
|
| 292 |
+
self.decoder = decoder
|
| 293 |
+
assert isinstance(self.encoder, FairseqEncoder)
|
| 294 |
+
assert isinstance(self.decoder, FairseqDecoder)
|
| 295 |
+
|
| 296 |
+
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
|
| 297 |
+
"""
|
| 298 |
+
Run the forward pass for an encoder-decoder model.
|
| 299 |
+
|
| 300 |
+
First feed a batch of source tokens through the encoder. Then, feed the
|
| 301 |
+
encoder output and previous decoder outputs (i.e., teacher forcing) to
|
| 302 |
+
the decoder to produce the next outputs::
|
| 303 |
+
|
| 304 |
+
encoder_out = self.encoder(src_tokens, src_lengths)
|
| 305 |
+
return self.decoder(prev_output_tokens, encoder_out)
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
src_tokens (LongTensor): tokens in the source language of shape
|
| 309 |
+
`(batch, src_len)`
|
| 310 |
+
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
|
| 311 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
| 312 |
+
`(batch, tgt_len)`, for teacher forcing
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
tuple:
|
| 316 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
| 317 |
+
- a dictionary with any model-specific outputs
|
| 318 |
+
"""
|
| 319 |
+
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
|
| 320 |
+
decoder_out = self.decoder(
|
| 321 |
+
prev_output_tokens, encoder_out=encoder_out, **kwargs
|
| 322 |
+
)
|
| 323 |
+
return decoder_out
|
| 324 |
+
|
| 325 |
+
def forward_decoder(self, prev_output_tokens, **kwargs):
|
| 326 |
+
return self.decoder(prev_output_tokens, **kwargs)
|
| 327 |
+
|
| 328 |
+
def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
|
| 329 |
+
"""
|
| 330 |
+
Similar to *forward* but only return features.
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
tuple:
|
| 334 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
| 335 |
+
- a dictionary with any model-specific outputs
|
| 336 |
+
"""
|
| 337 |
+
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
|
| 338 |
+
features = self.decoder.extract_features(
|
| 339 |
+
prev_output_tokens, encoder_out=encoder_out, **kwargs
|
| 340 |
+
)
|
| 341 |
+
return features
|
| 342 |
+
|
| 343 |
+
def output_layer(self, features, **kwargs):
|
| 344 |
+
"""Project features to the default output size (typically vocabulary size)."""
|
| 345 |
+
return self.decoder.output_layer(features, **kwargs)
|
| 346 |
+
|
| 347 |
+
def max_positions(self):
|
| 348 |
+
"""Maximum length supported by the model."""
|
| 349 |
+
return (self.encoder.max_positions(), self.decoder.max_positions())
|
| 350 |
+
|
| 351 |
+
def max_decoder_positions(self):
|
| 352 |
+
"""Maximum length supported by the decoder."""
|
| 353 |
+
return self.decoder.max_positions()
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class FairseqModel(FairseqEncoderDecoderModel):
|
| 357 |
+
def __init__(self, *args, **kwargs):
|
| 358 |
+
super().__init__(*args, **kwargs)
|
| 359 |
+
utils.deprecation_warning(
|
| 360 |
+
"FairseqModel is deprecated, please use FairseqEncoderDecoderModel "
|
| 361 |
+
"or BaseFairseqModel instead",
|
| 362 |
+
stacklevel=4,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class FairseqMultiModel(BaseFairseqModel):
|
| 367 |
+
"""Base class for combining multiple encoder-decoder models."""
|
| 368 |
+
|
| 369 |
+
def __init__(self, encoders, decoders):
|
| 370 |
+
super().__init__()
|
| 371 |
+
assert encoders.keys() == decoders.keys()
|
| 372 |
+
self.keys = list(encoders.keys())
|
| 373 |
+
for key in self.keys:
|
| 374 |
+
assert isinstance(encoders[key], FairseqEncoder)
|
| 375 |
+
assert isinstance(decoders[key], FairseqDecoder)
|
| 376 |
+
|
| 377 |
+
self.models = nn.ModuleDict(
|
| 378 |
+
{
|
| 379 |
+
key: FairseqEncoderDecoderModel(encoders[key], decoders[key])
|
| 380 |
+
for key in self.keys
|
| 381 |
+
}
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
@staticmethod
|
| 385 |
+
def build_shared_embeddings(
|
| 386 |
+
dicts: Dict[str, Dictionary],
|
| 387 |
+
langs: List[str],
|
| 388 |
+
embed_dim: int,
|
| 389 |
+
build_embedding: callable,
|
| 390 |
+
pretrained_embed_path: Optional[str] = None,
|
| 391 |
+
):
|
| 392 |
+
"""
|
| 393 |
+
Helper function to build shared embeddings for a set of languages after
|
| 394 |
+
checking that all dicts corresponding to those languages are equivalent.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
dicts: Dict of lang_id to its corresponding Dictionary
|
| 398 |
+
langs: languages that we want to share embeddings for
|
| 399 |
+
embed_dim: embedding dimension
|
| 400 |
+
build_embedding: callable function to actually build the embedding
|
| 401 |
+
pretrained_embed_path: Optional path to load pretrained embeddings
|
| 402 |
+
"""
|
| 403 |
+
shared_dict = dicts[langs[0]]
|
| 404 |
+
if any(dicts[lang] != shared_dict for lang in langs):
|
| 405 |
+
raise ValueError(
|
| 406 |
+
"--share-*-embeddings requires a joined dictionary: "
|
| 407 |
+
"--share-encoder-embeddings requires a joined source "
|
| 408 |
+
"dictionary, --share-decoder-embeddings requires a joined "
|
| 409 |
+
"target dictionary, and --share-all-embeddings requires a "
|
| 410 |
+
"joint source + target dictionary."
|
| 411 |
+
)
|
| 412 |
+
return build_embedding(shared_dict, embed_dim, pretrained_embed_path)
|
| 413 |
+
|
| 414 |
+
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
|
| 415 |
+
raise NotImplementedError
|
| 416 |
+
|
| 417 |
+
def max_positions(self):
|
| 418 |
+
"""Maximum length supported by the model."""
|
| 419 |
+
return {
|
| 420 |
+
key: (
|
| 421 |
+
self.models[key].encoder.max_positions(),
|
| 422 |
+
self.models[key].decoder.max_positions(),
|
| 423 |
+
)
|
| 424 |
+
for key in self.keys
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
def max_decoder_positions(self):
|
| 428 |
+
"""Maximum length supported by the decoder."""
|
| 429 |
+
return min(model.decoder.max_positions() for model in self.models.values())
|
| 430 |
+
|
| 431 |
+
@property
|
| 432 |
+
def encoder(self):
|
| 433 |
+
return self.models[self.keys[0]].encoder
|
| 434 |
+
|
| 435 |
+
@property
|
| 436 |
+
def decoder(self):
|
| 437 |
+
return self.models[self.keys[0]].decoder
|
| 438 |
+
|
| 439 |
+
def forward_decoder(self, prev_output_tokens, **kwargs):
|
| 440 |
+
return self.decoder(prev_output_tokens, **kwargs)
|
| 441 |
+
|
| 442 |
+
def load_state_dict(self, state_dict, strict=True, args=None):
|
| 443 |
+
"""Copies parameters and buffers from *state_dict* into this module and
|
| 444 |
+
its descendants.
|
| 445 |
+
|
| 446 |
+
Overrides the method in :class:`nn.Module`. Compared with that method
|
| 447 |
+
this additionally "upgrades" *state_dicts* from old checkpoints.
|
| 448 |
+
"""
|
| 449 |
+
self.upgrade_state_dict(state_dict)
|
| 450 |
+
from fairseq.checkpoint_utils import prune_state_dict
|
| 451 |
+
new_state_dict = prune_state_dict(state_dict, args)
|
| 452 |
+
return super().load_state_dict(new_state_dict, strict)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class FairseqLanguageModel(BaseFairseqModel):
|
| 456 |
+
"""Base class for decoder-only models.
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
decoder (FairseqDecoder): the decoder
|
| 460 |
+
"""
|
| 461 |
+
|
| 462 |
+
def __init__(self, decoder):
|
| 463 |
+
super().__init__()
|
| 464 |
+
self.decoder = decoder
|
| 465 |
+
assert isinstance(self.decoder, FairseqDecoder)
|
| 466 |
+
|
| 467 |
+
def forward(self, src_tokens, **kwargs):
|
| 468 |
+
"""
|
| 469 |
+
Run the forward pass for a decoder-only model.
|
| 470 |
+
|
| 471 |
+
Feeds a batch of tokens through the decoder to predict the next tokens.
|
| 472 |
+
|
| 473 |
+
Args:
|
| 474 |
+
src_tokens (LongTensor): tokens on which to condition the decoder,
|
| 475 |
+
of shape `(batch, tgt_len)`
|
| 476 |
+
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
|
| 477 |
+
|
| 478 |
+
Returns:
|
| 479 |
+
tuple:
|
| 480 |
+
- the decoder's output of shape `(batch, seq_len, vocab)`
|
| 481 |
+
- a dictionary with any model-specific outputs
|
| 482 |
+
"""
|
| 483 |
+
return self.decoder(src_tokens, **kwargs)
|
| 484 |
+
|
| 485 |
+
def forward_decoder(self, prev_output_tokens, **kwargs):
|
| 486 |
+
return self.decoder(prev_output_tokens, **kwargs)
|
| 487 |
+
|
| 488 |
+
def extract_features(self, src_tokens, **kwargs):
|
| 489 |
+
"""
|
| 490 |
+
Similar to *forward* but only return features.
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
tuple:
|
| 494 |
+
- the decoder's features of shape `(batch, seq_len, embed_dim)`
|
| 495 |
+
- a dictionary with any model-specific outputs
|
| 496 |
+
"""
|
| 497 |
+
return self.decoder.extract_features(src_tokens, **kwargs)
|
| 498 |
+
|
| 499 |
+
def output_layer(self, features, **kwargs):
|
| 500 |
+
"""Project features to the default output size (typically vocabulary size)."""
|
| 501 |
+
return self.decoder.output_layer(features, **kwargs)
|
| 502 |
+
|
| 503 |
+
def max_positions(self):
|
| 504 |
+
"""Maximum length supported by the model."""
|
| 505 |
+
return self.decoder.max_positions()
|
| 506 |
+
|
| 507 |
+
def max_decoder_positions(self):
|
| 508 |
+
"""Maximum length supported by the decoder."""
|
| 509 |
+
return self.decoder.max_positions()
|
| 510 |
+
|
| 511 |
+
@property
|
| 512 |
+
def supported_targets(self):
|
| 513 |
+
return {"future"}
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
class FairseqEncoderModel(BaseFairseqModel):
|
| 517 |
+
"""Base class for encoder-only models.
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
encoder (FairseqEncoder): the encoder
|
| 521 |
+
"""
|
| 522 |
+
|
| 523 |
+
def __init__(self, encoder):
|
| 524 |
+
super().__init__()
|
| 525 |
+
self.encoder = encoder
|
| 526 |
+
assert isinstance(self.encoder, FairseqEncoder)
|
| 527 |
+
|
| 528 |
+
def forward(self, src_tokens, src_lengths, **kwargs):
|
| 529 |
+
"""
|
| 530 |
+
Run the forward pass for a encoder-only model.
|
| 531 |
+
|
| 532 |
+
Feeds a batch of tokens through the encoder to generate features.
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
|
| 536 |
+
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
|
| 537 |
+
|
| 538 |
+
Returns:
|
| 539 |
+
the encoder's output, typically of shape `(batch, src_len, features)`
|
| 540 |
+
"""
|
| 541 |
+
return self.encoder(src_tokens, src_lengths, **kwargs)
|
| 542 |
+
|
| 543 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
| 544 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
| 545 |
+
encoder_out = net_output["encoder_out"]
|
| 546 |
+
if torch.is_tensor(encoder_out):
|
| 547 |
+
logits = encoder_out.float()
|
| 548 |
+
if log_probs:
|
| 549 |
+
return F.log_softmax(logits, dim=-1)
|
| 550 |
+
else:
|
| 551 |
+
return F.softmax(logits, dim=-1)
|
| 552 |
+
raise NotImplementedError
|
| 553 |
+
|
| 554 |
+
def max_positions(self):
|
| 555 |
+
"""Maximum length supported by the model."""
|
| 556 |
+
return self.encoder.max_positions()
|
fairseq-0.10.2/fairseq/models/huggingface/hf_gpt2.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
from typing import Dict, List, Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from fairseq.models import (
|
| 13 |
+
FairseqIncrementalDecoder,
|
| 14 |
+
FairseqLanguageModel,
|
| 15 |
+
register_model,
|
| 16 |
+
register_model_architecture,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@register_model("hf_gpt2")
|
| 27 |
+
class HuggingFaceGPT2LanguageModel(FairseqLanguageModel):
|
| 28 |
+
def __init__(self, decoder):
|
| 29 |
+
super().__init__(decoder)
|
| 30 |
+
|
| 31 |
+
@staticmethod
|
| 32 |
+
def add_args(parser):
|
| 33 |
+
"""Add model-specific arguments to the parser."""
|
| 34 |
+
# fmt: off
|
| 35 |
+
parser.add_argument('--embed-dim', type=int, metavar='N',
|
| 36 |
+
help='embedding dimension')
|
| 37 |
+
parser.add_argument('--num-attention-heads', type=int, metavar='N',
|
| 38 |
+
help='num attention heads')
|
| 39 |
+
parser.add_argument('--num-layers', type=int, metavar='N',
|
| 40 |
+
help='num layers')
|
| 41 |
+
parser.add_argument('--dropout', type=float, metavar='D',
|
| 42 |
+
help='dropout probability for all fully connected layers '
|
| 43 |
+
'in the embeddings, encoder, and pooler')
|
| 44 |
+
parser.add_argument('--attention-dropout', type=float, metavar='D',
|
| 45 |
+
help='dropout probability for attention weights')
|
| 46 |
+
# fmt: on
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def build_model(cls, args, task):
|
| 50 |
+
"""Build a new model instance."""
|
| 51 |
+
default_architecture(args)
|
| 52 |
+
return cls(HuggingFaceGPT2Decoder(args, task))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class HuggingFaceGPT2Decoder(FairseqIncrementalDecoder):
|
| 56 |
+
def __init__(self, args, task):
|
| 57 |
+
try:
|
| 58 |
+
from transformers import GPT2Config, GPT2LMHeadModel
|
| 59 |
+
except ImportError:
|
| 60 |
+
raise ImportError(
|
| 61 |
+
"\n\nPlease install huggingface/transformers with:"
|
| 62 |
+
"\n\n pip install transformers"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
super().__init__(task.target_dictionary)
|
| 66 |
+
|
| 67 |
+
config = GPT2Config(
|
| 68 |
+
vocab_size=len(task.target_dictionary),
|
| 69 |
+
n_positions=args.max_target_positions + 1,
|
| 70 |
+
n_ctx=args.max_target_positions,
|
| 71 |
+
n_embd=args.embed_dim,
|
| 72 |
+
n_layer=args.num_layers,
|
| 73 |
+
n_head=args.num_attention_heads,
|
| 74 |
+
resid_pdrop=args.dropout,
|
| 75 |
+
embd_pdrop=args.dropout,
|
| 76 |
+
attn_pdrop=args.attention_dropout,
|
| 77 |
+
layer_norm_epsilon=1e-6,
|
| 78 |
+
)
|
| 79 |
+
self.model = GPT2LMHeadModel(config)
|
| 80 |
+
|
| 81 |
+
# set zero embedding for padding symbol
|
| 82 |
+
self.pad_idx = task.target_dictionary.pad()
|
| 83 |
+
self.model.transformer.wte.weight.data[self.pad_idx].zero_()
|
| 84 |
+
self.model.transformer.wpe.weight.data[0].zero_()
|
| 85 |
+
|
| 86 |
+
def forward(
|
| 87 |
+
self,
|
| 88 |
+
prev_output_tokens,
|
| 89 |
+
src_lengths=None,
|
| 90 |
+
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
|
| 91 |
+
encoder_out=None,
|
| 92 |
+
):
|
| 93 |
+
features = self.extract_features(prev_output_tokens, incremental_state)
|
| 94 |
+
lm_logits = self.model.lm_head(features)
|
| 95 |
+
return (lm_logits,)
|
| 96 |
+
|
| 97 |
+
def extract_features(
|
| 98 |
+
self,
|
| 99 |
+
prev_output_tokens,
|
| 100 |
+
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
|
| 101 |
+
):
|
| 102 |
+
if incremental_state:
|
| 103 |
+
past = self.get_incremental_state("past")
|
| 104 |
+
else:
|
| 105 |
+
past = None
|
| 106 |
+
|
| 107 |
+
# don't attend to padding symbols
|
| 108 |
+
attention_mask = prev_output_tokens.ne(self.pad_idx).int()
|
| 109 |
+
|
| 110 |
+
# set position ids to exclude padding symbols
|
| 111 |
+
position_ids = attention_mask * (
|
| 112 |
+
torch.arange(1, 1 + prev_output_tokens.size(1))
|
| 113 |
+
.to(prev_output_tokens)
|
| 114 |
+
.repeat(prev_output_tokens.size(0), 1)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
outputs = self.model.transformer(
|
| 118 |
+
input_ids=prev_output_tokens,
|
| 119 |
+
past=past,
|
| 120 |
+
attention_mask=attention_mask,
|
| 121 |
+
position_ids=position_ids,
|
| 122 |
+
)
|
| 123 |
+
last_hidden_states = outputs[0]
|
| 124 |
+
|
| 125 |
+
if incremental_state:
|
| 126 |
+
self.set_incremental_state(incremental_state, "past", outputs[1])
|
| 127 |
+
|
| 128 |
+
return last_hidden_states
|
| 129 |
+
|
| 130 |
+
def max_positions(self):
|
| 131 |
+
return self.model.config.n_positions - 1
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@register_model_architecture("hf_gpt2", "hf_gpt2")
|
| 135 |
+
def default_architecture(args):
|
| 136 |
+
if getattr(args, "max_target_positions", None) is None:
|
| 137 |
+
args.max_target_positions = getattr(
|
| 138 |
+
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
|
| 139 |
+
)
|
| 140 |
+
args.embed_dim = getattr(args, "embed_dim", 768)
|
| 141 |
+
args.num_attention_heads = getattr(args, "num_attention_heads", 12)
|
| 142 |
+
args.num_layers = getattr(args, "num_layers", 12)
|
| 143 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 144 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@register_model_architecture("hf_gpt2", "hf_gpt2_medium")
|
| 148 |
+
def hf_gpt2_medium(args):
|
| 149 |
+
args.embed_dim = getattr(args, "embed_dim", 1024)
|
| 150 |
+
args.num_attention_heads = getattr(args, "num_attention_heads", 16)
|
| 151 |
+
args.num_layers = getattr(args, "num_layers", 24)
|
| 152 |
+
default_architecture(args)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@register_model_architecture("hf_gpt2", "hf_gpt2_large")
|
| 156 |
+
def hf_gpt2_large(args):
|
| 157 |
+
args.embed_dim = getattr(args, "embed_dim", 1280)
|
| 158 |
+
args.num_attention_heads = getattr(args, "num_attention_heads", 20)
|
| 159 |
+
args.num_layers = getattr(args, "num_layers", 36)
|
| 160 |
+
default_architecture(args)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@register_model_architecture("hf_gpt2", "hf_gpt2_xl")
|
| 164 |
+
def hf_gpt2_xl(args):
|
| 165 |
+
args.embed_dim = getattr(args, "embed_dim", 1600)
|
| 166 |
+
args.num_attention_heads = getattr(args, "num_attention_heads", 25)
|
| 167 |
+
args.num_layers = getattr(args, "num_layers", 48)
|
| 168 |
+
default_architecture(args)
|
fairseq-0.10.2/fairseq/models/lightconv_lm.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from fairseq import utils
|
| 7 |
+
from fairseq.models import (
|
| 8 |
+
FairseqLanguageModel,
|
| 9 |
+
register_model,
|
| 10 |
+
register_model_architecture,
|
| 11 |
+
)
|
| 12 |
+
from fairseq.models.lightconv import Embedding, LightConvDecoder
|
| 13 |
+
from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@register_model("lightconv_lm")
|
| 17 |
+
class LightConvLanguageModel(FairseqLanguageModel):
|
| 18 |
+
def __init__(self, decoder):
|
| 19 |
+
super().__init__(decoder)
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def add_args(parser):
|
| 23 |
+
"""Add model-specific arguments to the parser."""
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--dropout",
|
| 26 |
+
default=0.1,
|
| 27 |
+
type=float,
|
| 28 |
+
metavar="D",
|
| 29 |
+
help="dropout probability",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--attention-dropout",
|
| 33 |
+
default=0.0,
|
| 34 |
+
type=float,
|
| 35 |
+
metavar="D",
|
| 36 |
+
help="dropout probability for attention weights",
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--relu-dropout",
|
| 40 |
+
default=0.0,
|
| 41 |
+
type=float,
|
| 42 |
+
metavar="D",
|
| 43 |
+
help="dropout probability after ReLU in FFN",
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--input-dropout",
|
| 47 |
+
type=float,
|
| 48 |
+
metavar="D",
|
| 49 |
+
help="dropout probability of the inputs",
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--decoder-embed-dim",
|
| 53 |
+
type=int,
|
| 54 |
+
metavar="N",
|
| 55 |
+
help="decoder embedding dimension",
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--decoder-output-dim",
|
| 59 |
+
type=int,
|
| 60 |
+
metavar="N",
|
| 61 |
+
help="decoder output dimension",
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--decoder-input-dim", type=int, metavar="N", help="decoder input dimension"
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--decoder-ffn-embed-dim",
|
| 68 |
+
type=int,
|
| 69 |
+
metavar="N",
|
| 70 |
+
help="decoder embedding dimension for FFN",
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--decoder-attention-heads",
|
| 77 |
+
type=int,
|
| 78 |
+
metavar="N",
|
| 79 |
+
help="num decoder attention heads or LightConv/DynamicConv heads",
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--decoder-normalize-before",
|
| 83 |
+
default=False,
|
| 84 |
+
action="store_true",
|
| 85 |
+
help="apply layernorm before each decoder block",
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--adaptive-softmax-cutoff",
|
| 89 |
+
metavar="EXPR",
|
| 90 |
+
help="comma separated list of adaptive softmax cutoff points. "
|
| 91 |
+
"Must be used with adaptive_loss criterion",
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--adaptive-softmax-dropout",
|
| 95 |
+
type=float,
|
| 96 |
+
metavar="D",
|
| 97 |
+
help="sets adaptive softmax dropout for the tail projections",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--adaptive-softmax-factor",
|
| 101 |
+
type=float,
|
| 102 |
+
metavar="N",
|
| 103 |
+
help="adaptive input factor",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--no-token-positional-embeddings",
|
| 107 |
+
default=False,
|
| 108 |
+
action="store_true",
|
| 109 |
+
help="if set, disables positional embeddings (outside self attention)",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--share-decoder-input-output-embed",
|
| 113 |
+
default=False,
|
| 114 |
+
action="store_true",
|
| 115 |
+
help="share decoder input and output embeddings",
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--character-embeddings",
|
| 119 |
+
default=False,
|
| 120 |
+
action="store_true",
|
| 121 |
+
help="if set, uses character embedding convolutions to produce token embeddings",
|
| 122 |
+
)
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--character-filters",
|
| 125 |
+
type=str,
|
| 126 |
+
metavar="LIST",
|
| 127 |
+
default="[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]",
|
| 128 |
+
help="size of character embeddings",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--character-embedding-dim",
|
| 132 |
+
type=int,
|
| 133 |
+
metavar="N",
|
| 134 |
+
default=4,
|
| 135 |
+
help="size of character embeddings",
|
| 136 |
+
)
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--char-embedder-highway-layers",
|
| 139 |
+
type=int,
|
| 140 |
+
metavar="N",
|
| 141 |
+
default=2,
|
| 142 |
+
help="number of highway layers for character token embeddder",
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--adaptive-input",
|
| 146 |
+
default=False,
|
| 147 |
+
action="store_true",
|
| 148 |
+
help="if set, uses adaptive input",
|
| 149 |
+
)
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--adaptive-input-factor",
|
| 152 |
+
type=float,
|
| 153 |
+
metavar="N",
|
| 154 |
+
help="adaptive input factor",
|
| 155 |
+
)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--adaptive-input-cutoff",
|
| 158 |
+
metavar="EXPR",
|
| 159 |
+
help="comma separated list of adaptive input cutoff points.",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--tie-adaptive-weights",
|
| 163 |
+
action="store_true",
|
| 164 |
+
help="if set, ties the weights of adaptive softmax and adaptive input",
|
| 165 |
+
)
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
"--tie-adaptive-proj",
|
| 168 |
+
action="store_true",
|
| 169 |
+
help="if set, ties the projection weights of adaptive softmax and adaptive input",
|
| 170 |
+
)
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--decoder-learned-pos",
|
| 173 |
+
action="store_true",
|
| 174 |
+
help="use learned positional embeddings in the decoder",
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
"""LightConv and DynamicConv arguments"""
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--decoder-kernel-size-list",
|
| 180 |
+
type=lambda x: utils.eval_str_list(x, int),
|
| 181 |
+
help='list of kernel size (default: "[3,7,15,31,31,31]")',
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--decoder-glu", type=utils.eval_bool, help="glu after in proj"
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--decoder-conv-type",
|
| 188 |
+
default="dynamic",
|
| 189 |
+
type=str,
|
| 190 |
+
choices=["dynamic", "lightweight"],
|
| 191 |
+
help="type of convolution",
|
| 192 |
+
)
|
| 193 |
+
parser.add_argument("--weight-softmax", default=True, type=utils.eval_bool)
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--weight-dropout",
|
| 196 |
+
type=float,
|
| 197 |
+
metavar="D",
|
| 198 |
+
help="dropout probability for conv weights",
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
@classmethod
|
| 202 |
+
def build_model(cls, args, task):
|
| 203 |
+
"""Build a new model instance."""
|
| 204 |
+
|
| 205 |
+
# make sure all arguments are present in older models
|
| 206 |
+
base_lm_architecture(args)
|
| 207 |
+
|
| 208 |
+
if getattr(args, "max_source_positions", None) is None:
|
| 209 |
+
args.max_source_positions = args.tokens_per_sample
|
| 210 |
+
if getattr(args, "max_target_positions", None) is None:
|
| 211 |
+
args.max_target_positions = args.tokens_per_sample
|
| 212 |
+
|
| 213 |
+
if args.character_embeddings:
|
| 214 |
+
embed_tokens = CharacterTokenEmbedder(
|
| 215 |
+
task.dictionary,
|
| 216 |
+
eval(args.character_filters),
|
| 217 |
+
args.character_embedding_dim,
|
| 218 |
+
args.decoder_embed_dim,
|
| 219 |
+
args.char_embedder_highway_layers,
|
| 220 |
+
)
|
| 221 |
+
elif args.adaptive_input:
|
| 222 |
+
embed_tokens = AdaptiveInput(
|
| 223 |
+
len(task.dictionary),
|
| 224 |
+
task.dictionary.pad(),
|
| 225 |
+
args.decoder_input_dim,
|
| 226 |
+
args.adaptive_input_factor,
|
| 227 |
+
args.decoder_embed_dim,
|
| 228 |
+
utils.eval_str_list(args.adaptive_input_cutoff, type=int),
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
embed_tokens = Embedding(
|
| 232 |
+
len(task.dictionary), args.decoder_input_dim, task.dictionary.pad()
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if args.tie_adaptive_weights:
|
| 236 |
+
assert args.adaptive_input
|
| 237 |
+
assert args.adaptive_input_factor == args.adaptive_softmax_factor
|
| 238 |
+
assert (
|
| 239 |
+
args.adaptive_softmax_cutoff == args.adaptive_input_cutoff
|
| 240 |
+
), "{} != {}".format(
|
| 241 |
+
args.adaptive_softmax_cutoff, args.adaptive_input_cutoff
|
| 242 |
+
)
|
| 243 |
+
assert args.decoder_input_dim == args.decoder_output_dim
|
| 244 |
+
|
| 245 |
+
decoder = LightConvDecoder(
|
| 246 |
+
args,
|
| 247 |
+
task.output_dictionary,
|
| 248 |
+
embed_tokens,
|
| 249 |
+
no_encoder_attn=True,
|
| 250 |
+
final_norm=False,
|
| 251 |
+
)
|
| 252 |
+
return LightConvLanguageModel(decoder)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
@register_model_architecture("lightconv_lm", "lightconv_lm")
|
| 256 |
+
def base_lm_architecture(args):
|
| 257 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
| 258 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
|
| 259 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
| 260 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
| 261 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
| 262 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
| 263 |
+
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
|
| 264 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
| 265 |
+
|
| 266 |
+
args.character_embeddings = getattr(args, "character_embeddings", False)
|
| 267 |
+
|
| 268 |
+
args.decoder_output_dim = getattr(
|
| 269 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
| 270 |
+
)
|
| 271 |
+
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
| 272 |
+
args.decoder_conv_dim = getattr(args, "decoder_conv_dim", args.decoder_embed_dim)
|
| 273 |
+
|
| 274 |
+
# The model training is not stable without this
|
| 275 |
+
args.decoder_normalize_before = True
|
| 276 |
+
|
| 277 |
+
args.adaptive_input = getattr(args, "adaptive_input", False)
|
| 278 |
+
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
|
| 279 |
+
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
|
| 280 |
+
|
| 281 |
+
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
|
| 282 |
+
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
|
| 283 |
+
|
| 284 |
+
args.decoder_kernel_size_list = getattr(
|
| 285 |
+
args, "decoder_kernel_size_list", [3, 7, 15, 31, 31, 31]
|
| 286 |
+
)
|
| 287 |
+
if len(args.decoder_kernel_size_list) == 1:
|
| 288 |
+
args.decoder_kernel_size_list = (
|
| 289 |
+
args.decoder_kernel_size_list * args.decoder_layers
|
| 290 |
+
)
|
| 291 |
+
assert (
|
| 292 |
+
len(args.decoder_kernel_size_list) == args.decoder_layers
|
| 293 |
+
), "decoder_kernel_size_list doesn't match decoder_layers"
|
| 294 |
+
args.decoder_glu = getattr(args, "decoder_glu", True)
|
| 295 |
+
args.input_dropout = getattr(args, "input_dropout", 0.1)
|
| 296 |
+
args.weight_dropout = getattr(args, "weight_dropout", args.attention_dropout)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
@register_model_architecture("lightconv_lm", "lightconv_lm_gbw")
|
| 300 |
+
def lightconv_lm_gbw(args):
|
| 301 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
| 302 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 303 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
| 304 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
|
| 305 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
| 306 |
+
base_lm_architecture(args)
|
fairseq-0.10.2/fairseq/models/masked_lm.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from fairseq import utils
|
| 12 |
+
from fairseq.models import (
|
| 13 |
+
FairseqEncoder,
|
| 14 |
+
FairseqEncoderModel,
|
| 15 |
+
register_model,
|
| 16 |
+
register_model_architecture,
|
| 17 |
+
)
|
| 18 |
+
from fairseq.modules import (
|
| 19 |
+
LayerNorm,
|
| 20 |
+
SinusoidalPositionalEmbedding,
|
| 21 |
+
TransformerSentenceEncoder,
|
| 22 |
+
)
|
| 23 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@register_model("masked_lm")
|
| 30 |
+
class MaskedLMModel(FairseqEncoderModel):
|
| 31 |
+
"""
|
| 32 |
+
Class for training a Masked Language Model. It also supports an
|
| 33 |
+
additional sentence level prediction if the sent-loss argument is set.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, args, encoder):
|
| 37 |
+
super().__init__(encoder)
|
| 38 |
+
self.args = args
|
| 39 |
+
|
| 40 |
+
# if specified then apply bert initialization on the model. We need
|
| 41 |
+
# to explictly call this to make sure that the output embeddings
|
| 42 |
+
# and projection layers are also correctly initialized
|
| 43 |
+
if getattr(args, "apply_bert_init", False):
|
| 44 |
+
self.apply(init_bert_params)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def add_args(parser):
|
| 48 |
+
"""Add model-specific arguments to the parser."""
|
| 49 |
+
# Arguments related to dropout
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--dropout", type=float, metavar="D", help="dropout probability"
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--attention-dropout",
|
| 55 |
+
type=float,
|
| 56 |
+
metavar="D",
|
| 57 |
+
help="dropout probability for" " attention weights",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--act-dropout",
|
| 61 |
+
type=float,
|
| 62 |
+
metavar="D",
|
| 63 |
+
help="dropout probability after" " activation in FFN",
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Arguments related to hidden states and self-attention
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--encoder-ffn-embed-dim",
|
| 69 |
+
type=int,
|
| 70 |
+
metavar="N",
|
| 71 |
+
help="encoder embedding dimension for FFN",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--encoder-attention-heads",
|
| 78 |
+
type=int,
|
| 79 |
+
metavar="N",
|
| 80 |
+
help="num encoder attention heads",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Arguments related to input and output embeddings
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--encoder-embed-dim",
|
| 86 |
+
type=int,
|
| 87 |
+
metavar="N",
|
| 88 |
+
help="encoder embedding dimension",
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--share-encoder-input-output-embed",
|
| 92 |
+
action="store_true",
|
| 93 |
+
help="share encoder input" " and output embeddings",
|
| 94 |
+
)
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--encoder-learned-pos",
|
| 97 |
+
action="store_true",
|
| 98 |
+
help="use learned positional embeddings in the encoder",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--no-token-positional-embeddings",
|
| 102 |
+
action="store_true",
|
| 103 |
+
help="if set, disables positional embeddings" " (outside self attention)",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--num-segment", type=int, metavar="N", help="num segment in the input"
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--max-positions", type=int, help="number of positional embeddings to learn"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Arguments related to sentence level prediction
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--sentence-class-num",
|
| 115 |
+
type=int,
|
| 116 |
+
metavar="N",
|
| 117 |
+
help="number of classes for sentence task",
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--sent-loss",
|
| 121 |
+
action="store_true",
|
| 122 |
+
help="if set," " calculate sentence level predictions",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Arguments related to parameter initialization
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--apply-bert-init",
|
| 128 |
+
action="store_true",
|
| 129 |
+
help="use custom param initialization for BERT",
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# misc params
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--activation-fn",
|
| 135 |
+
choices=utils.get_available_activation_fns(),
|
| 136 |
+
help="activation function to use",
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--pooler-activation-fn",
|
| 140 |
+
choices=utils.get_available_activation_fns(),
|
| 141 |
+
help="Which activation function to use for pooler layer.",
|
| 142 |
+
)
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"--encoder-normalize-before",
|
| 145 |
+
action="store_true",
|
| 146 |
+
help="apply layernorm before each encoder block",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def forward(self, src_tokens, segment_labels=None, **kwargs):
|
| 150 |
+
return self.encoder(src_tokens, segment_labels=segment_labels, **kwargs)
|
| 151 |
+
|
| 152 |
+
def max_positions(self):
|
| 153 |
+
return self.encoder.max_positions
|
| 154 |
+
|
| 155 |
+
@classmethod
|
| 156 |
+
def build_model(cls, args, task):
|
| 157 |
+
"""Build a new model instance."""
|
| 158 |
+
# make sure all arguments are present in older models
|
| 159 |
+
base_architecture(args)
|
| 160 |
+
|
| 161 |
+
if not hasattr(args, "max_positions"):
|
| 162 |
+
args.max_positions = args.tokens_per_sample
|
| 163 |
+
|
| 164 |
+
logger.info(args)
|
| 165 |
+
|
| 166 |
+
encoder = MaskedLMEncoder(args, task.dictionary)
|
| 167 |
+
return cls(args, encoder)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class MaskedLMEncoder(FairseqEncoder):
|
| 171 |
+
"""
|
| 172 |
+
Encoder for Masked Language Modelling.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(self, args, dictionary):
|
| 176 |
+
super().__init__(dictionary)
|
| 177 |
+
|
| 178 |
+
self.padding_idx = dictionary.pad()
|
| 179 |
+
self.vocab_size = dictionary.__len__()
|
| 180 |
+
self.max_positions = args.max_positions
|
| 181 |
+
|
| 182 |
+
self.sentence_encoder = TransformerSentenceEncoder(
|
| 183 |
+
padding_idx=self.padding_idx,
|
| 184 |
+
vocab_size=self.vocab_size,
|
| 185 |
+
num_encoder_layers=args.encoder_layers,
|
| 186 |
+
embedding_dim=args.encoder_embed_dim,
|
| 187 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
| 188 |
+
num_attention_heads=args.encoder_attention_heads,
|
| 189 |
+
dropout=args.dropout,
|
| 190 |
+
attention_dropout=args.attention_dropout,
|
| 191 |
+
activation_dropout=args.act_dropout,
|
| 192 |
+
max_seq_len=self.max_positions,
|
| 193 |
+
num_segments=args.num_segment,
|
| 194 |
+
use_position_embeddings=not args.no_token_positional_embeddings,
|
| 195 |
+
encoder_normalize_before=args.encoder_normalize_before,
|
| 196 |
+
apply_bert_init=args.apply_bert_init,
|
| 197 |
+
activation_fn=args.activation_fn,
|
| 198 |
+
learned_pos_embedding=args.encoder_learned_pos,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
self.share_input_output_embed = args.share_encoder_input_output_embed
|
| 202 |
+
self.embed_out = None
|
| 203 |
+
self.sentence_projection_layer = None
|
| 204 |
+
self.sentence_out_dim = args.sentence_class_num
|
| 205 |
+
self.lm_output_learned_bias = None
|
| 206 |
+
|
| 207 |
+
# Remove head is set to true during fine-tuning
|
| 208 |
+
self.load_softmax = not getattr(args, "remove_head", False)
|
| 209 |
+
|
| 210 |
+
self.masked_lm_pooler = nn.Linear(
|
| 211 |
+
args.encoder_embed_dim, args.encoder_embed_dim
|
| 212 |
+
)
|
| 213 |
+
self.pooler_activation = utils.get_activation_fn(args.pooler_activation_fn)
|
| 214 |
+
|
| 215 |
+
self.lm_head_transform_weight = nn.Linear(
|
| 216 |
+
args.encoder_embed_dim, args.encoder_embed_dim
|
| 217 |
+
)
|
| 218 |
+
self.activation_fn = utils.get_activation_fn(args.activation_fn)
|
| 219 |
+
self.layer_norm = LayerNorm(args.encoder_embed_dim)
|
| 220 |
+
|
| 221 |
+
self.lm_output_learned_bias = None
|
| 222 |
+
if self.load_softmax:
|
| 223 |
+
self.lm_output_learned_bias = nn.Parameter(torch.zeros(self.vocab_size))
|
| 224 |
+
|
| 225 |
+
if not self.share_input_output_embed:
|
| 226 |
+
self.embed_out = nn.Linear(
|
| 227 |
+
args.encoder_embed_dim, self.vocab_size, bias=False
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
if args.sent_loss:
|
| 231 |
+
self.sentence_projection_layer = nn.Linear(
|
| 232 |
+
args.encoder_embed_dim, self.sentence_out_dim, bias=False
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def forward(self, src_tokens, segment_labels=None, masked_tokens=None, **unused):
|
| 236 |
+
"""
|
| 237 |
+
Forward pass for Masked LM encoder. This first computes the token
|
| 238 |
+
embedding using the token embedding matrix, position embeddings (if
|
| 239 |
+
specified) and segment embeddings (if specified).
|
| 240 |
+
|
| 241 |
+
Here we assume that the sentence representation corresponds to the
|
| 242 |
+
output of the classification_token (see bert_task or cross_lingual_lm
|
| 243 |
+
task for more details).
|
| 244 |
+
Args:
|
| 245 |
+
- src_tokens: B x T matrix representing sentences
|
| 246 |
+
- segment_labels: B x T matrix representing segment label for tokens
|
| 247 |
+
Returns:
|
| 248 |
+
- a tuple of the following:
|
| 249 |
+
- logits for predictions in format B x T x C to be used in
|
| 250 |
+
softmax afterwards
|
| 251 |
+
- a dictionary of additional data, where 'pooled_output' contains
|
| 252 |
+
the representation for classification_token and 'inner_states'
|
| 253 |
+
is a list of internal model states used to compute the
|
| 254 |
+
predictions (similar in ELMO). 'sentence_logits'
|
| 255 |
+
is the prediction logit for NSP task and is only computed if
|
| 256 |
+
this is specified in the input arguments.
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
inner_states, sentence_rep = self.sentence_encoder(
|
| 260 |
+
src_tokens,
|
| 261 |
+
segment_labels=segment_labels,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
x = inner_states[-1].transpose(0, 1)
|
| 265 |
+
# project masked tokens only
|
| 266 |
+
if masked_tokens is not None:
|
| 267 |
+
x = x[masked_tokens, :]
|
| 268 |
+
x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x)))
|
| 269 |
+
|
| 270 |
+
pooled_output = self.pooler_activation(self.masked_lm_pooler(sentence_rep))
|
| 271 |
+
|
| 272 |
+
# project back to size of vocabulary
|
| 273 |
+
if self.share_input_output_embed and hasattr(
|
| 274 |
+
self.sentence_encoder.embed_tokens, "weight"
|
| 275 |
+
):
|
| 276 |
+
x = F.linear(x, self.sentence_encoder.embed_tokens.weight)
|
| 277 |
+
elif self.embed_out is not None:
|
| 278 |
+
x = self.embed_out(x)
|
| 279 |
+
if self.lm_output_learned_bias is not None:
|
| 280 |
+
x = x + self.lm_output_learned_bias
|
| 281 |
+
sentence_logits = None
|
| 282 |
+
if self.sentence_projection_layer:
|
| 283 |
+
sentence_logits = self.sentence_projection_layer(pooled_output)
|
| 284 |
+
|
| 285 |
+
return x, {
|
| 286 |
+
"inner_states": inner_states,
|
| 287 |
+
"pooled_output": pooled_output,
|
| 288 |
+
"sentence_logits": sentence_logits,
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
def max_positions(self):
|
| 292 |
+
"""Maximum output length supported by the encoder."""
|
| 293 |
+
return self.max_positions
|
| 294 |
+
|
| 295 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 296 |
+
if isinstance(
|
| 297 |
+
self.sentence_encoder.embed_positions, SinusoidalPositionalEmbedding
|
| 298 |
+
):
|
| 299 |
+
state_dict[
|
| 300 |
+
name + ".sentence_encoder.embed_positions._float_tensor"
|
| 301 |
+
] = torch.FloatTensor(1)
|
| 302 |
+
if not self.load_softmax:
|
| 303 |
+
for k in list(state_dict.keys()):
|
| 304 |
+
if (
|
| 305 |
+
"embed_out.weight" in k
|
| 306 |
+
or "sentence_projection_layer.weight" in k
|
| 307 |
+
or "lm_output_learned_bias" in k
|
| 308 |
+
):
|
| 309 |
+
del state_dict[k]
|
| 310 |
+
return state_dict
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
@register_model_architecture("masked_lm", "masked_lm")
|
| 314 |
+
def base_architecture(args):
|
| 315 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 316 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
| 317 |
+
args.act_dropout = getattr(args, "act_dropout", 0.0)
|
| 318 |
+
|
| 319 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
| 320 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
| 321 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
| 322 |
+
|
| 323 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
| 324 |
+
args.share_encoder_input_output_embed = getattr(
|
| 325 |
+
args, "share_encoder_input_output_embed", False
|
| 326 |
+
)
|
| 327 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
| 328 |
+
args.no_token_positional_embeddings = getattr(
|
| 329 |
+
args, "no_token_positional_embeddings", False
|
| 330 |
+
)
|
| 331 |
+
args.num_segment = getattr(args, "num_segment", 2)
|
| 332 |
+
|
| 333 |
+
args.sentence_class_num = getattr(args, "sentence_class_num", 2)
|
| 334 |
+
args.sent_loss = getattr(args, "sent_loss", False)
|
| 335 |
+
|
| 336 |
+
args.apply_bert_init = getattr(args, "apply_bert_init", False)
|
| 337 |
+
|
| 338 |
+
args.activation_fn = getattr(args, "activation_fn", "relu")
|
| 339 |
+
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
| 340 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@register_model_architecture("masked_lm", "bert_base")
|
| 344 |
+
def bert_base_architecture(args):
|
| 345 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
| 346 |
+
args.share_encoder_input_output_embed = getattr(
|
| 347 |
+
args, "share_encoder_input_output_embed", True
|
| 348 |
+
)
|
| 349 |
+
args.no_token_positional_embeddings = getattr(
|
| 350 |
+
args, "no_token_positional_embeddings", False
|
| 351 |
+
)
|
| 352 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
|
| 353 |
+
args.num_segment = getattr(args, "num_segment", 2)
|
| 354 |
+
|
| 355 |
+
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
| 356 |
+
|
| 357 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
| 358 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
|
| 359 |
+
|
| 360 |
+
args.sentence_class_num = getattr(args, "sentence_class_num", 2)
|
| 361 |
+
args.sent_loss = getattr(args, "sent_loss", True)
|
| 362 |
+
|
| 363 |
+
args.apply_bert_init = getattr(args, "apply_bert_init", True)
|
| 364 |
+
|
| 365 |
+
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
| 366 |
+
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
| 367 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
| 368 |
+
base_architecture(args)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@register_model_architecture("masked_lm", "bert_large")
|
| 372 |
+
def bert_large_architecture(args):
|
| 373 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
| 374 |
+
args.encoder_layers = getattr(args, "encoder_layers", 24)
|
| 375 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
| 376 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
| 377 |
+
bert_base_architecture(args)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
@register_model_architecture("masked_lm", "xlm_base")
|
| 381 |
+
def xlm_architecture(args):
|
| 382 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
| 383 |
+
args.share_encoder_input_output_embed = getattr(
|
| 384 |
+
args, "share_encoder_input_output_embed", True
|
| 385 |
+
)
|
| 386 |
+
args.no_token_positional_embeddings = getattr(
|
| 387 |
+
args, "no_token_positional_embeddings", False
|
| 388 |
+
)
|
| 389 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
|
| 390 |
+
args.num_segment = getattr(args, "num_segment", 1)
|
| 391 |
+
|
| 392 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
| 393 |
+
|
| 394 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
| 395 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
| 396 |
+
|
| 397 |
+
args.sent_loss = getattr(args, "sent_loss", False)
|
| 398 |
+
|
| 399 |
+
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
| 400 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
| 401 |
+
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
| 402 |
+
args.apply_bert_init = getattr(args, "apply_bert_init", True)
|
| 403 |
+
base_architecture(args)
|
fairseq-0.10.2/fairseq/models/model_utils.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@torch.jit.script
|
| 13 |
+
def script_skip_tensor_list(x: List[Tensor], mask):
|
| 14 |
+
res = [xi[mask] if xi.size(0) == mask.size(0) else xi[:, mask] for xi in x]
|
| 15 |
+
outputs = []
|
| 16 |
+
for i, t in enumerate(res):
|
| 17 |
+
if t.numel() != 0:
|
| 18 |
+
outputs.append(t)
|
| 19 |
+
else:
|
| 20 |
+
outputs.append(x[i])
|
| 21 |
+
return outputs
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@torch.jit.script
|
| 25 |
+
def script_skip_tensor(x: Tensor, mask):
|
| 26 |
+
# None case
|
| 27 |
+
if x.size(0) == 0:
|
| 28 |
+
return x
|
| 29 |
+
res = x[mask] if x.size(0) == mask.size(0) else x[:, mask]
|
| 30 |
+
if res.numel() == 0:
|
| 31 |
+
return x
|
| 32 |
+
else:
|
| 33 |
+
return res
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@torch.jit.script
|
| 37 |
+
def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):
|
| 38 |
+
"""
|
| 39 |
+
Expand 2D/3D tensor on dim=1
|
| 40 |
+
"""
|
| 41 |
+
if x is None:
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
assert x.dim() == 2 or x.dim() == 3
|
| 45 |
+
assert trg_dim >= x.size(1), (trg_dim, x.size())
|
| 46 |
+
if trg_dim == x.size(1):
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
dims = [x.size(0), trg_dim - x.size(1)]
|
| 50 |
+
if x.dim() == 3:
|
| 51 |
+
dims.append(x.size(2))
|
| 52 |
+
x = torch.cat([x, torch.zeros(dims).to(x).fill_(padding_idx)], 1)
|
| 53 |
+
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@torch.jit.script
|
| 58 |
+
def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor:
|
| 59 |
+
return x if x is not None else y
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@torch.jit.script
|
| 63 |
+
def fill_tensors(
|
| 64 |
+
x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int
|
| 65 |
+
) -> Optional[Tensor]:
|
| 66 |
+
"""
|
| 67 |
+
Filling tensor x with y at masked positions (dim=0).
|
| 68 |
+
"""
|
| 69 |
+
if x is None or x.size()[0] == 0 or y is None:
|
| 70 |
+
return x
|
| 71 |
+
assert x.dim() == y.dim() and mask.size(0) == x.size(0)
|
| 72 |
+
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
|
| 73 |
+
|
| 74 |
+
n_selected = mask.sum()
|
| 75 |
+
if n_selected == 0:
|
| 76 |
+
return x
|
| 77 |
+
assert n_selected == y.size(0)
|
| 78 |
+
if n_selected == x.size(0):
|
| 79 |
+
return y
|
| 80 |
+
|
| 81 |
+
if x.size(1) < y.size(1):
|
| 82 |
+
x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx)
|
| 83 |
+
x[mask] = y
|
| 84 |
+
elif x.size(1) > y.size(1):
|
| 85 |
+
x[mask] = torch.tensor(padding_idx).type_as(x)
|
| 86 |
+
if x.dim() == 2:
|
| 87 |
+
x[mask, : y.size(1)] = y
|
| 88 |
+
else:
|
| 89 |
+
x[mask, : y.size(1), :] = y
|
| 90 |
+
else:
|
| 91 |
+
x[mask] = y
|
| 92 |
+
return x
|
fairseq-0.10.2/fairseq/models/multilingual_transformer.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
|
| 8 |
+
from fairseq import utils
|
| 9 |
+
from fairseq.models import (
|
| 10 |
+
FairseqMultiModel,
|
| 11 |
+
register_model,
|
| 12 |
+
register_model_architecture,
|
| 13 |
+
)
|
| 14 |
+
from fairseq.models.transformer import (
|
| 15 |
+
Embedding,
|
| 16 |
+
TransformerDecoder,
|
| 17 |
+
TransformerEncoder,
|
| 18 |
+
TransformerModel,
|
| 19 |
+
base_architecture,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@register_model("multilingual_transformer")
|
| 24 |
+
class MultilingualTransformerModel(FairseqMultiModel):
|
| 25 |
+
"""Train Transformer models for multiple language pairs simultaneously.
|
| 26 |
+
|
| 27 |
+
Requires `--task multilingual_translation`.
|
| 28 |
+
|
| 29 |
+
We inherit all arguments from TransformerModel and assume that all language
|
| 30 |
+
pairs use a single Transformer architecture. In addition, we provide several
|
| 31 |
+
options that are specific to the multilingual setting.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
--share-encoder-embeddings: share encoder embeddings across all source languages
|
| 35 |
+
--share-decoder-embeddings: share decoder embeddings across all target languages
|
| 36 |
+
--share-encoders: share all encoder params (incl. embeddings) across all source languages
|
| 37 |
+
--share-decoders: share all decoder params (incl. embeddings) across all target languages
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, encoders, decoders):
|
| 41 |
+
super().__init__(encoders, decoders)
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def add_args(parser):
|
| 45 |
+
"""Add model-specific arguments to the parser."""
|
| 46 |
+
TransformerModel.add_args(parser)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--share-encoder-embeddings",
|
| 49 |
+
action="store_true",
|
| 50 |
+
help="share encoder embeddings across languages",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--share-decoder-embeddings",
|
| 54 |
+
action="store_true",
|
| 55 |
+
help="share decoder embeddings across languages",
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--share-encoders",
|
| 59 |
+
action="store_true",
|
| 60 |
+
help="share encoders across languages",
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--share-decoders",
|
| 64 |
+
action="store_true",
|
| 65 |
+
help="share decoders across languages",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
@classmethod
|
| 69 |
+
def build_model(cls, args, task):
|
| 70 |
+
"""Build a new model instance."""
|
| 71 |
+
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
|
| 72 |
+
|
| 73 |
+
assert isinstance(task, MultilingualTranslationTask)
|
| 74 |
+
|
| 75 |
+
# make sure all arguments are present in older models
|
| 76 |
+
base_multilingual_architecture(args)
|
| 77 |
+
|
| 78 |
+
if not hasattr(args, "max_source_positions"):
|
| 79 |
+
args.max_source_positions = 1024
|
| 80 |
+
if not hasattr(args, "max_target_positions"):
|
| 81 |
+
args.max_target_positions = 1024
|
| 82 |
+
|
| 83 |
+
src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs]
|
| 84 |
+
tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs]
|
| 85 |
+
|
| 86 |
+
if args.share_encoders:
|
| 87 |
+
args.share_encoder_embeddings = True
|
| 88 |
+
if args.share_decoders:
|
| 89 |
+
args.share_decoder_embeddings = True
|
| 90 |
+
|
| 91 |
+
def build_embedding(dictionary, embed_dim, path=None):
|
| 92 |
+
num_embeddings = len(dictionary)
|
| 93 |
+
padding_idx = dictionary.pad()
|
| 94 |
+
emb = Embedding(num_embeddings, embed_dim, padding_idx)
|
| 95 |
+
# if provided, load from preloaded dictionaries
|
| 96 |
+
if path:
|
| 97 |
+
embed_dict = utils.parse_embedding(path)
|
| 98 |
+
utils.load_embedding(embed_dict, dictionary, emb)
|
| 99 |
+
return emb
|
| 100 |
+
|
| 101 |
+
# build shared embeddings (if applicable)
|
| 102 |
+
shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
|
| 103 |
+
if args.share_all_embeddings:
|
| 104 |
+
if args.encoder_embed_dim != args.decoder_embed_dim:
|
| 105 |
+
raise ValueError(
|
| 106 |
+
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
|
| 107 |
+
)
|
| 108 |
+
if args.decoder_embed_path and (
|
| 109 |
+
args.decoder_embed_path != args.encoder_embed_path
|
| 110 |
+
):
|
| 111 |
+
raise ValueError(
|
| 112 |
+
"--share-all-embeddings not compatible with --decoder-embed-path"
|
| 113 |
+
)
|
| 114 |
+
shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
|
| 115 |
+
dicts=task.dicts,
|
| 116 |
+
langs=task.langs,
|
| 117 |
+
embed_dim=args.encoder_embed_dim,
|
| 118 |
+
build_embedding=build_embedding,
|
| 119 |
+
pretrained_embed_path=args.encoder_embed_path,
|
| 120 |
+
)
|
| 121 |
+
shared_decoder_embed_tokens = shared_encoder_embed_tokens
|
| 122 |
+
args.share_decoder_input_output_embed = True
|
| 123 |
+
else:
|
| 124 |
+
if args.share_encoder_embeddings:
|
| 125 |
+
shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
|
| 126 |
+
dicts=task.dicts,
|
| 127 |
+
langs=src_langs,
|
| 128 |
+
embed_dim=args.encoder_embed_dim,
|
| 129 |
+
build_embedding=build_embedding,
|
| 130 |
+
pretrained_embed_path=args.encoder_embed_path,
|
| 131 |
+
)
|
| 132 |
+
if args.share_decoder_embeddings:
|
| 133 |
+
shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
|
| 134 |
+
dicts=task.dicts,
|
| 135 |
+
langs=tgt_langs,
|
| 136 |
+
embed_dim=args.decoder_embed_dim,
|
| 137 |
+
build_embedding=build_embedding,
|
| 138 |
+
pretrained_embed_path=args.decoder_embed_path,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# encoders/decoders for each language
|
| 142 |
+
lang_encoders, lang_decoders = {}, {}
|
| 143 |
+
|
| 144 |
+
def get_encoder(lang):
|
| 145 |
+
if lang not in lang_encoders:
|
| 146 |
+
if shared_encoder_embed_tokens is not None:
|
| 147 |
+
encoder_embed_tokens = shared_encoder_embed_tokens
|
| 148 |
+
else:
|
| 149 |
+
encoder_embed_tokens = build_embedding(
|
| 150 |
+
task.dicts[lang],
|
| 151 |
+
args.encoder_embed_dim,
|
| 152 |
+
args.encoder_embed_path,
|
| 153 |
+
)
|
| 154 |
+
lang_encoders[lang] = cls._get_module_class(
|
| 155 |
+
True, args, task.dicts[lang], encoder_embed_tokens, src_langs
|
| 156 |
+
)
|
| 157 |
+
return lang_encoders[lang]
|
| 158 |
+
|
| 159 |
+
def get_decoder(lang):
|
| 160 |
+
if lang not in lang_decoders:
|
| 161 |
+
if shared_decoder_embed_tokens is not None:
|
| 162 |
+
decoder_embed_tokens = shared_decoder_embed_tokens
|
| 163 |
+
else:
|
| 164 |
+
decoder_embed_tokens = build_embedding(
|
| 165 |
+
task.dicts[lang],
|
| 166 |
+
args.decoder_embed_dim,
|
| 167 |
+
args.decoder_embed_path,
|
| 168 |
+
)
|
| 169 |
+
lang_decoders[lang] = cls._get_module_class(
|
| 170 |
+
False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs
|
| 171 |
+
)
|
| 172 |
+
return lang_decoders[lang]
|
| 173 |
+
|
| 174 |
+
# shared encoders/decoders (if applicable)
|
| 175 |
+
shared_encoder, shared_decoder = None, None
|
| 176 |
+
if args.share_encoders:
|
| 177 |
+
shared_encoder = get_encoder(src_langs[0])
|
| 178 |
+
if args.share_decoders:
|
| 179 |
+
shared_decoder = get_decoder(tgt_langs[0])
|
| 180 |
+
|
| 181 |
+
encoders, decoders = OrderedDict(), OrderedDict()
|
| 182 |
+
for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs):
|
| 183 |
+
encoders[lang_pair] = (
|
| 184 |
+
shared_encoder if shared_encoder is not None else get_encoder(src)
|
| 185 |
+
)
|
| 186 |
+
decoders[lang_pair] = (
|
| 187 |
+
shared_decoder if shared_decoder is not None else get_decoder(tgt)
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
return MultilingualTransformerModel(encoders, decoders)
|
| 191 |
+
|
| 192 |
+
@classmethod
|
| 193 |
+
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
|
| 194 |
+
module_class = TransformerEncoder if is_encoder else TransformerDecoder
|
| 195 |
+
return module_class(args, lang_dict, embed_tokens)
|
| 196 |
+
|
| 197 |
+
def load_state_dict(self, state_dict, strict=True, args=None):
|
| 198 |
+
state_dict_subset = state_dict.copy()
|
| 199 |
+
for k, _ in state_dict.items():
|
| 200 |
+
assert k.startswith("models.")
|
| 201 |
+
lang_pair = k.split(".")[1]
|
| 202 |
+
if lang_pair not in self.models:
|
| 203 |
+
del state_dict_subset[k]
|
| 204 |
+
super().load_state_dict(state_dict_subset, strict=strict, args=args)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@register_model_architecture("multilingual_transformer", "multilingual_transformer")
|
| 208 |
+
def base_multilingual_architecture(args):
|
| 209 |
+
base_architecture(args)
|
| 210 |
+
args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False)
|
| 211 |
+
args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False)
|
| 212 |
+
args.share_encoders = getattr(args, "share_encoders", False)
|
| 213 |
+
args.share_decoders = getattr(args, "share_decoders", False)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@register_model_architecture(
|
| 217 |
+
"multilingual_transformer", "multilingual_transformer_iwslt_de_en"
|
| 218 |
+
)
|
| 219 |
+
def multilingual_transformer_iwslt_de_en(args):
|
| 220 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
| 221 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
|
| 222 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
| 223 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
| 224 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
| 225 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
|
| 226 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
| 227 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
| 228 |
+
base_multilingual_architecture(args)
|
fairseq-0.10.2/fairseq/models/roberta/__pycache__/hub_interface.cpython-310.pyc
ADDED
|
Binary file (8.12 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/roberta/__pycache__/model_camembert.cpython-310.pyc
ADDED
|
Binary file (1.85 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/roberta/__pycache__/model_xlmr.cpython-310.pyc
ADDED
|
Binary file (1.37 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/roberta/model_xlmr.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""
|
| 6 |
+
Unsupervised Cross-lingual Representation Learning at Scale
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from fairseq.models import register_model
|
| 10 |
+
|
| 11 |
+
from .hub_interface import RobertaHubInterface
|
| 12 |
+
from .model import RobertaModel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@register_model("xlmr")
|
| 16 |
+
class XLMRModel(RobertaModel):
|
| 17 |
+
@classmethod
|
| 18 |
+
def hub_models(cls):
|
| 19 |
+
return {
|
| 20 |
+
"xlmr.base": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz",
|
| 21 |
+
"xlmr.large": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz",
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
@classmethod
|
| 25 |
+
def from_pretrained(
|
| 26 |
+
cls,
|
| 27 |
+
model_name_or_path,
|
| 28 |
+
checkpoint_file="model.pt",
|
| 29 |
+
data_name_or_path=".",
|
| 30 |
+
bpe="sentencepiece",
|
| 31 |
+
**kwargs
|
| 32 |
+
):
|
| 33 |
+
from fairseq import hub_utils
|
| 34 |
+
|
| 35 |
+
x = hub_utils.from_pretrained(
|
| 36 |
+
model_name_or_path,
|
| 37 |
+
checkpoint_file,
|
| 38 |
+
data_name_or_path,
|
| 39 |
+
archive_map=cls.hub_models(),
|
| 40 |
+
bpe=bpe,
|
| 41 |
+
load_checkpoint_heads=True,
|
| 42 |
+
**kwargs,
|
| 43 |
+
)
|
| 44 |
+
return RobertaHubInterface(x["args"], x["task"], x["models"][0])
|
fairseq-0.10.2/fairseq/modules/__pycache__/beamable_mm.cpython-310.pyc
ADDED
|
Binary file (1.64 kB). View file
|
|
|