Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq-0.10.2/fairseq/data/__pycache__/backtranslation_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/base_wrapper_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/concat_sentences_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/data_utils.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/denoising_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/fasta_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/iterators.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/list_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/lm_context_window_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/lru_cache_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/mask_tokens_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/noising.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/num_samples_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/numel_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/pad_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/plasma_utils.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/prepend_token_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/raw_label_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/replace_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/shorten_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/sort_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/strip_token_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/token_block_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/transform_eos_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/composite_encoder.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/transformer_align.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/bart/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/bart/__pycache__/hub_interface.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/bart/__pycache__/model.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/bart/model.py +368 -0
- fairseq-0.10.2/fairseq/models/nat/__init__.py +13 -0
- fairseq-0.10.2/fairseq/models/nat/__pycache__/cmlm_transformer.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/nat/cmlm_transformer.py +162 -0
- fairseq-0.10.2/fairseq/models/nat/fairseq_nat_model.py +159 -0
- fairseq-0.10.2/fairseq/models/nat/insertion_transformer.py +280 -0
- fairseq-0.10.2/fairseq/models/nat/iterative_nonautoregressive_transformer.py +228 -0
- fairseq-0.10.2/fairseq/models/nat/levenshtein_utils.py +293 -0
- fairseq-0.10.2/fairseq/models/nat/nat_crf_transformer.py +121 -0
- fairseq-0.10.2/fairseq/models/nat/nonautoregressive_ensembles.py +254 -0
- fairseq-0.10.2/fairseq/models/nat/nonautoregressive_transformer.py +440 -0
- fairseq-0.10.2/fairseq/models/roberta/__init__.py +9 -0
- fairseq-0.10.2/fairseq/models/roberta/__pycache__/model.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/models/roberta/alignment_utils.py +118 -0
- fairseq-0.10.2/fairseq/models/roberta/hub_interface.py +235 -0
- fairseq-0.10.2/fairseq/models/roberta/model.py +524 -0
- fairseq-0.10.2/fairseq/models/roberta/model_camembert.py +50 -0
- fairseq-0.10.2/fairseq/models/wav2vec/__init__.py +8 -0
- fairseq-0.10.2/fairseq/models/wav2vec/__pycache__/__init__.cpython-310.pyc +0 -0
fairseq-0.10.2/fairseq/data/__pycache__/backtranslation_dataset.cpython-310.pyc
ADDED
|
Binary file (6.74 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/base_wrapper_dataset.cpython-310.pyc
ADDED
|
Binary file (3.23 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/concat_sentences_dataset.cpython-310.pyc
ADDED
|
Binary file (3.34 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/data_utils.cpython-310.pyc
ADDED
|
Binary file (15.6 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/denoising_dataset.cpython-310.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/fasta_dataset.cpython-310.pyc
ADDED
|
Binary file (3.83 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/iterators.cpython-310.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/list_dataset.cpython-310.pyc
ADDED
|
Binary file (1.39 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/lm_context_window_dataset.cpython-310.pyc
ADDED
|
Binary file (2.99 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/lru_cache_dataset.cpython-310.pyc
ADDED
|
Binary file (973 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/mask_tokens_dataset.cpython-310.pyc
ADDED
|
Binary file (5.1 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/noising.cpython-310.pyc
ADDED
|
Binary file (9.37 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/num_samples_dataset.cpython-310.pyc
ADDED
|
Binary file (792 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/numel_dataset.cpython-310.pyc
ADDED
|
Binary file (1.16 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/pad_dataset.cpython-310.pyc
ADDED
|
Binary file (1.36 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/plasma_utils.cpython-310.pyc
ADDED
|
Binary file (2.57 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/prepend_token_dataset.cpython-310.pyc
ADDED
|
Binary file (1.41 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/raw_label_dataset.cpython-310.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/replace_dataset.cpython-310.pyc
ADDED
|
Binary file (1.53 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/shorten_dataset.cpython-310.pyc
ADDED
|
Binary file (2.84 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/sort_dataset.cpython-310.pyc
ADDED
|
Binary file (1.01 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/strip_token_dataset.cpython-310.pyc
ADDED
|
Binary file (922 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/token_block_dataset.cpython-310.pyc
ADDED
|
Binary file (5.03 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/transform_eos_dataset.cpython-310.pyc
ADDED
|
Binary file (4.15 kB). View file
|
|
|
fairseq-0.10.2/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-310.pyc
ADDED
|
Binary file (3.68 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/composite_encoder.cpython-310.pyc
ADDED
|
Binary file (2.36 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/transformer_align.cpython-310.pyc
ADDED
|
Binary file (3 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-310.pyc
ADDED
|
Binary file (5.34 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/bart/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (215 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/models/bart/__pycache__/hub_interface.cpython-310.pyc
ADDED
|
Binary file (7.4 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/bart/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (9.81 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/bart/model.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""
|
| 6 |
+
BART: Denoising Sequence-to-Sequence Pre-training for
|
| 7 |
+
Natural Language Generation, Translation, and Comprehension
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from fairseq import utils
|
| 15 |
+
from fairseq.models import register_model, register_model_architecture
|
| 16 |
+
from fairseq.models.transformer import TransformerModel
|
| 17 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
| 18 |
+
|
| 19 |
+
from .hub_interface import BARTHubInterface
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@register_model("bart")
|
| 26 |
+
class BARTModel(TransformerModel):
|
| 27 |
+
@classmethod
|
| 28 |
+
def hub_models(cls):
|
| 29 |
+
return {
|
| 30 |
+
"bart.base": "http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz",
|
| 31 |
+
"bart.large": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz",
|
| 32 |
+
"bart.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz",
|
| 33 |
+
"bart.large.cnn": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz",
|
| 34 |
+
"bart.large.xsum": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz",
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
def __init__(self, args, encoder, decoder):
|
| 38 |
+
super().__init__(args, encoder, decoder)
|
| 39 |
+
|
| 40 |
+
# We follow BERT's random weight initialization
|
| 41 |
+
self.apply(init_bert_params)
|
| 42 |
+
|
| 43 |
+
self.classification_heads = nn.ModuleDict()
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def add_args(parser):
|
| 47 |
+
super(BARTModel, BARTModel).add_args(parser)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--pooler-dropout",
|
| 50 |
+
type=float,
|
| 51 |
+
metavar="D",
|
| 52 |
+
help="dropout probability in the masked_lm pooler layers",
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--pooler-activation-fn",
|
| 56 |
+
choices=utils.get_available_activation_fns(),
|
| 57 |
+
help="activation function to use for pooler layer",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--spectral-norm-classification-head",
|
| 61 |
+
action="store_true",
|
| 62 |
+
help="Apply spectral normalization on the classification head",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def supported_targets(self):
|
| 67 |
+
return {"self"}
|
| 68 |
+
|
| 69 |
+
def forward(
|
| 70 |
+
self,
|
| 71 |
+
src_tokens,
|
| 72 |
+
src_lengths,
|
| 73 |
+
prev_output_tokens,
|
| 74 |
+
features_only=False,
|
| 75 |
+
classification_head_name=None,
|
| 76 |
+
token_embeddings=None,
|
| 77 |
+
**kwargs,
|
| 78 |
+
):
|
| 79 |
+
if classification_head_name is not None:
|
| 80 |
+
features_only = True
|
| 81 |
+
|
| 82 |
+
encoder_out = self.encoder(
|
| 83 |
+
src_tokens,
|
| 84 |
+
src_lengths=src_lengths,
|
| 85 |
+
token_embeddings=token_embeddings,
|
| 86 |
+
**kwargs,
|
| 87 |
+
)
|
| 88 |
+
x, extra = self.decoder(
|
| 89 |
+
prev_output_tokens,
|
| 90 |
+
encoder_out=encoder_out,
|
| 91 |
+
features_only=features_only,
|
| 92 |
+
**kwargs,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
if classification_head_name is not None:
|
| 96 |
+
sentence_representation = x[
|
| 97 |
+
src_tokens.eq(self.encoder.dictionary.eos()), :
|
| 98 |
+
].view(x.size(0), -1, x.size(-1))[:, -1, :]
|
| 99 |
+
x = self.classification_heads[classification_head_name](
|
| 100 |
+
sentence_representation
|
| 101 |
+
)
|
| 102 |
+
return x, extra
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def from_pretrained(
|
| 106 |
+
cls,
|
| 107 |
+
model_name_or_path,
|
| 108 |
+
checkpoint_file="model.pt",
|
| 109 |
+
data_name_or_path=".",
|
| 110 |
+
bpe="gpt2",
|
| 111 |
+
**kwargs,
|
| 112 |
+
):
|
| 113 |
+
from fairseq import hub_utils
|
| 114 |
+
|
| 115 |
+
x = hub_utils.from_pretrained(
|
| 116 |
+
model_name_or_path,
|
| 117 |
+
checkpoint_file,
|
| 118 |
+
data_name_or_path,
|
| 119 |
+
archive_map=cls.hub_models(),
|
| 120 |
+
bpe=bpe,
|
| 121 |
+
load_checkpoint_heads=True,
|
| 122 |
+
**kwargs,
|
| 123 |
+
)
|
| 124 |
+
return BARTHubInterface(x["args"], x["task"], x["models"][0])
|
| 125 |
+
|
| 126 |
+
def register_classification_head(
|
| 127 |
+
self, name, num_classes=None, inner_dim=None, **kwargs
|
| 128 |
+
):
|
| 129 |
+
"""Register a classification head."""
|
| 130 |
+
logger.info("Registering classification head: {0}".format(name))
|
| 131 |
+
if name in self.classification_heads:
|
| 132 |
+
prev_num_classes = self.classification_heads[name].out_proj.out_features
|
| 133 |
+
prev_inner_dim = self.classification_heads[name].dense.out_features
|
| 134 |
+
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
|
| 135 |
+
logger.warning(
|
| 136 |
+
're-registering head "{}" with num_classes {} (prev: {}) '
|
| 137 |
+
"and inner_dim {} (prev: {})".format(
|
| 138 |
+
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
|
| 139 |
+
)
|
| 140 |
+
)
|
| 141 |
+
self.classification_heads[name] = BARTClassificationHead(
|
| 142 |
+
input_dim=self.args.encoder_embed_dim,
|
| 143 |
+
inner_dim=inner_dim or self.args.encoder_embed_dim,
|
| 144 |
+
num_classes=num_classes,
|
| 145 |
+
activation_fn=self.args.pooler_activation_fn,
|
| 146 |
+
pooler_dropout=self.args.pooler_dropout,
|
| 147 |
+
do_spectral_norm=self.args.spectral_norm_classification_head,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 151 |
+
super().upgrade_state_dict_named(state_dict, name)
|
| 152 |
+
|
| 153 |
+
prefix = name + "." if name != "" else ""
|
| 154 |
+
current_head_names = (
|
| 155 |
+
[]
|
| 156 |
+
if not hasattr(self, "classification_heads")
|
| 157 |
+
else self.classification_heads.keys()
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Handle new classification heads present in the state dict.
|
| 161 |
+
keys_to_delete = []
|
| 162 |
+
for k in state_dict.keys():
|
| 163 |
+
if not k.startswith(prefix + "classification_heads."):
|
| 164 |
+
continue
|
| 165 |
+
|
| 166 |
+
head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
|
| 167 |
+
num_classes = state_dict[
|
| 168 |
+
prefix + "classification_heads." + head_name + ".out_proj.weight"
|
| 169 |
+
].size(0)
|
| 170 |
+
inner_dim = state_dict[
|
| 171 |
+
prefix + "classification_heads." + head_name + ".dense.weight"
|
| 172 |
+
].size(0)
|
| 173 |
+
|
| 174 |
+
if getattr(self.args, "load_checkpoint_heads", False):
|
| 175 |
+
if head_name not in current_head_names:
|
| 176 |
+
self.register_classification_head(head_name, num_classes, inner_dim)
|
| 177 |
+
else:
|
| 178 |
+
if head_name not in current_head_names:
|
| 179 |
+
logger.warning(
|
| 180 |
+
"deleting classification head ({}) from checkpoint "
|
| 181 |
+
"not present in current model: {}".format(head_name, k)
|
| 182 |
+
)
|
| 183 |
+
keys_to_delete.append(k)
|
| 184 |
+
elif (
|
| 185 |
+
num_classes
|
| 186 |
+
!= self.classification_heads[head_name].out_proj.out_features
|
| 187 |
+
or inner_dim
|
| 188 |
+
!= self.classification_heads[head_name].dense.out_features
|
| 189 |
+
):
|
| 190 |
+
logger.warning(
|
| 191 |
+
"deleting classification head ({}) from checkpoint "
|
| 192 |
+
"with different dimensions than current model: {}".format(
|
| 193 |
+
head_name, k
|
| 194 |
+
)
|
| 195 |
+
)
|
| 196 |
+
keys_to_delete.append(k)
|
| 197 |
+
for k in keys_to_delete:
|
| 198 |
+
del state_dict[k]
|
| 199 |
+
|
| 200 |
+
def truncate_emb(key):
|
| 201 |
+
if key in state_dict:
|
| 202 |
+
state_dict[key] = state_dict[key][:-1, :]
|
| 203 |
+
|
| 204 |
+
# When finetuning on translation task, remove last row of
|
| 205 |
+
# embedding matrix that corresponds to mask_idx token.
|
| 206 |
+
loaded_dict_size = state_dict["encoder.embed_tokens.weight"].size(0)
|
| 207 |
+
if (
|
| 208 |
+
loaded_dict_size == len(self.encoder.dictionary) + 1
|
| 209 |
+
and "<mask>" not in self.encoder.dictionary
|
| 210 |
+
):
|
| 211 |
+
truncate_emb("encoder.embed_tokens.weight")
|
| 212 |
+
truncate_emb("decoder.embed_tokens.weight")
|
| 213 |
+
truncate_emb("encoder.output_projection.weight")
|
| 214 |
+
truncate_emb("decoder.output_projection.weight")
|
| 215 |
+
|
| 216 |
+
# When continued pretraining on new set of languages for mbart,
|
| 217 |
+
# add extra lang embeddings at the end of embed_tokens.
|
| 218 |
+
# Note: newly added languages are assumed to have been added at the end.
|
| 219 |
+
if self.args.task == "multilingual_denoising" and loaded_dict_size < len(
|
| 220 |
+
self.encoder.dictionary
|
| 221 |
+
):
|
| 222 |
+
logger.info(
|
| 223 |
+
"Adding extra language embeddings not found in pretrained model for "
|
| 224 |
+
"continued pretraining of MBART on new set of languages."
|
| 225 |
+
)
|
| 226 |
+
loaded_mask_token_embedding = state_dict["encoder.embed_tokens.weight"][
|
| 227 |
+
-1, :
|
| 228 |
+
]
|
| 229 |
+
|
| 230 |
+
num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size
|
| 231 |
+
embed_dim = state_dict["encoder.embed_tokens.weight"].size(1)
|
| 232 |
+
|
| 233 |
+
new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim)
|
| 234 |
+
nn.init.normal_(new_lang_embed_to_add, mean=0, std=embed_dim ** -0.5)
|
| 235 |
+
new_lang_embed_to_add = new_lang_embed_to_add.to(
|
| 236 |
+
dtype=state_dict["encoder.embed_tokens.weight"].dtype,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
state_dict["encoder.embed_tokens.weight"] = torch.cat(
|
| 240 |
+
[
|
| 241 |
+
state_dict["encoder.embed_tokens.weight"][
|
| 242 |
+
: loaded_dict_size - 1, :
|
| 243 |
+
],
|
| 244 |
+
new_lang_embed_to_add,
|
| 245 |
+
loaded_mask_token_embedding.unsqueeze(0),
|
| 246 |
+
]
|
| 247 |
+
)
|
| 248 |
+
state_dict["decoder.embed_tokens.weight"] = torch.cat(
|
| 249 |
+
[
|
| 250 |
+
state_dict["decoder.embed_tokens.weight"][
|
| 251 |
+
: loaded_dict_size - 1, :
|
| 252 |
+
],
|
| 253 |
+
new_lang_embed_to_add,
|
| 254 |
+
loaded_mask_token_embedding.unsqueeze(0),
|
| 255 |
+
]
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Copy any newly-added classification heads into the state dict
|
| 259 |
+
# with their current weights.
|
| 260 |
+
if hasattr(self, "classification_heads"):
|
| 261 |
+
cur_state = self.classification_heads.state_dict()
|
| 262 |
+
for k, v in cur_state.items():
|
| 263 |
+
if prefix + "classification_heads." + k not in state_dict:
|
| 264 |
+
logger.info("Overwriting", prefix + "classification_heads." + k)
|
| 265 |
+
state_dict[prefix + "classification_heads." + k] = v
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class BARTClassificationHead(nn.Module):
|
| 269 |
+
"""Head for sentence-level classification tasks."""
|
| 270 |
+
|
| 271 |
+
def __init__(
|
| 272 |
+
self,
|
| 273 |
+
input_dim,
|
| 274 |
+
inner_dim,
|
| 275 |
+
num_classes,
|
| 276 |
+
activation_fn,
|
| 277 |
+
pooler_dropout,
|
| 278 |
+
do_spectral_norm=False,
|
| 279 |
+
):
|
| 280 |
+
super().__init__()
|
| 281 |
+
self.dense = nn.Linear(input_dim, inner_dim)
|
| 282 |
+
self.activation_fn = utils.get_activation_fn(activation_fn)
|
| 283 |
+
self.dropout = nn.Dropout(p=pooler_dropout)
|
| 284 |
+
self.out_proj = nn.Linear(inner_dim, num_classes)
|
| 285 |
+
|
| 286 |
+
if do_spectral_norm:
|
| 287 |
+
self.out_proj = torch.nn.utils.spectral_norm(self.out_proj)
|
| 288 |
+
|
| 289 |
+
def forward(self, features, **kwargs):
|
| 290 |
+
x = features
|
| 291 |
+
x = self.dropout(x)
|
| 292 |
+
x = self.dense(x)
|
| 293 |
+
x = self.activation_fn(x)
|
| 294 |
+
x = self.dropout(x)
|
| 295 |
+
x = self.out_proj(x)
|
| 296 |
+
return x
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
@register_model_architecture("bart", "bart_large")
|
| 300 |
+
def bart_large_architecture(args):
|
| 301 |
+
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
| 302 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
| 303 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1024)
|
| 304 |
+
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
| 305 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
| 306 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
| 307 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
|
| 308 |
+
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
| 309 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
| 310 |
+
args.decoder_ffn_embed_dim = getattr(
|
| 311 |
+
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
| 312 |
+
)
|
| 313 |
+
args.decoder_layers = getattr(args, "decoder_layers", 12)
|
| 314 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
| 315 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
| 316 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
|
| 317 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
| 318 |
+
args.relu_dropout = getattr(args, "relu_dropout", 0.0)
|
| 319 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 320 |
+
args.max_target_positions = getattr(args, "max_target_positions", 1024)
|
| 321 |
+
args.max_source_positions = getattr(args, "max_source_positions", 1024)
|
| 322 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
| 323 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
| 324 |
+
args.share_decoder_input_output_embed = getattr(
|
| 325 |
+
args, "share_decoder_input_output_embed", True
|
| 326 |
+
)
|
| 327 |
+
args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
|
| 328 |
+
|
| 329 |
+
args.decoder_output_dim = getattr(
|
| 330 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
| 331 |
+
)
|
| 332 |
+
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
| 333 |
+
|
| 334 |
+
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
|
| 335 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
|
| 336 |
+
|
| 337 |
+
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
| 338 |
+
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
| 339 |
+
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
@register_model_architecture("bart", "bart_base")
|
| 343 |
+
def bart_base_architecture(args):
|
| 344 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
| 345 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 768)
|
| 346 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
| 347 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
| 348 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
| 349 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
|
| 350 |
+
bart_large_architecture(args)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
@register_model_architecture("bart", "mbart_large")
|
| 354 |
+
def mbart_large_architecture(args):
|
| 355 |
+
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
| 356 |
+
bart_large_architecture(args)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
@register_model_architecture("bart", "mbart_base")
|
| 360 |
+
def mbart_base_architecture(args):
|
| 361 |
+
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
| 362 |
+
bart_base_architecture(args)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
@register_model_architecture("bart", "mbart_base_wmt20")
|
| 366 |
+
def mbart_base_wmt20_architecture(args):
|
| 367 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
| 368 |
+
mbart_base_architecture(args)
|
fairseq-0.10.2/fairseq/models/nat/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""isort:skip_file"""
|
| 6 |
+
|
| 7 |
+
from .fairseq_nat_model import *
|
| 8 |
+
from .nonautoregressive_transformer import *
|
| 9 |
+
from .nat_crf_transformer import *
|
| 10 |
+
from .iterative_nonautoregressive_transformer import *
|
| 11 |
+
from .cmlm_transformer import *
|
| 12 |
+
from .levenshtein_transformer import *
|
| 13 |
+
from .insertion_transformer import *
|
fairseq-0.10.2/fairseq/models/nat/__pycache__/cmlm_transformer.cpython-310.pyc
ADDED
|
Binary file (4.38 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/nat/cmlm_transformer.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
This file implements:
|
| 8 |
+
Ghazvininejad, Marjan, et al.
|
| 9 |
+
"Constant-time machine translation with conditional masked language models."
|
| 10 |
+
arXiv preprint arXiv:1904.09324 (2019).
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from fairseq.models import register_model, register_model_architecture
|
| 14 |
+
from fairseq.models.nat import NATransformerModel
|
| 15 |
+
from fairseq.utils import new_arange
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _skeptical_unmasking(output_scores, output_masks, p):
|
| 19 |
+
sorted_index = output_scores.sort(-1)[1]
|
| 20 |
+
boundary_len = (
|
| 21 |
+
(output_masks.sum(1, keepdim=True).type_as(output_scores) - 2) * p
|
| 22 |
+
).long()
|
| 23 |
+
skeptical_mask = new_arange(output_masks) < boundary_len
|
| 24 |
+
return skeptical_mask.scatter(1, sorted_index, skeptical_mask)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@register_model("cmlm_transformer")
|
| 28 |
+
class CMLMNATransformerModel(NATransformerModel):
|
| 29 |
+
@staticmethod
|
| 30 |
+
def add_args(parser):
|
| 31 |
+
NATransformerModel.add_args(parser)
|
| 32 |
+
|
| 33 |
+
def forward(
|
| 34 |
+
self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
|
| 35 |
+
):
|
| 36 |
+
assert not self.decoder.src_embedding_copy, "do not support embedding copy."
|
| 37 |
+
|
| 38 |
+
# encoding
|
| 39 |
+
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
|
| 40 |
+
# length prediction
|
| 41 |
+
length_out = self.decoder.forward_length(
|
| 42 |
+
normalize=False, encoder_out=encoder_out
|
| 43 |
+
)
|
| 44 |
+
length_tgt = self.decoder.forward_length_prediction(
|
| 45 |
+
length_out, encoder_out, tgt_tokens
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# decoding
|
| 49 |
+
word_ins_out = self.decoder(
|
| 50 |
+
normalize=False,
|
| 51 |
+
prev_output_tokens=prev_output_tokens,
|
| 52 |
+
encoder_out=encoder_out,
|
| 53 |
+
)
|
| 54 |
+
word_ins_mask = prev_output_tokens.eq(self.unk)
|
| 55 |
+
|
| 56 |
+
return {
|
| 57 |
+
"word_ins": {
|
| 58 |
+
"out": word_ins_out,
|
| 59 |
+
"tgt": tgt_tokens,
|
| 60 |
+
"mask": word_ins_mask,
|
| 61 |
+
"ls": self.args.label_smoothing,
|
| 62 |
+
"nll_loss": True,
|
| 63 |
+
},
|
| 64 |
+
"length": {
|
| 65 |
+
"out": length_out,
|
| 66 |
+
"tgt": length_tgt,
|
| 67 |
+
"factor": self.decoder.length_loss_factor,
|
| 68 |
+
},
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
|
| 72 |
+
|
| 73 |
+
step = decoder_out.step
|
| 74 |
+
max_step = decoder_out.max_step
|
| 75 |
+
|
| 76 |
+
output_tokens = decoder_out.output_tokens
|
| 77 |
+
output_scores = decoder_out.output_scores
|
| 78 |
+
history = decoder_out.history
|
| 79 |
+
|
| 80 |
+
# execute the decoder
|
| 81 |
+
output_masks = output_tokens.eq(self.unk)
|
| 82 |
+
_scores, _tokens = self.decoder(
|
| 83 |
+
normalize=True,
|
| 84 |
+
prev_output_tokens=output_tokens,
|
| 85 |
+
encoder_out=encoder_out,
|
| 86 |
+
).max(-1)
|
| 87 |
+
output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
|
| 88 |
+
output_scores.masked_scatter_(output_masks, _scores[output_masks])
|
| 89 |
+
|
| 90 |
+
if history is not None:
|
| 91 |
+
history.append(output_tokens.clone())
|
| 92 |
+
|
| 93 |
+
# skeptical decoding (depend on the maximum decoding steps.)
|
| 94 |
+
if (step + 1) < max_step:
|
| 95 |
+
skeptical_mask = _skeptical_unmasking(
|
| 96 |
+
output_scores, output_tokens.ne(self.pad), 1 - (step + 1) / max_step
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
output_tokens.masked_fill_(skeptical_mask, self.unk)
|
| 100 |
+
output_scores.masked_fill_(skeptical_mask, 0.0)
|
| 101 |
+
|
| 102 |
+
if history is not None:
|
| 103 |
+
history.append(output_tokens.clone())
|
| 104 |
+
|
| 105 |
+
return decoder_out._replace(
|
| 106 |
+
output_tokens=output_tokens,
|
| 107 |
+
output_scores=output_scores,
|
| 108 |
+
attn=None,
|
| 109 |
+
history=history,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@register_model_architecture("cmlm_transformer", "cmlm_transformer")
|
| 114 |
+
def cmlm_base_architecture(args):
|
| 115 |
+
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
| 116 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
| 117 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
| 118 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
| 119 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
| 120 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
| 121 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
| 122 |
+
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
| 123 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
| 124 |
+
args.decoder_ffn_embed_dim = getattr(
|
| 125 |
+
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
| 126 |
+
)
|
| 127 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
| 128 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
| 129 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
| 130 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
| 131 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
| 132 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
| 133 |
+
args.activation_fn = getattr(args, "activation_fn", "relu")
|
| 134 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 135 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
| 136 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
| 137 |
+
args.share_decoder_input_output_embed = getattr(
|
| 138 |
+
args, "share_decoder_input_output_embed", False
|
| 139 |
+
)
|
| 140 |
+
args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
|
| 141 |
+
args.no_token_positional_embeddings = getattr(
|
| 142 |
+
args, "no_token_positional_embeddings", False
|
| 143 |
+
)
|
| 144 |
+
args.adaptive_input = getattr(args, "adaptive_input", False)
|
| 145 |
+
args.apply_bert_init = getattr(args, "apply_bert_init", False)
|
| 146 |
+
|
| 147 |
+
args.decoder_output_dim = getattr(
|
| 148 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
| 149 |
+
)
|
| 150 |
+
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
| 151 |
+
|
| 152 |
+
# --- special arguments ---
|
| 153 |
+
args.sg_length_pred = getattr(args, "sg_length_pred", False)
|
| 154 |
+
args.pred_length_offset = getattr(args, "pred_length_offset", False)
|
| 155 |
+
args.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
|
| 156 |
+
args.ngram_predictor = getattr(args, "ngram_predictor", 1)
|
| 157 |
+
args.src_embedding_copy = getattr(args, "src_embedding_copy", False)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@register_model_architecture("cmlm_transformer", "cmlm_transformer_wmt_en_de")
|
| 161 |
+
def cmlm_wmt_en_de(args):
|
| 162 |
+
cmlm_base_architecture(args)
|
fairseq-0.10.2/fairseq/models/nat/fairseq_nat_model.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from fairseq.models.transformer import (
|
| 10 |
+
TransformerDecoder,
|
| 11 |
+
TransformerEncoder,
|
| 12 |
+
TransformerModel,
|
| 13 |
+
)
|
| 14 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def ensemble_encoder(func):
|
| 18 |
+
def wrapper(self, *args, **kwargs):
|
| 19 |
+
if self.ensemble_models is None or len(self.ensemble_models) == 1:
|
| 20 |
+
return func(self, *args, **kwargs)
|
| 21 |
+
encoder_outs = [func(model, *args, **kwargs) for model in self.ensemble_models]
|
| 22 |
+
_encoder_out = encoder_outs[0]
|
| 23 |
+
|
| 24 |
+
def stack(key):
|
| 25 |
+
outs = [getattr(e, key) for e in encoder_outs]
|
| 26 |
+
return torch.stack(outs, -1) if outs[0] is not None else None
|
| 27 |
+
|
| 28 |
+
return _encoder_out._replace(
|
| 29 |
+
encoder_out=stack("encoder_out"),
|
| 30 |
+
encoder_embedding=stack("encoder_embedding"),
|
| 31 |
+
encoder_states=stack("encoder_states"),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
return wrapper
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def ensemble_decoder(func):
|
| 38 |
+
def wrapper(self, normalize=False, encoder_out=None, *args, **kwargs):
|
| 39 |
+
if self.ensemble_models is None or len(self.ensemble_models) == 1:
|
| 40 |
+
return func(
|
| 41 |
+
self, normalize=normalize, encoder_out=encoder_out, *args, **kwargs
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
action_outs = [
|
| 45 |
+
func(
|
| 46 |
+
model,
|
| 47 |
+
normalize=normalize,
|
| 48 |
+
encoder_out=encoder_out._replace(
|
| 49 |
+
encoder_out=encoder_out.encoder_out[:, :, :, i]
|
| 50 |
+
),
|
| 51 |
+
*args,
|
| 52 |
+
**kwargs
|
| 53 |
+
)
|
| 54 |
+
for i, model in enumerate(self.ensemble_models)
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
if not isinstance(action_outs[0], tuple): # return multiple values
|
| 58 |
+
action_outs = [[a] for a in action_outs]
|
| 59 |
+
else:
|
| 60 |
+
action_outs = [list(a) for a in action_outs]
|
| 61 |
+
|
| 62 |
+
ensembled_outs = []
|
| 63 |
+
for i in range(len(action_outs[0])):
|
| 64 |
+
if i == 0 and normalize:
|
| 65 |
+
ensembled_outs += [
|
| 66 |
+
torch.logsumexp(
|
| 67 |
+
torch.stack([a[i] for a in action_outs], -1), dim=-1
|
| 68 |
+
)
|
| 69 |
+
- math.log(len(self.ensemble_models))
|
| 70 |
+
]
|
| 71 |
+
elif action_outs[0][i] is not None:
|
| 72 |
+
ensembled_outs += [torch.stack([a[i] for a in action_outs], -1)]
|
| 73 |
+
else:
|
| 74 |
+
ensembled_outs += [None]
|
| 75 |
+
|
| 76 |
+
if len(ensembled_outs) == 1:
|
| 77 |
+
return ensembled_outs[0]
|
| 78 |
+
return tuple(ensembled_outs)
|
| 79 |
+
|
| 80 |
+
return wrapper
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class FairseqNATModel(TransformerModel):
|
| 84 |
+
"""
|
| 85 |
+
Abstract class for all nonautoregressive-based models
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, args, encoder, decoder):
|
| 89 |
+
super().__init__(args, encoder, decoder)
|
| 90 |
+
self.tgt_dict = decoder.dictionary
|
| 91 |
+
self.bos = decoder.dictionary.bos()
|
| 92 |
+
self.eos = decoder.dictionary.eos()
|
| 93 |
+
self.pad = decoder.dictionary.pad()
|
| 94 |
+
self.unk = decoder.dictionary.unk()
|
| 95 |
+
|
| 96 |
+
self.ensemble_models = None
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def allow_length_beam(self):
|
| 100 |
+
return False
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def allow_ensemble(self):
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
def enable_ensemble(self, models):
|
| 107 |
+
self.encoder.ensemble_models = [m.encoder for m in models]
|
| 108 |
+
self.decoder.ensemble_models = [m.decoder for m in models]
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def add_args(parser):
|
| 112 |
+
TransformerModel.add_args(parser)
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--apply-bert-init",
|
| 115 |
+
action="store_true",
|
| 116 |
+
help="use custom param initialization for BERT",
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
@classmethod
|
| 120 |
+
def build_decoder(cls, args, tgt_dict, embed_tokens):
|
| 121 |
+
decoder = FairseqNATDecoder(args, tgt_dict, embed_tokens)
|
| 122 |
+
if getattr(args, "apply_bert_init", False):
|
| 123 |
+
decoder.apply(init_bert_params)
|
| 124 |
+
return decoder
|
| 125 |
+
|
| 126 |
+
@classmethod
|
| 127 |
+
def build_encoder(cls, args, src_dict, embed_tokens):
|
| 128 |
+
encoder = FairseqNATEncoder(args, src_dict, embed_tokens)
|
| 129 |
+
if getattr(args, "apply_bert_init", False):
|
| 130 |
+
encoder.apply(init_bert_params)
|
| 131 |
+
return encoder
|
| 132 |
+
|
| 133 |
+
def forward_encoder(self, encoder_inputs):
|
| 134 |
+
return self.encoder(*encoder_inputs)
|
| 135 |
+
|
| 136 |
+
def forward_decoder(self, *args, **kwargs):
|
| 137 |
+
return NotImplementedError
|
| 138 |
+
|
| 139 |
+
def initialize_output_tokens(self, *args, **kwargs):
|
| 140 |
+
return NotImplementedError
|
| 141 |
+
|
| 142 |
+
def forward(self, *args, **kwargs):
|
| 143 |
+
return NotImplementedError
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class FairseqNATEncoder(TransformerEncoder):
|
| 147 |
+
def __init__(self, args, dictionary, embed_tokens):
|
| 148 |
+
super().__init__(args, dictionary, embed_tokens)
|
| 149 |
+
self.ensemble_models = None
|
| 150 |
+
|
| 151 |
+
@ensemble_encoder
|
| 152 |
+
def forward(self, *args, **kwargs):
|
| 153 |
+
return super().forward(*args, **kwargs)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class FairseqNATDecoder(TransformerDecoder):
|
| 157 |
+
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
|
| 158 |
+
super().__init__(args, dictionary, embed_tokens, no_encoder_attn)
|
| 159 |
+
self.ensemble_models = None
|
fairseq-0.10.2/fairseq/models/nat/insertion_transformer.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from fairseq.models import register_model, register_model_architecture
|
| 10 |
+
from fairseq.models.nat import (
|
| 11 |
+
FairseqNATModel,
|
| 12 |
+
LevenshteinTransformerDecoder,
|
| 13 |
+
LevenshteinTransformerModel,
|
| 14 |
+
ensemble_decoder,
|
| 15 |
+
)
|
| 16 |
+
from fairseq.models.transformer import Linear
|
| 17 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
| 18 |
+
from fairseq.utils import new_arange
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class NegativeDistanceScore(object):
|
| 22 |
+
def __init__(self):
|
| 23 |
+
|
| 24 |
+
# pre-compute some values
|
| 25 |
+
self.scores = {}
|
| 26 |
+
|
| 27 |
+
self.scores[0.5] = self.compute_score_full(50, 0.5)
|
| 28 |
+
self.scores[1.0] = self.compute_score_full(50, 1.0)
|
| 29 |
+
self.scores[2.0] = self.compute_score_full(50, 2.0)
|
| 30 |
+
|
| 31 |
+
def __call__(self, i, L, tau):
|
| 32 |
+
if (tau is None) or (tau > 1000):
|
| 33 |
+
return 1 / L
|
| 34 |
+
|
| 35 |
+
if tau in self.scores:
|
| 36 |
+
if L < self.scores[tau].shape[0]:
|
| 37 |
+
return self.scores[tau][L - 1, i]
|
| 38 |
+
return self.compute_score(L, tau)[i]
|
| 39 |
+
|
| 40 |
+
def compute_score(self, L, tau):
|
| 41 |
+
s = np.array([-abs(L / 2 - i) / tau for i in range(L)])
|
| 42 |
+
s = np.exp(s - s.max())
|
| 43 |
+
return s / s.sum()
|
| 44 |
+
|
| 45 |
+
def compute_score_full(self, L, tau):
|
| 46 |
+
s = -abs(np.arange(0, L - 1)[:, None] / 2 - np.arange(L)[None, :]) / tau
|
| 47 |
+
s = np.tril(s, 0) + np.triu(s - float("inf"), 1)
|
| 48 |
+
s = np.exp(s - s.max(1, keepdims=True))
|
| 49 |
+
return s / s.sum(1, keepdims=True)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
neg_scorer = NegativeDistanceScore()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, tau=None):
|
| 56 |
+
try:
|
| 57 |
+
from fairseq import libnat
|
| 58 |
+
except ImportError as e:
|
| 59 |
+
import sys
|
| 60 |
+
|
| 61 |
+
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
|
| 62 |
+
raise e
|
| 63 |
+
|
| 64 |
+
B = in_tokens.size(0)
|
| 65 |
+
T = in_tokens.size(1)
|
| 66 |
+
V = vocab_size
|
| 67 |
+
|
| 68 |
+
with torch.cuda.device_of(in_tokens):
|
| 69 |
+
in_tokens_list = [
|
| 70 |
+
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
|
| 71 |
+
]
|
| 72 |
+
out_tokens_list = [
|
| 73 |
+
[t for t in s if t != padding_idx]
|
| 74 |
+
for i, s in enumerate(out_tokens.tolist())
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
full_labels = libnat.suggested_ed2_path(
|
| 78 |
+
in_tokens_list, out_tokens_list, padding_idx
|
| 79 |
+
)
|
| 80 |
+
insert_labels = [a[:-1] for a in full_labels]
|
| 81 |
+
|
| 82 |
+
# numericalize1
|
| 83 |
+
insert_label_tensors = in_tokens.new_zeros(B * (T - 1) * V).float()
|
| 84 |
+
insert_index, insert_labels = zip(
|
| 85 |
+
*[
|
| 86 |
+
(w + (j + i * (T - 1)) * V, neg_scorer(k, len(label), tau))
|
| 87 |
+
for i, labels in enumerate(insert_labels)
|
| 88 |
+
for j, label in enumerate(labels[1:-1])
|
| 89 |
+
for k, w in enumerate(label)
|
| 90 |
+
]
|
| 91 |
+
) # HACK 1:-1
|
| 92 |
+
insert_index, insert_labels = [
|
| 93 |
+
torch.tensor(list(a), device=in_tokens.device)
|
| 94 |
+
for a in [insert_index, insert_labels]
|
| 95 |
+
]
|
| 96 |
+
insert_label_tensors.scatter_(0, insert_index.long(), insert_labels)
|
| 97 |
+
insert_label_tensors = insert_label_tensors.view(B, T - 1, V)
|
| 98 |
+
|
| 99 |
+
return insert_label_tensors
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, padding_idx):
|
| 103 |
+
|
| 104 |
+
padding_masks = in_tokens[:, 1:].eq(padding_idx)
|
| 105 |
+
word_ins_scores.masked_fill_(padding_masks, 0.0)
|
| 106 |
+
word_ins_pred.masked_fill_(padding_masks, padding_idx)
|
| 107 |
+
|
| 108 |
+
in_coords = new_arange(in_tokens).type_as(in_scores)
|
| 109 |
+
|
| 110 |
+
# shift all padding predictions to infinite
|
| 111 |
+
out_coords = (in_coords[:, 1:] - 0.5).masked_fill(
|
| 112 |
+
word_ins_pred.eq(padding_idx), float("inf")
|
| 113 |
+
)
|
| 114 |
+
out_coords = torch.cat([in_coords, out_coords], 1).sort(-1)[1]
|
| 115 |
+
out_tokens = torch.cat([in_tokens, word_ins_pred], 1).gather(1, out_coords)
|
| 116 |
+
out_scores = torch.cat([in_scores, word_ins_scores], 1).gather(1, out_coords)
|
| 117 |
+
return out_tokens, out_scores
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@register_model("insertion_transformer")
|
| 121 |
+
class InsertionTransformerModel(LevenshteinTransformerModel):
|
| 122 |
+
def __init__(self, args, encoder, decoder):
|
| 123 |
+
super().__init__(args, encoder, decoder)
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def add_args(parser):
|
| 127 |
+
FairseqNATModel.add_args(parser)
|
| 128 |
+
parser.add_argument("--label-tau", default=None, type=float)
|
| 129 |
+
|
| 130 |
+
@classmethod
|
| 131 |
+
def build_decoder(cls, args, tgt_dict, embed_tokens):
|
| 132 |
+
decoder = InsertionTransformerDecoder(args, tgt_dict, embed_tokens)
|
| 133 |
+
if getattr(args, "apply_bert_init", False):
|
| 134 |
+
decoder.apply(init_bert_params)
|
| 135 |
+
return decoder
|
| 136 |
+
|
| 137 |
+
def forward(
|
| 138 |
+
self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
|
| 139 |
+
):
|
| 140 |
+
|
| 141 |
+
assert tgt_tokens is not None, "forward function only supports training."
|
| 142 |
+
|
| 143 |
+
# encoding
|
| 144 |
+
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
|
| 145 |
+
|
| 146 |
+
# generate training labels for insertion
|
| 147 |
+
word_ins_out = self.decoder.forward_word_ins(
|
| 148 |
+
normalize=False,
|
| 149 |
+
prev_output_tokens=prev_output_tokens,
|
| 150 |
+
encoder_out=encoder_out,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
word_ins_tgt = _get_ins_targets(
|
| 154 |
+
prev_output_tokens,
|
| 155 |
+
tgt_tokens,
|
| 156 |
+
self.pad,
|
| 157 |
+
self.unk,
|
| 158 |
+
len(self.tgt_dict),
|
| 159 |
+
tau=self.decoder.label_tau,
|
| 160 |
+
).type_as(word_ins_out)
|
| 161 |
+
word_ins_masks = prev_output_tokens[:, 1:].ne(self.pad)
|
| 162 |
+
|
| 163 |
+
return {
|
| 164 |
+
"word_ins": {
|
| 165 |
+
"out": word_ins_out,
|
| 166 |
+
"tgt": word_ins_tgt,
|
| 167 |
+
"mask": word_ins_masks,
|
| 168 |
+
"ls": self.args.label_smoothing,
|
| 169 |
+
"nll_loss": True,
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
def forward_decoder(
|
| 174 |
+
self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
|
| 175 |
+
):
|
| 176 |
+
|
| 177 |
+
output_tokens = decoder_out.output_tokens
|
| 178 |
+
output_scores = decoder_out.output_scores
|
| 179 |
+
history = decoder_out.history
|
| 180 |
+
|
| 181 |
+
# TODO: decoding for InsertionTransformer
|
| 182 |
+
word_ins_score = self.decoder.forward_word_ins(
|
| 183 |
+
normalize=True, prev_output_tokens=output_tokens, encoder_out=encoder_out
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if eos_penalty > 0.0:
|
| 187 |
+
word_ins_score[:, :, self.pad] -= eos_penalty
|
| 188 |
+
word_ins_score, word_ins_pred = word_ins_score.max(-1)
|
| 189 |
+
output_tokens, output_scores = _apply_ins_words(
|
| 190 |
+
output_tokens, output_scores, word_ins_pred, word_ins_score, self.pad
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# delete some unnecessary paddings
|
| 194 |
+
cut_off = output_tokens.ne(self.pad).sum(1).max()
|
| 195 |
+
output_tokens = output_tokens[:, :cut_off]
|
| 196 |
+
output_scores = output_scores[:, :cut_off]
|
| 197 |
+
|
| 198 |
+
if history is not None:
|
| 199 |
+
history.append(output_tokens.clone())
|
| 200 |
+
|
| 201 |
+
return decoder_out._replace(
|
| 202 |
+
output_tokens=output_tokens,
|
| 203 |
+
output_scores=output_scores,
|
| 204 |
+
attn=None,
|
| 205 |
+
history=history,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
|
| 210 |
+
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
|
| 211 |
+
# use the TransformerDecoder's __init__
|
| 212 |
+
super(LevenshteinTransformerDecoder, self).__init__(
|
| 213 |
+
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
self.dictionary = dictionary
|
| 217 |
+
self.bos = dictionary.bos()
|
| 218 |
+
self.unk = dictionary.unk()
|
| 219 |
+
self.eos = dictionary.eos()
|
| 220 |
+
self.pool_out = Linear(self.output_embed_dim * 2, self.output_embed_dim)
|
| 221 |
+
|
| 222 |
+
self.label_tau = getattr(args, "label_tau", None)
|
| 223 |
+
|
| 224 |
+
@ensemble_decoder
|
| 225 |
+
def forward_word_ins(self, normalize, encoder_out, prev_output_tokens):
|
| 226 |
+
features = self.extract_features(prev_output_tokens, encoder_out=encoder_out)[0]
|
| 227 |
+
features = self.pool_out(
|
| 228 |
+
torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
|
| 229 |
+
)
|
| 230 |
+
decoder_out = self.output_layer(features)
|
| 231 |
+
return F.log_softmax(decoder_out, -1) if normalize else decoder_out
|
| 232 |
+
|
| 233 |
+
def forward_mask_ins(self, *args, **kwargs):
|
| 234 |
+
raise NotImplementedError
|
| 235 |
+
|
| 236 |
+
def forward_word_del(self, *args, **kwargs):
|
| 237 |
+
raise NotImplementedError
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@register_model_architecture("insertion_transformer", "insertion_transformer")
|
| 241 |
+
def insertion_base_architecture(args):
|
| 242 |
+
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
| 243 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
| 244 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
| 245 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
| 246 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
| 247 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
| 248 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
| 249 |
+
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
| 250 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
| 251 |
+
args.decoder_ffn_embed_dim = getattr(
|
| 252 |
+
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
| 253 |
+
)
|
| 254 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
| 255 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
| 256 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
| 257 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
| 258 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
| 259 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
| 260 |
+
args.activation_fn = getattr(args, "activation_fn", "relu")
|
| 261 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 262 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
| 263 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
| 264 |
+
args.share_decoder_input_output_embed = getattr(
|
| 265 |
+
args, "share_decoder_input_output_embed", False
|
| 266 |
+
)
|
| 267 |
+
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
|
| 268 |
+
args.no_token_positional_embeddings = getattr(
|
| 269 |
+
args, "no_token_positional_embeddings", False
|
| 270 |
+
)
|
| 271 |
+
args.adaptive_input = getattr(args, "adaptive_input", False)
|
| 272 |
+
args.apply_bert_init = getattr(args, "apply_bert_init", False)
|
| 273 |
+
|
| 274 |
+
args.decoder_output_dim = getattr(
|
| 275 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
| 276 |
+
)
|
| 277 |
+
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
| 278 |
+
|
| 279 |
+
# special for insertion transformer
|
| 280 |
+
args.label_tau = getattr(args, "label_tau", None)
|
fairseq-0.10.2/fairseq/models/nat/iterative_nonautoregressive_transformer.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from fairseq.models import register_model, register_model_architecture
|
| 8 |
+
from fairseq.models.nat import NATransformerModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _sequential_poisoning(s, V, beta=0.33, bos=2, eos=3, pad=1):
|
| 12 |
+
# s: input batch
|
| 13 |
+
# V: vocabulary size
|
| 14 |
+
rand_words = torch.randint(low=4, high=V, size=s.size(), device=s.device)
|
| 15 |
+
choices = torch.rand(size=s.size(), device=s.device)
|
| 16 |
+
choices.masked_fill_((s == pad) | (s == bos) | (s == eos), 1)
|
| 17 |
+
|
| 18 |
+
replace = choices < beta / 3
|
| 19 |
+
repeat = (choices >= beta / 3) & (choices < beta * 2 / 3)
|
| 20 |
+
swap = (choices >= beta * 2 / 3) & (choices < beta)
|
| 21 |
+
safe = choices >= beta
|
| 22 |
+
|
| 23 |
+
for i in range(s.size(1) - 1):
|
| 24 |
+
rand_word = rand_words[:, i]
|
| 25 |
+
next_word = s[:, i + 1]
|
| 26 |
+
self_word = s[:, i]
|
| 27 |
+
|
| 28 |
+
replace_i = replace[:, i]
|
| 29 |
+
swap_i = swap[:, i] & (next_word != 3)
|
| 30 |
+
repeat_i = repeat[:, i] & (next_word != 3)
|
| 31 |
+
safe_i = safe[:, i] | ((next_word == 3) & (~replace_i))
|
| 32 |
+
|
| 33 |
+
s[:, i] = (
|
| 34 |
+
self_word * (safe_i | repeat_i).long()
|
| 35 |
+
+ next_word * swap_i.long()
|
| 36 |
+
+ rand_word * replace_i.long()
|
| 37 |
+
)
|
| 38 |
+
s[:, i + 1] = (
|
| 39 |
+
next_word * (safe_i | replace_i).long()
|
| 40 |
+
+ self_word * (swap_i | repeat_i).long()
|
| 41 |
+
)
|
| 42 |
+
return s
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def gumbel_noise(input, TINY=1e-8):
|
| 46 |
+
return (
|
| 47 |
+
input.new_zeros(*input.size())
|
| 48 |
+
.uniform_()
|
| 49 |
+
.add_(TINY)
|
| 50 |
+
.log_()
|
| 51 |
+
.neg_()
|
| 52 |
+
.add_(TINY)
|
| 53 |
+
.log_()
|
| 54 |
+
.neg_()
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@register_model("iterative_nonautoregressive_transformer")
|
| 59 |
+
class IterNATransformerModel(NATransformerModel):
|
| 60 |
+
@staticmethod
|
| 61 |
+
def add_args(parser):
|
| 62 |
+
NATransformerModel.add_args(parser)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--train-step",
|
| 65 |
+
type=int,
|
| 66 |
+
help="number of refinement iterations during training",
|
| 67 |
+
)
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--dae-ratio",
|
| 70 |
+
type=float,
|
| 71 |
+
help="the probability of switching to the denoising auto-encoder loss",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--stochastic-approx",
|
| 75 |
+
action="store_true",
|
| 76 |
+
help="sampling from the decoder as the inputs for next iteration",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
@classmethod
|
| 80 |
+
def build_model(cls, args, task):
|
| 81 |
+
model = super().build_model(args, task)
|
| 82 |
+
model.train_step = getattr(args, "train_step", 4)
|
| 83 |
+
model.dae_ratio = getattr(args, "dae_ratio", 0.5)
|
| 84 |
+
model.stochastic_approx = getattr(args, "stochastic_approx", False)
|
| 85 |
+
return model
|
| 86 |
+
|
| 87 |
+
def forward(
|
| 88 |
+
self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
|
| 89 |
+
):
|
| 90 |
+
|
| 91 |
+
B, T = prev_output_tokens.size()
|
| 92 |
+
|
| 93 |
+
# encoding
|
| 94 |
+
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
|
| 95 |
+
|
| 96 |
+
# length prediction
|
| 97 |
+
length_out = self.decoder.forward_length(
|
| 98 |
+
normalize=False, encoder_out=encoder_out
|
| 99 |
+
)
|
| 100 |
+
length_tgt = self.decoder.forward_length_prediction(
|
| 101 |
+
length_out, encoder_out, tgt_tokens
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# decoding
|
| 105 |
+
word_ins_outs, word_ins_tgts, word_ins_masks = [], [], []
|
| 106 |
+
for t in range(self.train_step):
|
| 107 |
+
word_ins_out = self.decoder(
|
| 108 |
+
normalize=False,
|
| 109 |
+
prev_output_tokens=prev_output_tokens,
|
| 110 |
+
encoder_out=encoder_out,
|
| 111 |
+
step=t,
|
| 112 |
+
)
|
| 113 |
+
word_ins_tgt = tgt_tokens
|
| 114 |
+
word_ins_mask = word_ins_tgt.ne(self.pad)
|
| 115 |
+
|
| 116 |
+
word_ins_outs.append(word_ins_out)
|
| 117 |
+
word_ins_tgts.append(word_ins_tgt)
|
| 118 |
+
word_ins_masks.append(word_ins_mask)
|
| 119 |
+
|
| 120 |
+
if t < (self.train_step - 1):
|
| 121 |
+
# prediction for next iteration
|
| 122 |
+
if self.stochastic_approx:
|
| 123 |
+
word_ins_prediction = (
|
| 124 |
+
word_ins_out + gumbel_noise(word_ins_out)
|
| 125 |
+
).max(-1)[1]
|
| 126 |
+
else:
|
| 127 |
+
word_ins_prediction = word_ins_out.max(-1)[1]
|
| 128 |
+
|
| 129 |
+
prev_output_tokens = prev_output_tokens.masked_scatter(
|
| 130 |
+
word_ins_mask, word_ins_prediction[word_ins_mask]
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
if self.dae_ratio > 0:
|
| 134 |
+
# we do not perform denoising for the first iteration
|
| 135 |
+
corrputed = (
|
| 136 |
+
torch.rand(size=(B,), device=prev_output_tokens.device)
|
| 137 |
+
< self.dae_ratio
|
| 138 |
+
)
|
| 139 |
+
corrputed_tokens = _sequential_poisoning(
|
| 140 |
+
tgt_tokens[corrputed],
|
| 141 |
+
len(self.tgt_dict),
|
| 142 |
+
0.33,
|
| 143 |
+
self.bos,
|
| 144 |
+
self.eos,
|
| 145 |
+
self.pad,
|
| 146 |
+
)
|
| 147 |
+
prev_output_tokens[corrputed] = corrputed_tokens
|
| 148 |
+
|
| 149 |
+
# concat everything
|
| 150 |
+
word_ins_out = torch.cat(word_ins_outs, 0)
|
| 151 |
+
word_ins_tgt = torch.cat(word_ins_tgts, 0)
|
| 152 |
+
word_ins_mask = torch.cat(word_ins_masks, 0)
|
| 153 |
+
|
| 154 |
+
return {
|
| 155 |
+
"word_ins": {
|
| 156 |
+
"out": word_ins_out,
|
| 157 |
+
"tgt": word_ins_tgt,
|
| 158 |
+
"mask": word_ins_mask,
|
| 159 |
+
"ls": self.args.label_smoothing,
|
| 160 |
+
"nll_loss": True,
|
| 161 |
+
},
|
| 162 |
+
"length": {
|
| 163 |
+
"out": length_out,
|
| 164 |
+
"tgt": length_tgt,
|
| 165 |
+
"factor": self.decoder.length_loss_factor,
|
| 166 |
+
},
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@register_model_architecture(
|
| 171 |
+
"iterative_nonautoregressive_transformer", "iterative_nonautoregressive_transformer"
|
| 172 |
+
)
|
| 173 |
+
def inat_base_architecture(args):
|
| 174 |
+
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
| 175 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
| 176 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
| 177 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
| 178 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
| 179 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
| 180 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
| 181 |
+
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
| 182 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
| 183 |
+
args.decoder_ffn_embed_dim = getattr(
|
| 184 |
+
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
| 185 |
+
)
|
| 186 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
| 187 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
| 188 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
| 189 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
| 190 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
| 191 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
| 192 |
+
args.activation_fn = getattr(args, "activation_fn", "relu")
|
| 193 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 194 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
| 195 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
| 196 |
+
args.share_decoder_input_output_embed = getattr(
|
| 197 |
+
args, "share_decoder_input_output_embed", False
|
| 198 |
+
)
|
| 199 |
+
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
|
| 200 |
+
args.no_token_positional_embeddings = getattr(
|
| 201 |
+
args, "no_token_positional_embeddings", False
|
| 202 |
+
)
|
| 203 |
+
args.adaptive_input = getattr(args, "adaptive_input", False)
|
| 204 |
+
args.apply_bert_init = getattr(args, "apply_bert_init", False)
|
| 205 |
+
|
| 206 |
+
args.decoder_output_dim = getattr(
|
| 207 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
| 208 |
+
)
|
| 209 |
+
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
| 210 |
+
|
| 211 |
+
# --- special arguments ---
|
| 212 |
+
args.sg_length_pred = getattr(args, "sg_length_pred", False)
|
| 213 |
+
args.pred_length_offset = getattr(args, "pred_length_offset", False)
|
| 214 |
+
args.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
|
| 215 |
+
args.ngram_predictor = getattr(args, "ngram_predictor", 1)
|
| 216 |
+
args.src_embedding_copy = getattr(args, "src_embedding_copy", False)
|
| 217 |
+
|
| 218 |
+
args.train_step = getattr(args, "train_step", 4)
|
| 219 |
+
args.dae_ratio = getattr(args, "dae_ratio", 0.5)
|
| 220 |
+
args.stochastic_approx = getattr(args, "stochastic_approx", False)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@register_model_architecture(
|
| 224 |
+
"iterative_nonautoregressive_transformer",
|
| 225 |
+
"iterative_nonautoregressive_transformer_wmt_en_de",
|
| 226 |
+
)
|
| 227 |
+
def iter_nat_wmt_en_de(args):
|
| 228 |
+
inat_base_architecture(args)
|
fairseq-0.10.2/fairseq/models/nat/levenshtein_utils.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from fairseq.utils import new_arange
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# -------------- Helper Functions --------------------------------------------------- #
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_libnat():
|
| 14 |
+
try:
|
| 15 |
+
from fairseq import libnat_cuda
|
| 16 |
+
|
| 17 |
+
return libnat_cuda, True
|
| 18 |
+
|
| 19 |
+
except ImportError as e:
|
| 20 |
+
print(str(e) + "... fall back to CPU version")
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from fairseq import libnat
|
| 24 |
+
|
| 25 |
+
return libnat, False
|
| 26 |
+
|
| 27 |
+
except ImportError as e:
|
| 28 |
+
import sys
|
| 29 |
+
|
| 30 |
+
sys.stderr.write(
|
| 31 |
+
"ERROR: missing libnat_cuda. run `python setup.py build_ext --inplace`\n"
|
| 32 |
+
)
|
| 33 |
+
raise e
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
|
| 37 |
+
libnat, use_cuda = load_libnat()
|
| 38 |
+
|
| 39 |
+
def _get_ins_targets_cuda(in_tokens, out_tokens, padding_idx, unk_idx):
|
| 40 |
+
in_masks = in_tokens.ne(padding_idx)
|
| 41 |
+
out_masks = out_tokens.ne(padding_idx)
|
| 42 |
+
mask_ins_targets, masked_tgt_masks = libnat.generate_insertion_labels(
|
| 43 |
+
out_tokens.int(),
|
| 44 |
+
libnat.levenshtein_distance(
|
| 45 |
+
in_tokens.int(),
|
| 46 |
+
out_tokens.int(),
|
| 47 |
+
in_masks.sum(1).int(),
|
| 48 |
+
out_masks.sum(1).int(),
|
| 49 |
+
),
|
| 50 |
+
)
|
| 51 |
+
masked_tgt_masks = masked_tgt_masks.bool() & out_masks
|
| 52 |
+
mask_ins_targets = mask_ins_targets.type_as(in_tokens)[
|
| 53 |
+
:, 1 : in_masks.size(1)
|
| 54 |
+
].masked_fill_(~in_masks[:, 1:], 0)
|
| 55 |
+
masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
|
| 56 |
+
return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
|
| 57 |
+
|
| 58 |
+
def _get_ins_targets_cpu(in_tokens, out_tokens, padding_idx, unk_idx):
|
| 59 |
+
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
|
| 60 |
+
|
| 61 |
+
in_tokens_list = [
|
| 62 |
+
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
|
| 63 |
+
]
|
| 64 |
+
out_tokens_list = [
|
| 65 |
+
[t for t in s if t != padding_idx]
|
| 66 |
+
for i, s in enumerate(out_tokens.tolist())
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
full_labels = libnat.suggested_ed2_path(
|
| 70 |
+
in_tokens_list, out_tokens_list, padding_idx
|
| 71 |
+
)
|
| 72 |
+
mask_inputs = [
|
| 73 |
+
[len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
# generate labels
|
| 77 |
+
masked_tgt_masks = []
|
| 78 |
+
for mask_input in mask_inputs:
|
| 79 |
+
mask_label = []
|
| 80 |
+
for beam_size in mask_input[1:-1]: # HACK 1:-1
|
| 81 |
+
mask_label += [0] + [1 for _ in range(beam_size)]
|
| 82 |
+
masked_tgt_masks.append(
|
| 83 |
+
mask_label + [0 for _ in range(out_seq_len - len(mask_label))]
|
| 84 |
+
)
|
| 85 |
+
mask_ins_targets = [
|
| 86 |
+
mask_input[1:-1]
|
| 87 |
+
+ [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))]
|
| 88 |
+
for mask_input in mask_inputs
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
# transform to tensor
|
| 92 |
+
masked_tgt_masks = torch.tensor(
|
| 93 |
+
masked_tgt_masks, device=out_tokens.device
|
| 94 |
+
).bool()
|
| 95 |
+
mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
|
| 96 |
+
masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
|
| 97 |
+
return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
|
| 98 |
+
|
| 99 |
+
if use_cuda:
|
| 100 |
+
return _get_ins_targets_cuda(in_tokens, out_tokens, padding_idx, unk_idx)
|
| 101 |
+
return _get_ins_targets_cpu(in_tokens, out_tokens, padding_idx, unk_idx)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _get_del_targets(in_tokens, out_tokens, padding_idx):
|
| 105 |
+
libnat, use_cuda = load_libnat()
|
| 106 |
+
|
| 107 |
+
def _get_del_targets_cuda(in_tokens, out_tokens, padding_idx):
|
| 108 |
+
in_masks = in_tokens.ne(padding_idx)
|
| 109 |
+
out_masks = out_tokens.ne(padding_idx)
|
| 110 |
+
|
| 111 |
+
word_del_targets = libnat.generate_deletion_labels(
|
| 112 |
+
in_tokens.int(),
|
| 113 |
+
libnat.levenshtein_distance(
|
| 114 |
+
in_tokens.int(),
|
| 115 |
+
out_tokens.int(),
|
| 116 |
+
in_masks.sum(1).int(),
|
| 117 |
+
out_masks.sum(1).int(),
|
| 118 |
+
),
|
| 119 |
+
)
|
| 120 |
+
word_del_targets = word_del_targets.type_as(in_tokens).masked_fill_(
|
| 121 |
+
~in_masks, 0
|
| 122 |
+
)
|
| 123 |
+
return word_del_targets
|
| 124 |
+
|
| 125 |
+
def _get_del_targets_cpu(in_tokens, out_tokens, padding_idx):
|
| 126 |
+
out_seq_len = out_tokens.size(1)
|
| 127 |
+
with torch.cuda.device_of(in_tokens):
|
| 128 |
+
in_tokens_list = [
|
| 129 |
+
[t for t in s if t != padding_idx]
|
| 130 |
+
for i, s in enumerate(in_tokens.tolist())
|
| 131 |
+
]
|
| 132 |
+
out_tokens_list = [
|
| 133 |
+
[t for t in s if t != padding_idx]
|
| 134 |
+
for i, s in enumerate(out_tokens.tolist())
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
full_labels = libnat.suggested_ed2_path(
|
| 138 |
+
in_tokens_list, out_tokens_list, padding_idx
|
| 139 |
+
)
|
| 140 |
+
word_del_targets = [b[-1] for b in full_labels]
|
| 141 |
+
word_del_targets = [
|
| 142 |
+
labels + [0 for _ in range(out_seq_len - len(labels))]
|
| 143 |
+
for labels in word_del_targets
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
# transform to tensor
|
| 147 |
+
word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
|
| 148 |
+
return word_del_targets
|
| 149 |
+
|
| 150 |
+
if use_cuda:
|
| 151 |
+
return _get_del_targets_cuda(in_tokens, out_tokens, padding_idx)
|
| 152 |
+
return _get_del_targets_cpu(in_tokens, out_tokens, padding_idx)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _apply_ins_masks(
|
| 156 |
+
in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
|
| 157 |
+
):
|
| 158 |
+
|
| 159 |
+
in_masks = in_tokens.ne(padding_idx)
|
| 160 |
+
in_lengths = in_masks.sum(1)
|
| 161 |
+
|
| 162 |
+
# HACK: hacky way to shift all the paddings to eos first.
|
| 163 |
+
in_tokens.masked_fill_(~in_masks, eos_idx)
|
| 164 |
+
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
|
| 165 |
+
|
| 166 |
+
out_lengths = in_lengths + mask_ins_pred.sum(1)
|
| 167 |
+
out_max_len = out_lengths.max()
|
| 168 |
+
out_masks = new_arange(out_lengths, out_max_len)[None, :] < out_lengths[:, None]
|
| 169 |
+
|
| 170 |
+
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
|
| 171 |
+
out_tokens = (
|
| 172 |
+
in_tokens.new_zeros(in_tokens.size(0), out_max_len)
|
| 173 |
+
.fill_(padding_idx)
|
| 174 |
+
.masked_fill_(out_masks, unk_idx)
|
| 175 |
+
)
|
| 176 |
+
out_tokens[:, 0] = in_tokens[:, 0]
|
| 177 |
+
out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
|
| 178 |
+
|
| 179 |
+
out_scores = None
|
| 180 |
+
if in_scores is not None:
|
| 181 |
+
in_scores.masked_fill_(~in_masks, 0)
|
| 182 |
+
out_scores = in_scores.new_zeros(*out_tokens.size())
|
| 183 |
+
out_scores[:, 0] = in_scores[:, 0]
|
| 184 |
+
out_scores.scatter_(1, reordering, in_scores[:, 1:])
|
| 185 |
+
|
| 186 |
+
return out_tokens, out_scores
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx):
|
| 190 |
+
word_ins_masks = in_tokens.eq(unk_idx)
|
| 191 |
+
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
|
| 192 |
+
|
| 193 |
+
if in_scores is not None:
|
| 194 |
+
out_scores = in_scores.masked_scatter(
|
| 195 |
+
word_ins_masks, word_ins_scores[word_ins_masks]
|
| 196 |
+
)
|
| 197 |
+
else:
|
| 198 |
+
out_scores = None
|
| 199 |
+
|
| 200 |
+
return out_tokens, out_scores
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _apply_del_words(
|
| 204 |
+
in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
|
| 205 |
+
):
|
| 206 |
+
# apply deletion to a tensor
|
| 207 |
+
in_masks = in_tokens.ne(padding_idx)
|
| 208 |
+
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
|
| 209 |
+
|
| 210 |
+
max_len = in_tokens.size(1)
|
| 211 |
+
word_del_pred.masked_fill_(~in_masks, 1)
|
| 212 |
+
word_del_pred.masked_fill_(bos_eos_masks, 0)
|
| 213 |
+
|
| 214 |
+
reordering = new_arange(in_tokens).masked_fill_(word_del_pred, max_len).sort(1)[1]
|
| 215 |
+
|
| 216 |
+
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
|
| 217 |
+
|
| 218 |
+
out_scores = None
|
| 219 |
+
if in_scores is not None:
|
| 220 |
+
out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
|
| 221 |
+
|
| 222 |
+
out_attn = None
|
| 223 |
+
if in_attn is not None:
|
| 224 |
+
_mask = word_del_pred[:, :, None].expand_as(in_attn)
|
| 225 |
+
_reordering = reordering[:, :, None].expand_as(in_attn)
|
| 226 |
+
out_attn = in_attn.masked_fill(_mask, 0.0).gather(1, _reordering)
|
| 227 |
+
|
| 228 |
+
return out_tokens, out_scores, out_attn
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _skip(x, mask):
|
| 232 |
+
"""
|
| 233 |
+
Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors.
|
| 234 |
+
"""
|
| 235 |
+
if isinstance(x, int):
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
if x is None:
|
| 239 |
+
return None
|
| 240 |
+
|
| 241 |
+
if isinstance(x, torch.Tensor):
|
| 242 |
+
if x.size(0) == mask.size(0):
|
| 243 |
+
return x[mask]
|
| 244 |
+
elif x.size(1) == mask.size(0):
|
| 245 |
+
return x[:, mask]
|
| 246 |
+
|
| 247 |
+
if isinstance(x, list):
|
| 248 |
+
return [_skip(x_i, mask) for x_i in x]
|
| 249 |
+
|
| 250 |
+
if isinstance(x, dict):
|
| 251 |
+
return {k: _skip(v, mask) for k, v in x.items()}
|
| 252 |
+
|
| 253 |
+
raise NotImplementedError
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _skip_encoder_out(encoder, encoder_out, mask):
|
| 257 |
+
if not mask.any():
|
| 258 |
+
return encoder_out
|
| 259 |
+
else:
|
| 260 |
+
return encoder.reorder_encoder_out(
|
| 261 |
+
encoder_out, mask.nonzero(as_tuple=False).squeeze()
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _fill(x, mask, y, padding_idx):
|
| 266 |
+
"""
|
| 267 |
+
Filling tensor x with y at masked positions (dim=0).
|
| 268 |
+
"""
|
| 269 |
+
if x is None:
|
| 270 |
+
return y
|
| 271 |
+
assert x.dim() == y.dim() and mask.size(0) == x.size(0)
|
| 272 |
+
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
|
| 273 |
+
n_selected = mask.sum()
|
| 274 |
+
assert n_selected == y.size(0)
|
| 275 |
+
|
| 276 |
+
if n_selected == x.size(0):
|
| 277 |
+
return y
|
| 278 |
+
|
| 279 |
+
if x.size(1) < y.size(1):
|
| 280 |
+
dims = [x.size(0), y.size(1) - x.size(1)]
|
| 281 |
+
if x.dim() == 3:
|
| 282 |
+
dims.append(x.size(2))
|
| 283 |
+
x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1)
|
| 284 |
+
x[mask] = y
|
| 285 |
+
elif x.size(1) > y.size(1):
|
| 286 |
+
x[mask] = padding_idx
|
| 287 |
+
if x.dim() == 2:
|
| 288 |
+
x[mask, : y.size(1)] = y
|
| 289 |
+
else:
|
| 290 |
+
x[mask, : y.size(1), :] = y
|
| 291 |
+
else:
|
| 292 |
+
x[mask] = y
|
| 293 |
+
return x
|
fairseq-0.10.2/fairseq/models/nat/nat_crf_transformer.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from fairseq.models import register_model, register_model_architecture
|
| 8 |
+
from fairseq.models.nat import NATransformerModel, base_architecture
|
| 9 |
+
from fairseq.modules import DynamicCRF
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@register_model("nacrf_transformer")
|
| 13 |
+
class NACRFTransformerModel(NATransformerModel):
|
| 14 |
+
def __init__(self, args, encoder, decoder):
|
| 15 |
+
super().__init__(args, encoder, decoder)
|
| 16 |
+
self.crf_layer = DynamicCRF(
|
| 17 |
+
num_embedding=len(self.tgt_dict),
|
| 18 |
+
low_rank=args.crf_lowrank_approx,
|
| 19 |
+
beam_size=args.crf_beam_approx,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
@property
|
| 23 |
+
def allow_ensemble(self):
|
| 24 |
+
return False
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def add_args(parser):
|
| 28 |
+
NATransformerModel.add_args(parser)
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--crf-lowrank-approx",
|
| 31 |
+
type=int,
|
| 32 |
+
help="the dimension of low-rank approximation of transition",
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--crf-beam-approx",
|
| 36 |
+
type=int,
|
| 37 |
+
help="the beam size for apporixmating the normalizing factor",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--word-ins-loss-factor",
|
| 41 |
+
type=float,
|
| 42 |
+
help="weights on NAT loss used to co-training with CRF loss.",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def forward(
|
| 46 |
+
self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
|
| 47 |
+
):
|
| 48 |
+
# encoding
|
| 49 |
+
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
|
| 50 |
+
|
| 51 |
+
# length prediction
|
| 52 |
+
length_out = self.decoder.forward_length(
|
| 53 |
+
normalize=False, encoder_out=encoder_out
|
| 54 |
+
)
|
| 55 |
+
length_tgt = self.decoder.forward_length_prediction(
|
| 56 |
+
length_out, encoder_out, tgt_tokens
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# decoding
|
| 60 |
+
word_ins_out = self.decoder(
|
| 61 |
+
normalize=False,
|
| 62 |
+
prev_output_tokens=prev_output_tokens,
|
| 63 |
+
encoder_out=encoder_out,
|
| 64 |
+
)
|
| 65 |
+
word_ins_tgt, word_ins_mask = tgt_tokens, tgt_tokens.ne(self.pad)
|
| 66 |
+
|
| 67 |
+
# compute the log-likelihood of CRF
|
| 68 |
+
crf_nll = -self.crf_layer(word_ins_out, word_ins_tgt, word_ins_mask)
|
| 69 |
+
crf_nll = (crf_nll / word_ins_mask.type_as(crf_nll).sum(-1)).mean()
|
| 70 |
+
|
| 71 |
+
return {
|
| 72 |
+
"word_ins": {
|
| 73 |
+
"out": word_ins_out,
|
| 74 |
+
"tgt": word_ins_tgt,
|
| 75 |
+
"mask": word_ins_mask,
|
| 76 |
+
"ls": self.args.label_smoothing,
|
| 77 |
+
"nll_loss": True,
|
| 78 |
+
"factor": self.args.word_ins_loss_factor,
|
| 79 |
+
},
|
| 80 |
+
"word_crf": {"loss": crf_nll},
|
| 81 |
+
"length": {
|
| 82 |
+
"out": length_out,
|
| 83 |
+
"tgt": length_tgt,
|
| 84 |
+
"factor": self.decoder.length_loss_factor,
|
| 85 |
+
},
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
|
| 89 |
+
output_tokens = decoder_out.output_tokens
|
| 90 |
+
output_scores = decoder_out.output_scores
|
| 91 |
+
history = decoder_out.history
|
| 92 |
+
|
| 93 |
+
# execute the decoder and get emission scores
|
| 94 |
+
output_masks = output_tokens.ne(self.pad)
|
| 95 |
+
word_ins_out = self.decoder(
|
| 96 |
+
normalize=False, prev_output_tokens=output_tokens, encoder_out=encoder_out
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# run viterbi decoding through CRF
|
| 100 |
+
_scores, _tokens = self.crf_layer.forward_decoder(word_ins_out, output_masks)
|
| 101 |
+
output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
|
| 102 |
+
output_scores.masked_scatter_(output_masks, _scores[output_masks])
|
| 103 |
+
if history is not None:
|
| 104 |
+
history.append(output_tokens.clone())
|
| 105 |
+
|
| 106 |
+
return decoder_out._replace(
|
| 107 |
+
output_tokens=output_tokens,
|
| 108 |
+
output_scores=output_scores,
|
| 109 |
+
attn=None,
|
| 110 |
+
history=history,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@register_model_architecture("nacrf_transformer", "nacrf_transformer")
|
| 115 |
+
def nacrf_base_architecture(args):
|
| 116 |
+
args.crf_lowrank_approx = getattr(args, "crf_lowrank_approx", 32)
|
| 117 |
+
args.crf_beam_approx = getattr(args, "crf_beam_approx", 64)
|
| 118 |
+
args.word_ins_loss_factor = getattr(args, "word_ins_loss_factor", 0.5)
|
| 119 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
| 120 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
|
| 121 |
+
base_architecture(args)
|
fairseq-0.10.2/fairseq/models/nat/nonautoregressive_ensembles.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from fairseq.models.nat import (
|
| 11 |
+
_apply_del_words,
|
| 12 |
+
_apply_ins_masks,
|
| 13 |
+
_apply_ins_words,
|
| 14 |
+
_fill,
|
| 15 |
+
_skip,
|
| 16 |
+
_skip_encoder_out,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class _EnsembleModelEncoder(object):
|
| 21 |
+
def __init__(self, models):
|
| 22 |
+
self.models = models
|
| 23 |
+
|
| 24 |
+
def reorder_encoder_out(self, encoder_outs, new_order):
|
| 25 |
+
encoder_outs = [
|
| 26 |
+
model.encoder.reorder_encoder_out(encoder_out, new_order)
|
| 27 |
+
for model, encoder_out in zip(self.models, encoder_outs)
|
| 28 |
+
]
|
| 29 |
+
return encoder_outs
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BasicEnsembleModel(torch.nn.Module):
|
| 33 |
+
"""A wrapper around an ensemble of models."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, models):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.models = torch.nn.ModuleList(models)
|
| 38 |
+
self.bos = self.models[0].decoder.dictionary.bos()
|
| 39 |
+
self.eos = self.models[0].decoder.dictionary.eos()
|
| 40 |
+
self.pad = self.models[0].decoder.dictionary.pad()
|
| 41 |
+
self.unk = self.models[0].decoder.dictionary.unk()
|
| 42 |
+
self.encoder = _EnsembleModelEncoder(self.models)
|
| 43 |
+
|
| 44 |
+
def has_encoder(self):
|
| 45 |
+
return hasattr(self.models[0], "encoder")
|
| 46 |
+
|
| 47 |
+
def max_decoder_positions(self):
|
| 48 |
+
return min(m.max_decoder_positions() for m in self.models)
|
| 49 |
+
|
| 50 |
+
@torch.no_grad()
|
| 51 |
+
def forward_encoder(self, encoder_input):
|
| 52 |
+
if not self.has_encoder():
|
| 53 |
+
return None
|
| 54 |
+
return [model.forward_encoder(encoder_input) for model in self.models]
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def forward_decoder(self, *inputs):
|
| 58 |
+
raise NotImplementedError
|
| 59 |
+
|
| 60 |
+
def initialize_output_tokens(self, *inputs):
|
| 61 |
+
raise NotImplementedError
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class EnsembleLevT(BasicEnsembleModel):
|
| 65 |
+
"""A wrapper around an ensemble of models."""
|
| 66 |
+
|
| 67 |
+
def __init__(self, models):
|
| 68 |
+
super().__init__(models)
|
| 69 |
+
|
| 70 |
+
@torch.no_grad()
|
| 71 |
+
def forward_decoder(
|
| 72 |
+
self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=None, **kwargs
|
| 73 |
+
):
|
| 74 |
+
# LevT ensembling
|
| 75 |
+
# A pipeline of three steps: deletion, placeholder, and word insertion.
|
| 76 |
+
# We need to average scores in each step in a pipeline way because of dependence.
|
| 77 |
+
# deletion
|
| 78 |
+
output_tokens = decoder_out.output_tokens
|
| 79 |
+
output_scores = decoder_out.output_scores
|
| 80 |
+
attn = decoder_out.attn
|
| 81 |
+
|
| 82 |
+
bsz = output_tokens.size(0)
|
| 83 |
+
if max_ratio is None:
|
| 84 |
+
max_lens = output_tokens.new().fill_(255)
|
| 85 |
+
else:
|
| 86 |
+
if encoder_outs[0].encoder_padding_mask is None:
|
| 87 |
+
src_lens = (
|
| 88 |
+
encoder_outs[0]
|
| 89 |
+
.encoder_out.new(bsz)
|
| 90 |
+
.fill_(encoder_outs[0].encoder_out.size(1))
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
src_lens = (~encoder_outs[0].encoder_padding_mask).sum(1)
|
| 94 |
+
max_lens = (src_lens * max_ratio).clamp(min=10).long()
|
| 95 |
+
|
| 96 |
+
# delete words
|
| 97 |
+
# do not delete tokens if it is <s> </s>
|
| 98 |
+
can_del_word = output_tokens.ne(self.pad).sum(1) > 2
|
| 99 |
+
if can_del_word.sum() != 0: # we cannot delete, skip
|
| 100 |
+
output_tokens, output_scores, attn = self.forward_word_del(
|
| 101 |
+
encoder_outs,
|
| 102 |
+
output_tokens,
|
| 103 |
+
output_scores,
|
| 104 |
+
attn,
|
| 105 |
+
can_del_word,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# insert placeholders
|
| 109 |
+
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens
|
| 110 |
+
if can_ins_mask.sum() != 0:
|
| 111 |
+
output_tokens, output_scores = self.forward_mask_ins(
|
| 112 |
+
encoder_outs,
|
| 113 |
+
output_tokens,
|
| 114 |
+
output_scores,
|
| 115 |
+
can_ins_mask,
|
| 116 |
+
eos_penalty,
|
| 117 |
+
max_lens,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# insert words
|
| 121 |
+
can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
|
| 122 |
+
if can_ins_word.sum() != 0:
|
| 123 |
+
output_tokens, output_scores, attn = self.forward_word_ins(
|
| 124 |
+
encoder_outs,
|
| 125 |
+
output_tokens,
|
| 126 |
+
output_scores,
|
| 127 |
+
attn,
|
| 128 |
+
can_ins_word,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# delete some unnecessary paddings
|
| 132 |
+
cut_off = output_tokens.ne(self.pad).sum(1).max()
|
| 133 |
+
output_tokens = output_tokens[:, :cut_off]
|
| 134 |
+
output_scores = output_scores[:, :cut_off]
|
| 135 |
+
attn = None if attn is None else attn[:, :cut_off, :]
|
| 136 |
+
return decoder_out._replace(
|
| 137 |
+
output_tokens=output_tokens,
|
| 138 |
+
output_scores=output_scores,
|
| 139 |
+
attn=attn,
|
| 140 |
+
history=None,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def forward_word_del(
|
| 144 |
+
self, encoder_outs, output_tokens, output_scores, attn, can_del_word
|
| 145 |
+
):
|
| 146 |
+
word_del_score_avg = []
|
| 147 |
+
word_del_attn_avg = []
|
| 148 |
+
for model, encoder_out in zip(self.models, encoder_outs):
|
| 149 |
+
word_del_out, word_del_attn = model.decoder.forward_word_del(
|
| 150 |
+
_skip(output_tokens, can_del_word),
|
| 151 |
+
_skip_encoder_out(model.encoder, encoder_out, can_del_word),
|
| 152 |
+
)
|
| 153 |
+
word_del_score = F.log_softmax(word_del_out, 2)
|
| 154 |
+
word_del_score_avg.append(word_del_score)
|
| 155 |
+
word_del_attn_avg.append(word_del_attn)
|
| 156 |
+
word_del_score_avg = torch.logsumexp(
|
| 157 |
+
torch.stack(word_del_score_avg, dim=0), dim=0
|
| 158 |
+
) - math.log(len(self.models))
|
| 159 |
+
word_del_pred = word_del_score_avg.max(-1)[1].bool()
|
| 160 |
+
if word_del_attn_avg[0] is not None:
|
| 161 |
+
word_del_attn_avg = torch.stack(word_del_attn_avg, dim=0) / len(self.models)
|
| 162 |
+
else:
|
| 163 |
+
word_del_attn_avg = None
|
| 164 |
+
|
| 165 |
+
_tokens, _scores, _attn = _apply_del_words(
|
| 166 |
+
output_tokens[can_del_word],
|
| 167 |
+
output_scores[can_del_word],
|
| 168 |
+
word_del_attn_avg,
|
| 169 |
+
word_del_pred,
|
| 170 |
+
self.pad,
|
| 171 |
+
self.bos,
|
| 172 |
+
self.eos,
|
| 173 |
+
)
|
| 174 |
+
output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad)
|
| 175 |
+
output_scores = _fill(output_scores, can_del_word, _scores, 0)
|
| 176 |
+
attn = _fill(attn, can_del_word, _attn, 0.0)
|
| 177 |
+
return output_tokens, output_scores, attn
|
| 178 |
+
|
| 179 |
+
def forward_mask_ins(
|
| 180 |
+
self,
|
| 181 |
+
encoder_outs,
|
| 182 |
+
output_tokens,
|
| 183 |
+
output_scores,
|
| 184 |
+
can_ins_mask,
|
| 185 |
+
eos_penalty,
|
| 186 |
+
max_lens,
|
| 187 |
+
):
|
| 188 |
+
mask_ins_score_avg = []
|
| 189 |
+
for model, encoder_out in zip(self.models, encoder_outs):
|
| 190 |
+
mask_ins_out, _ = model.decoder.forward_mask_ins(
|
| 191 |
+
_skip(output_tokens, can_ins_mask),
|
| 192 |
+
_skip_encoder_out(model.encoder, encoder_out, can_ins_mask),
|
| 193 |
+
)
|
| 194 |
+
mask_ins_score = F.log_softmax(mask_ins_out, 2)
|
| 195 |
+
if eos_penalty > 0.0:
|
| 196 |
+
mask_ins_score[:, :, 0] -= eos_penalty
|
| 197 |
+
mask_ins_score_avg.append(mask_ins_score)
|
| 198 |
+
mask_ins_score_avg = torch.logsumexp(
|
| 199 |
+
torch.stack(mask_ins_score_avg, dim=0), dim=0
|
| 200 |
+
) - math.log(len(self.models))
|
| 201 |
+
mask_ins_pred = mask_ins_score_avg.max(-1)[1]
|
| 202 |
+
mask_ins_pred = torch.min(
|
| 203 |
+
mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred)
|
| 204 |
+
)
|
| 205 |
+
_tokens, _scores = _apply_ins_masks(
|
| 206 |
+
output_tokens[can_ins_mask],
|
| 207 |
+
output_scores[can_ins_mask],
|
| 208 |
+
mask_ins_pred,
|
| 209 |
+
self.pad,
|
| 210 |
+
self.unk,
|
| 211 |
+
self.eos,
|
| 212 |
+
)
|
| 213 |
+
output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad)
|
| 214 |
+
output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
|
| 215 |
+
return output_tokens, output_scores
|
| 216 |
+
|
| 217 |
+
def forward_word_ins(
|
| 218 |
+
self, encoder_outs, output_tokens, output_scores, attn, can_ins_word
|
| 219 |
+
):
|
| 220 |
+
word_ins_score_avg = []
|
| 221 |
+
word_ins_attn_avg = []
|
| 222 |
+
for model, encoder_out in zip(self.models, encoder_outs):
|
| 223 |
+
word_ins_out, word_ins_attn = model.decoder.forward_word_ins(
|
| 224 |
+
_skip(output_tokens, can_ins_word),
|
| 225 |
+
_skip_encoder_out(model.encoder, encoder_out, can_ins_word),
|
| 226 |
+
)
|
| 227 |
+
word_ins_score = F.log_softmax(word_ins_out, 2)
|
| 228 |
+
word_ins_score_avg.append(word_ins_score)
|
| 229 |
+
word_ins_attn_avg.append(word_ins_attn)
|
| 230 |
+
word_ins_score_avg = torch.logsumexp(
|
| 231 |
+
torch.stack(word_ins_score_avg, dim=0), dim=0
|
| 232 |
+
) - math.log(len(self.models))
|
| 233 |
+
if word_ins_attn_avg[0] is not None:
|
| 234 |
+
word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0) / len(self.models)
|
| 235 |
+
else:
|
| 236 |
+
word_ins_attn_avg = None
|
| 237 |
+
word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1)
|
| 238 |
+
|
| 239 |
+
_tokens, _scores = _apply_ins_words(
|
| 240 |
+
output_tokens[can_ins_word],
|
| 241 |
+
output_scores[can_ins_word],
|
| 242 |
+
word_ins_pred,
|
| 243 |
+
word_ins_score_max,
|
| 244 |
+
self.unk,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad)
|
| 248 |
+
output_scores = _fill(output_scores, can_ins_word, _scores, 0)
|
| 249 |
+
attn = _fill(attn, can_ins_word, word_ins_attn, 0.0)
|
| 250 |
+
return output_tokens, output_scores, attn
|
| 251 |
+
|
| 252 |
+
def initialize_output_tokens(self, encoder_outs, src_tokens):
|
| 253 |
+
# LevT doesn't do length prediction.
|
| 254 |
+
return self.models[0].initialize_output_tokens(encoder_outs[0], src_tokens)
|
fairseq-0.10.2/fairseq/models/nat/nonautoregressive_transformer.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from fairseq import utils
|
| 9 |
+
from fairseq.iterative_refinement_generator import DecoderOut
|
| 10 |
+
from fairseq.models import register_model, register_model_architecture
|
| 11 |
+
from fairseq.models.nat import FairseqNATDecoder, FairseqNATModel, ensemble_decoder
|
| 12 |
+
from fairseq.models.transformer import Embedding
|
| 13 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _mean_pooling(enc_feats, src_masks):
|
| 17 |
+
# enc_feats: T x B x C
|
| 18 |
+
# src_masks: B x T or None
|
| 19 |
+
if src_masks is None:
|
| 20 |
+
enc_feats = enc_feats.mean(0)
|
| 21 |
+
else:
|
| 22 |
+
src_masks = (~src_masks).transpose(0, 1).type_as(enc_feats)
|
| 23 |
+
enc_feats = (
|
| 24 |
+
(enc_feats / src_masks.sum(0)[None, :, None]) * src_masks[:, :, None]
|
| 25 |
+
).sum(0)
|
| 26 |
+
return enc_feats
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _argmax(x, dim):
|
| 30 |
+
return (x == x.max(dim, keepdim=True)[0]).type_as(x)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _uniform_assignment(src_lens, trg_lens):
|
| 34 |
+
max_trg_len = trg_lens.max()
|
| 35 |
+
steps = (src_lens.float() - 1) / (trg_lens.float() - 1) # step-size
|
| 36 |
+
# max_trg_len
|
| 37 |
+
index_t = utils.new_arange(trg_lens, max_trg_len).float()
|
| 38 |
+
index_t = steps[:, None] * index_t[None, :] # batch_size X max_trg_len
|
| 39 |
+
index_t = torch.round(index_t).long().detach()
|
| 40 |
+
return index_t
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@register_model("nonautoregressive_transformer")
|
| 44 |
+
class NATransformerModel(FairseqNATModel):
|
| 45 |
+
@property
|
| 46 |
+
def allow_length_beam(self):
|
| 47 |
+
return True
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def add_args(parser):
|
| 51 |
+
FairseqNATModel.add_args(parser)
|
| 52 |
+
|
| 53 |
+
# length prediction
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--src-embedding-copy",
|
| 56 |
+
action="store_true",
|
| 57 |
+
help="copy encoder word embeddings as the initial input of the decoder",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--pred-length-offset",
|
| 61 |
+
action="store_true",
|
| 62 |
+
help="predicting the length difference between the target and source sentences",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--sg-length-pred",
|
| 66 |
+
action="store_true",
|
| 67 |
+
help="stop the gradients back-propagated from the length predictor",
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--length-loss-factor",
|
| 71 |
+
type=float,
|
| 72 |
+
help="weights on the length prediction loss",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def build_decoder(cls, args, tgt_dict, embed_tokens):
|
| 77 |
+
decoder = NATransformerDecoder(args, tgt_dict, embed_tokens)
|
| 78 |
+
if getattr(args, "apply_bert_init", False):
|
| 79 |
+
decoder.apply(init_bert_params)
|
| 80 |
+
return decoder
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
|
| 84 |
+
):
|
| 85 |
+
# encoding
|
| 86 |
+
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
|
| 87 |
+
|
| 88 |
+
# length prediction
|
| 89 |
+
length_out = self.decoder.forward_length(
|
| 90 |
+
normalize=False, encoder_out=encoder_out
|
| 91 |
+
)
|
| 92 |
+
length_tgt = self.decoder.forward_length_prediction(
|
| 93 |
+
length_out, encoder_out, tgt_tokens
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# decoding
|
| 97 |
+
word_ins_out = self.decoder(
|
| 98 |
+
normalize=False,
|
| 99 |
+
prev_output_tokens=prev_output_tokens,
|
| 100 |
+
encoder_out=encoder_out,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
"word_ins": {
|
| 105 |
+
"out": word_ins_out,
|
| 106 |
+
"tgt": tgt_tokens,
|
| 107 |
+
"mask": tgt_tokens.ne(self.pad),
|
| 108 |
+
"ls": self.args.label_smoothing,
|
| 109 |
+
"nll_loss": True,
|
| 110 |
+
},
|
| 111 |
+
"length": {
|
| 112 |
+
"out": length_out,
|
| 113 |
+
"tgt": length_tgt,
|
| 114 |
+
"factor": self.decoder.length_loss_factor,
|
| 115 |
+
},
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
|
| 119 |
+
step = decoder_out.step
|
| 120 |
+
output_tokens = decoder_out.output_tokens
|
| 121 |
+
output_scores = decoder_out.output_scores
|
| 122 |
+
history = decoder_out.history
|
| 123 |
+
|
| 124 |
+
# execute the decoder
|
| 125 |
+
output_masks = output_tokens.ne(self.pad)
|
| 126 |
+
_scores, _tokens = self.decoder(
|
| 127 |
+
normalize=True,
|
| 128 |
+
prev_output_tokens=output_tokens,
|
| 129 |
+
encoder_out=encoder_out,
|
| 130 |
+
step=step,
|
| 131 |
+
).max(-1)
|
| 132 |
+
|
| 133 |
+
output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
|
| 134 |
+
output_scores.masked_scatter_(output_masks, _scores[output_masks])
|
| 135 |
+
if history is not None:
|
| 136 |
+
history.append(output_tokens.clone())
|
| 137 |
+
|
| 138 |
+
return decoder_out._replace(
|
| 139 |
+
output_tokens=output_tokens,
|
| 140 |
+
output_scores=output_scores,
|
| 141 |
+
attn=None,
|
| 142 |
+
history=history,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def initialize_output_tokens(self, encoder_out, src_tokens):
|
| 146 |
+
# length prediction
|
| 147 |
+
length_tgt = self.decoder.forward_length_prediction(
|
| 148 |
+
self.decoder.forward_length(normalize=True, encoder_out=encoder_out),
|
| 149 |
+
encoder_out=encoder_out,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
max_length = length_tgt.clamp_(min=2).max()
|
| 153 |
+
idx_length = utils.new_arange(src_tokens, max_length)
|
| 154 |
+
|
| 155 |
+
initial_output_tokens = src_tokens.new_zeros(
|
| 156 |
+
src_tokens.size(0), max_length
|
| 157 |
+
).fill_(self.pad)
|
| 158 |
+
initial_output_tokens.masked_fill_(
|
| 159 |
+
idx_length[None, :] < length_tgt[:, None], self.unk
|
| 160 |
+
)
|
| 161 |
+
initial_output_tokens[:, 0] = self.bos
|
| 162 |
+
initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)
|
| 163 |
+
|
| 164 |
+
initial_output_scores = initial_output_tokens.new_zeros(
|
| 165 |
+
*initial_output_tokens.size()
|
| 166 |
+
).type_as(encoder_out.encoder_out)
|
| 167 |
+
|
| 168 |
+
return DecoderOut(
|
| 169 |
+
output_tokens=initial_output_tokens,
|
| 170 |
+
output_scores=initial_output_scores,
|
| 171 |
+
attn=None,
|
| 172 |
+
step=0,
|
| 173 |
+
max_step=0,
|
| 174 |
+
history=None,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def regenerate_length_beam(self, decoder_out, beam_size):
|
| 178 |
+
output_tokens = decoder_out.output_tokens
|
| 179 |
+
length_tgt = output_tokens.ne(self.pad).sum(1)
|
| 180 |
+
length_tgt = (
|
| 181 |
+
length_tgt[:, None]
|
| 182 |
+
+ utils.new_arange(length_tgt, 1, beam_size)
|
| 183 |
+
- beam_size // 2
|
| 184 |
+
)
|
| 185 |
+
length_tgt = length_tgt.view(-1).clamp_(min=2)
|
| 186 |
+
max_length = length_tgt.max()
|
| 187 |
+
idx_length = utils.new_arange(length_tgt, max_length)
|
| 188 |
+
|
| 189 |
+
initial_output_tokens = output_tokens.new_zeros(
|
| 190 |
+
length_tgt.size(0), max_length
|
| 191 |
+
).fill_(self.pad)
|
| 192 |
+
initial_output_tokens.masked_fill_(
|
| 193 |
+
idx_length[None, :] < length_tgt[:, None], self.unk
|
| 194 |
+
)
|
| 195 |
+
initial_output_tokens[:, 0] = self.bos
|
| 196 |
+
initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)
|
| 197 |
+
|
| 198 |
+
initial_output_scores = initial_output_tokens.new_zeros(
|
| 199 |
+
*initial_output_tokens.size()
|
| 200 |
+
).type_as(decoder_out.output_scores)
|
| 201 |
+
|
| 202 |
+
return decoder_out._replace(
|
| 203 |
+
output_tokens=initial_output_tokens, output_scores=initial_output_scores
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class NATransformerDecoder(FairseqNATDecoder):
|
| 208 |
+
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
|
| 209 |
+
super().__init__(
|
| 210 |
+
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
|
| 211 |
+
)
|
| 212 |
+
self.dictionary = dictionary
|
| 213 |
+
self.bos = dictionary.bos()
|
| 214 |
+
self.unk = dictionary.unk()
|
| 215 |
+
self.eos = dictionary.eos()
|
| 216 |
+
|
| 217 |
+
self.encoder_embed_dim = args.encoder_embed_dim
|
| 218 |
+
self.sg_length_pred = getattr(args, "sg_length_pred", False)
|
| 219 |
+
self.pred_length_offset = getattr(args, "pred_length_offset", False)
|
| 220 |
+
self.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
|
| 221 |
+
self.src_embedding_copy = getattr(args, "src_embedding_copy", False)
|
| 222 |
+
self.embed_length = Embedding(256, self.encoder_embed_dim, None)
|
| 223 |
+
|
| 224 |
+
@ensemble_decoder
|
| 225 |
+
def forward(self, normalize, encoder_out, prev_output_tokens, step=0, **unused):
|
| 226 |
+
features, _ = self.extract_features(
|
| 227 |
+
prev_output_tokens,
|
| 228 |
+
encoder_out=encoder_out,
|
| 229 |
+
embedding_copy=(step == 0) & self.src_embedding_copy,
|
| 230 |
+
)
|
| 231 |
+
decoder_out = self.output_layer(features)
|
| 232 |
+
return F.log_softmax(decoder_out, -1) if normalize else decoder_out
|
| 233 |
+
|
| 234 |
+
@ensemble_decoder
|
| 235 |
+
def forward_length(self, normalize, encoder_out):
|
| 236 |
+
enc_feats = encoder_out.encoder_out # T x B x C
|
| 237 |
+
src_masks = encoder_out.encoder_padding_mask # B x T or None
|
| 238 |
+
enc_feats = _mean_pooling(enc_feats, src_masks)
|
| 239 |
+
if self.sg_length_pred:
|
| 240 |
+
enc_feats = enc_feats.detach()
|
| 241 |
+
length_out = F.linear(enc_feats, self.embed_length.weight)
|
| 242 |
+
return F.log_softmax(length_out, -1) if normalize else length_out
|
| 243 |
+
|
| 244 |
+
def extract_features(
|
| 245 |
+
self,
|
| 246 |
+
prev_output_tokens,
|
| 247 |
+
encoder_out=None,
|
| 248 |
+
early_exit=None,
|
| 249 |
+
embedding_copy=False,
|
| 250 |
+
**unused
|
| 251 |
+
):
|
| 252 |
+
"""
|
| 253 |
+
Similar to *forward* but only return features.
|
| 254 |
+
|
| 255 |
+
Inputs:
|
| 256 |
+
prev_output_tokens: Tensor(B, T)
|
| 257 |
+
encoder_out: a dictionary of hidden states and masks
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
tuple:
|
| 261 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
| 262 |
+
- a dictionary with any model-specific outputs
|
| 263 |
+
the LevenshteinTransformer decoder has full-attention to all generated tokens
|
| 264 |
+
"""
|
| 265 |
+
# embedding
|
| 266 |
+
if embedding_copy:
|
| 267 |
+
src_embd = encoder_out.encoder_embedding
|
| 268 |
+
src_mask = encoder_out.encoder_padding_mask
|
| 269 |
+
src_mask = (
|
| 270 |
+
~src_mask
|
| 271 |
+
if src_mask is not None
|
| 272 |
+
else prev_output_tokens.new_ones(*src_embd.size()[:2]).bool()
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
x, decoder_padding_mask = self.forward_embedding(
|
| 276 |
+
prev_output_tokens,
|
| 277 |
+
self.forward_copying_source(
|
| 278 |
+
src_embd, src_mask, prev_output_tokens.ne(self.padding_idx)
|
| 279 |
+
),
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
else:
|
| 283 |
+
|
| 284 |
+
x, decoder_padding_mask = self.forward_embedding(prev_output_tokens)
|
| 285 |
+
|
| 286 |
+
# B x T x C -> T x B x C
|
| 287 |
+
x = x.transpose(0, 1)
|
| 288 |
+
attn = None
|
| 289 |
+
inner_states = [x]
|
| 290 |
+
|
| 291 |
+
# decoder layers
|
| 292 |
+
for i, layer in enumerate(self.layers):
|
| 293 |
+
|
| 294 |
+
# early exit from the decoder.
|
| 295 |
+
if (early_exit is not None) and (i >= early_exit):
|
| 296 |
+
break
|
| 297 |
+
|
| 298 |
+
x, attn, _ = layer(
|
| 299 |
+
x,
|
| 300 |
+
encoder_out.encoder_out if encoder_out is not None else None,
|
| 301 |
+
encoder_out.encoder_padding_mask if encoder_out is not None else None,
|
| 302 |
+
self_attn_mask=None,
|
| 303 |
+
self_attn_padding_mask=decoder_padding_mask,
|
| 304 |
+
)
|
| 305 |
+
inner_states.append(x)
|
| 306 |
+
|
| 307 |
+
if self.layer_norm:
|
| 308 |
+
x = self.layer_norm(x)
|
| 309 |
+
|
| 310 |
+
# T x B x C -> B x T x C
|
| 311 |
+
x = x.transpose(0, 1)
|
| 312 |
+
|
| 313 |
+
if self.project_out_dim is not None:
|
| 314 |
+
x = self.project_out_dim(x)
|
| 315 |
+
|
| 316 |
+
return x, {"attn": attn, "inner_states": inner_states}
|
| 317 |
+
|
| 318 |
+
def forward_embedding(self, prev_output_tokens, states=None):
|
| 319 |
+
# embed positions
|
| 320 |
+
positions = (
|
| 321 |
+
self.embed_positions(prev_output_tokens)
|
| 322 |
+
if self.embed_positions is not None
|
| 323 |
+
else None
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# embed tokens and positions
|
| 327 |
+
if states is None:
|
| 328 |
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
| 329 |
+
if self.project_in_dim is not None:
|
| 330 |
+
x = self.project_in_dim(x)
|
| 331 |
+
else:
|
| 332 |
+
x = states
|
| 333 |
+
|
| 334 |
+
if positions is not None:
|
| 335 |
+
x += positions
|
| 336 |
+
x = self.dropout_module(x)
|
| 337 |
+
decoder_padding_mask = prev_output_tokens.eq(self.padding_idx)
|
| 338 |
+
return x, decoder_padding_mask
|
| 339 |
+
|
| 340 |
+
def forward_copying_source(self, src_embeds, src_masks, tgt_masks):
|
| 341 |
+
length_sources = src_masks.sum(1)
|
| 342 |
+
length_targets = tgt_masks.sum(1)
|
| 343 |
+
mapped_inputs = _uniform_assignment(length_sources, length_targets).masked_fill(
|
| 344 |
+
~tgt_masks, 0
|
| 345 |
+
)
|
| 346 |
+
copied_embedding = torch.gather(
|
| 347 |
+
src_embeds,
|
| 348 |
+
1,
|
| 349 |
+
mapped_inputs.unsqueeze(-1).expand(
|
| 350 |
+
*mapped_inputs.size(), src_embeds.size(-1)
|
| 351 |
+
),
|
| 352 |
+
)
|
| 353 |
+
return copied_embedding
|
| 354 |
+
|
| 355 |
+
def forward_length_prediction(self, length_out, encoder_out, tgt_tokens=None):
|
| 356 |
+
enc_feats = encoder_out.encoder_out # T x B x C
|
| 357 |
+
src_masks = encoder_out.encoder_padding_mask # B x T or None
|
| 358 |
+
if self.pred_length_offset:
|
| 359 |
+
if src_masks is None:
|
| 360 |
+
src_lengs = enc_feats.new_ones(enc_feats.size(1)).fill_(
|
| 361 |
+
enc_feats.size(0)
|
| 362 |
+
)
|
| 363 |
+
else:
|
| 364 |
+
src_lengs = (~src_masks).transpose(0, 1).type_as(enc_feats).sum(0)
|
| 365 |
+
src_lengs = src_lengs.long()
|
| 366 |
+
|
| 367 |
+
if tgt_tokens is not None:
|
| 368 |
+
# obtain the length target
|
| 369 |
+
tgt_lengs = tgt_tokens.ne(self.padding_idx).sum(1).long()
|
| 370 |
+
if self.pred_length_offset:
|
| 371 |
+
length_tgt = tgt_lengs - src_lengs + 128
|
| 372 |
+
else:
|
| 373 |
+
length_tgt = tgt_lengs
|
| 374 |
+
length_tgt = length_tgt.clamp(min=0, max=255)
|
| 375 |
+
|
| 376 |
+
else:
|
| 377 |
+
# predict the length target (greedy for now)
|
| 378 |
+
# TODO: implementing length-beam
|
| 379 |
+
pred_lengs = length_out.max(-1)[1]
|
| 380 |
+
if self.pred_length_offset:
|
| 381 |
+
length_tgt = pred_lengs - 128 + src_lengs
|
| 382 |
+
else:
|
| 383 |
+
length_tgt = pred_lengs
|
| 384 |
+
|
| 385 |
+
return length_tgt
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
@register_model_architecture(
|
| 389 |
+
"nonautoregressive_transformer", "nonautoregressive_transformer"
|
| 390 |
+
)
|
| 391 |
+
def base_architecture(args):
|
| 392 |
+
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
| 393 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
| 394 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
| 395 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
| 396 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
| 397 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
| 398 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
| 399 |
+
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
| 400 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
| 401 |
+
args.decoder_ffn_embed_dim = getattr(
|
| 402 |
+
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
| 403 |
+
)
|
| 404 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
| 405 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
| 406 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
| 407 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
| 408 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
| 409 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
| 410 |
+
args.activation_fn = getattr(args, "activation_fn", "relu")
|
| 411 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 412 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
| 413 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
| 414 |
+
args.share_decoder_input_output_embed = getattr(
|
| 415 |
+
args, "share_decoder_input_output_embed", False
|
| 416 |
+
)
|
| 417 |
+
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
|
| 418 |
+
args.no_token_positional_embeddings = getattr(
|
| 419 |
+
args, "no_token_positional_embeddings", False
|
| 420 |
+
)
|
| 421 |
+
args.adaptive_input = getattr(args, "adaptive_input", False)
|
| 422 |
+
args.apply_bert_init = getattr(args, "apply_bert_init", False)
|
| 423 |
+
|
| 424 |
+
args.decoder_output_dim = getattr(
|
| 425 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
| 426 |
+
)
|
| 427 |
+
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
| 428 |
+
|
| 429 |
+
# --- special arguments ---
|
| 430 |
+
args.sg_length_pred = getattr(args, "sg_length_pred", False)
|
| 431 |
+
args.pred_length_offset = getattr(args, "pred_length_offset", False)
|
| 432 |
+
args.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
|
| 433 |
+
args.src_embedding_copy = getattr(args, "src_embedding_copy", False)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
@register_model_architecture(
|
| 437 |
+
"nonautoregressive_transformer", "nonautoregressive_transformer_wmt_en_de"
|
| 438 |
+
)
|
| 439 |
+
def nonautoregressive_transformer_wmt_en_de(args):
|
| 440 |
+
base_architecture(args)
|
fairseq-0.10.2/fairseq/models/roberta/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .hub_interface import * # noqa
|
| 7 |
+
from .model import * # noqa
|
| 8 |
+
from .model_camembert import * # noqa
|
| 9 |
+
from .model_xlmr import * # noqa
|
fairseq-0.10.2/fairseq/models/roberta/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
fairseq-0.10.2/fairseq/models/roberta/alignment_utils.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import Counter
|
| 7 |
+
from typing import List
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def align_bpe_to_words(roberta, bpe_tokens: torch.LongTensor, other_tokens: List[str]):
|
| 13 |
+
"""
|
| 14 |
+
Helper to align GPT-2 BPE to other tokenization formats (e.g., spaCy).
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
roberta (RobertaHubInterface): RoBERTa instance
|
| 18 |
+
bpe_tokens (torch.LongTensor): GPT-2 BPE tokens of shape `(T_bpe)`
|
| 19 |
+
other_tokens (List[str]): other tokens of shape `(T_words)`
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
List[str]: mapping from *other_tokens* to corresponding *bpe_tokens*.
|
| 23 |
+
"""
|
| 24 |
+
assert bpe_tokens.dim() == 1
|
| 25 |
+
assert bpe_tokens[0] == 0
|
| 26 |
+
|
| 27 |
+
def clean(text):
|
| 28 |
+
return text.strip()
|
| 29 |
+
|
| 30 |
+
# remove whitespaces to simplify alignment
|
| 31 |
+
bpe_tokens = [roberta.task.source_dictionary.string([x]) for x in bpe_tokens]
|
| 32 |
+
bpe_tokens = [
|
| 33 |
+
clean(roberta.bpe.decode(x) if x not in {"<s>", ""} else x) for x in bpe_tokens
|
| 34 |
+
]
|
| 35 |
+
other_tokens = [clean(str(o)) for o in other_tokens]
|
| 36 |
+
|
| 37 |
+
# strip leading <s>
|
| 38 |
+
bpe_tokens = bpe_tokens[1:]
|
| 39 |
+
assert "".join(bpe_tokens) == "".join(other_tokens)
|
| 40 |
+
|
| 41 |
+
# create alignment from every word to a list of BPE tokens
|
| 42 |
+
alignment = []
|
| 43 |
+
bpe_toks = filter(lambda item: item[1] != "", enumerate(bpe_tokens, start=1))
|
| 44 |
+
j, bpe_tok = next(bpe_toks)
|
| 45 |
+
for other_tok in other_tokens:
|
| 46 |
+
bpe_indices = []
|
| 47 |
+
while True:
|
| 48 |
+
if other_tok.startswith(bpe_tok):
|
| 49 |
+
bpe_indices.append(j)
|
| 50 |
+
other_tok = other_tok[len(bpe_tok) :]
|
| 51 |
+
try:
|
| 52 |
+
j, bpe_tok = next(bpe_toks)
|
| 53 |
+
except StopIteration:
|
| 54 |
+
j, bpe_tok = None, None
|
| 55 |
+
elif bpe_tok.startswith(other_tok):
|
| 56 |
+
# other_tok spans multiple BPE tokens
|
| 57 |
+
bpe_indices.append(j)
|
| 58 |
+
bpe_tok = bpe_tok[len(other_tok) :]
|
| 59 |
+
other_tok = ""
|
| 60 |
+
else:
|
| 61 |
+
raise Exception('Cannot align "{}" and "{}"'.format(other_tok, bpe_tok))
|
| 62 |
+
if other_tok == "":
|
| 63 |
+
break
|
| 64 |
+
assert len(bpe_indices) > 0
|
| 65 |
+
alignment.append(bpe_indices)
|
| 66 |
+
assert len(alignment) == len(other_tokens)
|
| 67 |
+
|
| 68 |
+
return alignment
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def align_features_to_words(roberta, features, alignment):
|
| 72 |
+
"""
|
| 73 |
+
Align given features to words.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
roberta (RobertaHubInterface): RoBERTa instance
|
| 77 |
+
features (torch.Tensor): features to align of shape `(T_bpe x C)`
|
| 78 |
+
alignment: alignment between BPE tokens and words returned by
|
| 79 |
+
func:`align_bpe_to_words`.
|
| 80 |
+
"""
|
| 81 |
+
assert features.dim() == 2
|
| 82 |
+
|
| 83 |
+
bpe_counts = Counter(j for bpe_indices in alignment for j in bpe_indices)
|
| 84 |
+
assert bpe_counts[0] == 0 # <s> shouldn't be aligned
|
| 85 |
+
denom = features.new([bpe_counts.get(j, 1) for j in range(len(features))])
|
| 86 |
+
weighted_features = features / denom.unsqueeze(-1)
|
| 87 |
+
|
| 88 |
+
output = [weighted_features[0]]
|
| 89 |
+
largest_j = -1
|
| 90 |
+
for bpe_indices in alignment:
|
| 91 |
+
output.append(weighted_features[bpe_indices].sum(dim=0))
|
| 92 |
+
largest_j = max(largest_j, *bpe_indices)
|
| 93 |
+
for j in range(largest_j + 1, len(features)):
|
| 94 |
+
output.append(weighted_features[j])
|
| 95 |
+
output = torch.stack(output)
|
| 96 |
+
assert torch.all(torch.abs(output.sum(dim=0) - features.sum(dim=0)) < 1e-4)
|
| 97 |
+
return output
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def spacy_nlp():
|
| 101 |
+
if getattr(spacy_nlp, "_nlp", None) is None:
|
| 102 |
+
try:
|
| 103 |
+
from spacy.lang.en import English
|
| 104 |
+
|
| 105 |
+
spacy_nlp._nlp = English()
|
| 106 |
+
except ImportError:
|
| 107 |
+
raise ImportError("Please install spacy with: pip install spacy")
|
| 108 |
+
return spacy_nlp._nlp
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def spacy_tokenizer():
|
| 112 |
+
if getattr(spacy_tokenizer, "_tokenizer", None) is None:
|
| 113 |
+
try:
|
| 114 |
+
nlp = spacy_nlp()
|
| 115 |
+
spacy_tokenizer._tokenizer = nlp.Defaults.create_tokenizer(nlp)
|
| 116 |
+
except ImportError:
|
| 117 |
+
raise ImportError("Please install spacy with: pip install spacy")
|
| 118 |
+
return spacy_tokenizer._tokenizer
|
fairseq-0.10.2/fairseq/models/roberta/hub_interface.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from fairseq import utils
|
| 11 |
+
from fairseq.data import encoders
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RobertaHubInterface(nn.Module):
|
| 15 |
+
"""A simple PyTorch Hub interface to RoBERTa.
|
| 16 |
+
|
| 17 |
+
Usage: https://github.com/pytorch/fairseq/tree/master/examples/roberta
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, args, task, model):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.args = args
|
| 23 |
+
self.task = task
|
| 24 |
+
self.model = model
|
| 25 |
+
|
| 26 |
+
self.bpe = encoders.build_bpe(args)
|
| 27 |
+
|
| 28 |
+
# this is useful for determining the device
|
| 29 |
+
self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def device(self):
|
| 33 |
+
return self._float_tensor.device
|
| 34 |
+
|
| 35 |
+
def encode(
|
| 36 |
+
self, sentence: str, *addl_sentences, no_separator=False
|
| 37 |
+
) -> torch.LongTensor:
|
| 38 |
+
"""
|
| 39 |
+
BPE-encode a sentence (or multiple sentences).
|
| 40 |
+
|
| 41 |
+
Every sequence begins with a beginning-of-sentence (`<s>`) symbol.
|
| 42 |
+
Every sentence ends with an end-of-sentence (`</s>`) and we use an
|
| 43 |
+
extra end-of-sentence (`</s>`) as a separator.
|
| 44 |
+
|
| 45 |
+
Example (single sentence): `<s> a b c </s>`
|
| 46 |
+
Example (sentence pair): `<s> d e f </s> </s> 1 2 3 </s>`
|
| 47 |
+
|
| 48 |
+
The BPE encoding follows GPT-2. One subtle detail is that the GPT-2 BPE
|
| 49 |
+
requires leading spaces. For example::
|
| 50 |
+
|
| 51 |
+
>>> roberta.encode('Hello world').tolist()
|
| 52 |
+
[0, 31414, 232, 2]
|
| 53 |
+
>>> roberta.encode(' world').tolist()
|
| 54 |
+
[0, 232, 2]
|
| 55 |
+
>>> roberta.encode('world').tolist()
|
| 56 |
+
[0, 8331, 2]
|
| 57 |
+
"""
|
| 58 |
+
bpe_sentence = "<s> " + self.bpe.encode(sentence) + " </s>"
|
| 59 |
+
for s in addl_sentences:
|
| 60 |
+
bpe_sentence += " </s>" if not no_separator else ""
|
| 61 |
+
bpe_sentence += " " + self.bpe.encode(s) + " </s>"
|
| 62 |
+
tokens = self.task.source_dictionary.encode_line(
|
| 63 |
+
bpe_sentence, append_eos=False, add_if_not_exist=False
|
| 64 |
+
)
|
| 65 |
+
return tokens.long()
|
| 66 |
+
|
| 67 |
+
def decode(self, tokens: torch.LongTensor):
|
| 68 |
+
assert tokens.dim() == 1
|
| 69 |
+
tokens = tokens.numpy()
|
| 70 |
+
if tokens[0] == self.task.source_dictionary.bos():
|
| 71 |
+
tokens = tokens[1:] # remove <s>
|
| 72 |
+
eos_mask = tokens == self.task.source_dictionary.eos()
|
| 73 |
+
doc_mask = eos_mask[1:] & eos_mask[:-1]
|
| 74 |
+
sentences = np.split(tokens, doc_mask.nonzero()[0] + 1)
|
| 75 |
+
sentences = [
|
| 76 |
+
self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences
|
| 77 |
+
]
|
| 78 |
+
if len(sentences) == 1:
|
| 79 |
+
return sentences[0]
|
| 80 |
+
return sentences
|
| 81 |
+
|
| 82 |
+
def extract_features(
|
| 83 |
+
self, tokens: torch.LongTensor, return_all_hiddens: bool = False
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
if tokens.dim() == 1:
|
| 86 |
+
tokens = tokens.unsqueeze(0)
|
| 87 |
+
if tokens.size(-1) > self.model.max_positions():
|
| 88 |
+
raise ValueError(
|
| 89 |
+
"tokens exceeds maximum length: {} > {}".format(
|
| 90 |
+
tokens.size(-1), self.model.max_positions()
|
| 91 |
+
)
|
| 92 |
+
)
|
| 93 |
+
features, extra = self.model(
|
| 94 |
+
tokens.to(device=self.device),
|
| 95 |
+
features_only=True,
|
| 96 |
+
return_all_hiddens=return_all_hiddens,
|
| 97 |
+
)
|
| 98 |
+
if return_all_hiddens:
|
| 99 |
+
# convert from T x B x C -> B x T x C
|
| 100 |
+
inner_states = extra["inner_states"]
|
| 101 |
+
return [inner_state.transpose(0, 1) for inner_state in inner_states]
|
| 102 |
+
else:
|
| 103 |
+
return features # just the last layer's features
|
| 104 |
+
|
| 105 |
+
def register_classification_head(
|
| 106 |
+
self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs
|
| 107 |
+
):
|
| 108 |
+
self.model.register_classification_head(
|
| 109 |
+
name, num_classes=num_classes, embedding_size=embedding_size, **kwargs
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False):
|
| 113 |
+
features = self.extract_features(tokens.to(device=self.device))
|
| 114 |
+
logits = self.model.classification_heads[head](features)
|
| 115 |
+
if return_logits:
|
| 116 |
+
return logits
|
| 117 |
+
return F.log_softmax(logits, dim=-1)
|
| 118 |
+
|
| 119 |
+
def extract_features_aligned_to_words(
|
| 120 |
+
self, sentence: str, return_all_hiddens: bool = False
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
"""Extract RoBERTa features, aligned to spaCy's word-level tokenizer."""
|
| 123 |
+
from fairseq.models.roberta import alignment_utils
|
| 124 |
+
from spacy.tokens import Doc
|
| 125 |
+
|
| 126 |
+
nlp = alignment_utils.spacy_nlp()
|
| 127 |
+
tokenizer = alignment_utils.spacy_tokenizer()
|
| 128 |
+
|
| 129 |
+
# tokenize both with GPT-2 BPE and spaCy
|
| 130 |
+
bpe_toks = self.encode(sentence)
|
| 131 |
+
spacy_toks = tokenizer(sentence)
|
| 132 |
+
spacy_toks_ws = [t.text_with_ws for t in tokenizer(sentence)]
|
| 133 |
+
alignment = alignment_utils.align_bpe_to_words(self, bpe_toks, spacy_toks_ws)
|
| 134 |
+
|
| 135 |
+
# extract features and align them
|
| 136 |
+
features = self.extract_features(
|
| 137 |
+
bpe_toks, return_all_hiddens=return_all_hiddens
|
| 138 |
+
)
|
| 139 |
+
features = features.squeeze(0)
|
| 140 |
+
aligned_feats = alignment_utils.align_features_to_words(
|
| 141 |
+
self, features, alignment
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# wrap in spaCy Doc
|
| 145 |
+
doc = Doc(
|
| 146 |
+
nlp.vocab,
|
| 147 |
+
words=["<s>"] + [x.text for x in spacy_toks] + ["</s>"],
|
| 148 |
+
spaces=[True]
|
| 149 |
+
+ [x.endswith(" ") for x in spacy_toks_ws[:-1]]
|
| 150 |
+
+ [True, False],
|
| 151 |
+
)
|
| 152 |
+
assert len(doc) == aligned_feats.size(0)
|
| 153 |
+
doc.user_token_hooks["vector"] = lambda token: aligned_feats[token.i]
|
| 154 |
+
return doc
|
| 155 |
+
|
| 156 |
+
def fill_mask(self, masked_input: str, topk: int = 5):
|
| 157 |
+
masked_token = "<mask>"
|
| 158 |
+
assert (
|
| 159 |
+
masked_token in masked_input and masked_input.count(masked_token) == 1
|
| 160 |
+
), "Please add one {0} token for the input, eg: 'He is a {0} guy'".format(
|
| 161 |
+
masked_token
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
text_spans = masked_input.split(masked_token)
|
| 165 |
+
text_spans_bpe = (
|
| 166 |
+
(" {0} ".format(masked_token))
|
| 167 |
+
.join([self.bpe.encode(text_span.rstrip()) for text_span in text_spans])
|
| 168 |
+
.strip()
|
| 169 |
+
)
|
| 170 |
+
tokens = self.task.source_dictionary.encode_line(
|
| 171 |
+
"<s> " + text_spans_bpe + " </s>",
|
| 172 |
+
append_eos=False,
|
| 173 |
+
add_if_not_exist=False,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
masked_index = (tokens == self.task.mask_idx).nonzero()
|
| 177 |
+
if tokens.dim() == 1:
|
| 178 |
+
tokens = tokens.unsqueeze(0)
|
| 179 |
+
|
| 180 |
+
with utils.model_eval(self.model):
|
| 181 |
+
features, extra = self.model(
|
| 182 |
+
tokens.long().to(device=self.device),
|
| 183 |
+
features_only=False,
|
| 184 |
+
return_all_hiddens=False,
|
| 185 |
+
)
|
| 186 |
+
logits = features[0, masked_index, :].squeeze()
|
| 187 |
+
prob = logits.softmax(dim=0)
|
| 188 |
+
values, index = prob.topk(k=topk, dim=0)
|
| 189 |
+
topk_predicted_token_bpe = self.task.source_dictionary.string(index)
|
| 190 |
+
|
| 191 |
+
topk_filled_outputs = []
|
| 192 |
+
for index, predicted_token_bpe in enumerate(
|
| 193 |
+
topk_predicted_token_bpe.split(" ")
|
| 194 |
+
):
|
| 195 |
+
predicted_token = self.bpe.decode(predicted_token_bpe)
|
| 196 |
+
# Quick hack to fix https://github.com/pytorch/fairseq/issues/1306
|
| 197 |
+
if predicted_token_bpe.startswith("\u2581"):
|
| 198 |
+
predicted_token = " " + predicted_token
|
| 199 |
+
if " {0}".format(masked_token) in masked_input:
|
| 200 |
+
topk_filled_outputs.append(
|
| 201 |
+
(
|
| 202 |
+
masked_input.replace(
|
| 203 |
+
" {0}".format(masked_token), predicted_token
|
| 204 |
+
),
|
| 205 |
+
values[index].item(),
|
| 206 |
+
predicted_token,
|
| 207 |
+
)
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
topk_filled_outputs.append(
|
| 211 |
+
(
|
| 212 |
+
masked_input.replace(masked_token, predicted_token),
|
| 213 |
+
values[index].item(),
|
| 214 |
+
predicted_token,
|
| 215 |
+
)
|
| 216 |
+
)
|
| 217 |
+
return topk_filled_outputs
|
| 218 |
+
|
| 219 |
+
def disambiguate_pronoun(self, sentence: str) -> bool:
|
| 220 |
+
"""
|
| 221 |
+
Usage::
|
| 222 |
+
|
| 223 |
+
>>> disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.')
|
| 224 |
+
True
|
| 225 |
+
|
| 226 |
+
>>> disambiguate_pronoun('The trophy would not fit in the brown suitcase because [it] was too big.')
|
| 227 |
+
'The trophy'
|
| 228 |
+
"""
|
| 229 |
+
assert hasattr(
|
| 230 |
+
self.task, "disambiguate_pronoun"
|
| 231 |
+
), "roberta.disambiguate_pronoun() requires a model trained with the WSC task."
|
| 232 |
+
with utils.model_eval(self.model):
|
| 233 |
+
return self.task.disambiguate_pronoun(
|
| 234 |
+
self.model, sentence, use_cuda=self.device.type == "cuda"
|
| 235 |
+
)
|
fairseq-0.10.2/fairseq/models/roberta/model.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""
|
| 6 |
+
RoBERTa: A Robustly Optimized BERT Pretraining Approach.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from fairseq import utils
|
| 15 |
+
from fairseq.models import (
|
| 16 |
+
FairseqEncoder,
|
| 17 |
+
FairseqEncoderModel,
|
| 18 |
+
register_model,
|
| 19 |
+
register_model_architecture,
|
| 20 |
+
)
|
| 21 |
+
from fairseq.modules import LayerNorm, TransformerSentenceEncoder
|
| 22 |
+
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
|
| 23 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
| 24 |
+
|
| 25 |
+
from .hub_interface import RobertaHubInterface
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@register_model("roberta")
|
| 32 |
+
class RobertaModel(FairseqEncoderModel):
|
| 33 |
+
@classmethod
|
| 34 |
+
def hub_models(cls):
|
| 35 |
+
return {
|
| 36 |
+
"roberta.base": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz",
|
| 37 |
+
"roberta.large": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz",
|
| 38 |
+
"roberta.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz",
|
| 39 |
+
"roberta.large.wsc": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz",
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
def __init__(self, args, encoder):
|
| 43 |
+
super().__init__(encoder)
|
| 44 |
+
self.args = args
|
| 45 |
+
|
| 46 |
+
# We follow BERT's random weight initialization
|
| 47 |
+
self.apply(init_bert_params)
|
| 48 |
+
|
| 49 |
+
self.classification_heads = nn.ModuleDict()
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def add_args(parser):
|
| 53 |
+
"""Add model-specific arguments to the parser."""
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--encoder-layers", type=int, metavar="L", help="num encoder layers"
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--encoder-embed-dim",
|
| 59 |
+
type=int,
|
| 60 |
+
metavar="H",
|
| 61 |
+
help="encoder embedding dimension",
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--encoder-ffn-embed-dim",
|
| 65 |
+
type=int,
|
| 66 |
+
metavar="F",
|
| 67 |
+
help="encoder embedding dimension for FFN",
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--encoder-attention-heads",
|
| 71 |
+
type=int,
|
| 72 |
+
metavar="A",
|
| 73 |
+
help="num encoder attention heads",
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--activation-fn",
|
| 77 |
+
choices=utils.get_available_activation_fns(),
|
| 78 |
+
help="activation function to use",
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--pooler-activation-fn",
|
| 82 |
+
choices=utils.get_available_activation_fns(),
|
| 83 |
+
help="activation function to use for pooler layer",
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--encoder-normalize-before",
|
| 87 |
+
action="store_true",
|
| 88 |
+
help="apply layernorm before each encoder block",
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--dropout", type=float, metavar="D", help="dropout probability"
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--attention-dropout",
|
| 95 |
+
type=float,
|
| 96 |
+
metavar="D",
|
| 97 |
+
help="dropout probability for attention weights",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--activation-dropout",
|
| 101 |
+
type=float,
|
| 102 |
+
metavar="D",
|
| 103 |
+
help="dropout probability after activation in FFN",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--pooler-dropout",
|
| 107 |
+
type=float,
|
| 108 |
+
metavar="D",
|
| 109 |
+
help="dropout probability in the masked_lm pooler layers",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--max-positions", type=int, help="number of positional embeddings to learn"
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--load-checkpoint-heads",
|
| 116 |
+
action="store_true",
|
| 117 |
+
help="(re-)register and load heads when loading checkpoints",
|
| 118 |
+
)
|
| 119 |
+
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--encoder-layerdrop",
|
| 122 |
+
type=float,
|
| 123 |
+
metavar="D",
|
| 124 |
+
default=0,
|
| 125 |
+
help="LayerDrop probability for encoder",
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--encoder-layers-to-keep",
|
| 129 |
+
default=None,
|
| 130 |
+
help="which layers to *keep* when pruning as a comma-separated list",
|
| 131 |
+
)
|
| 132 |
+
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--quant-noise-pq",
|
| 135 |
+
type=float,
|
| 136 |
+
metavar="D",
|
| 137 |
+
default=0,
|
| 138 |
+
help="iterative PQ quantization noise at training time",
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--quant-noise-pq-block-size",
|
| 142 |
+
type=int,
|
| 143 |
+
metavar="D",
|
| 144 |
+
default=8,
|
| 145 |
+
help="block size of quantization noise at training time",
|
| 146 |
+
)
|
| 147 |
+
parser.add_argument(
|
| 148 |
+
"--quant-noise-scalar",
|
| 149 |
+
type=float,
|
| 150 |
+
metavar="D",
|
| 151 |
+
default=0,
|
| 152 |
+
help="scalar quantization noise and scalar quantization at training time",
|
| 153 |
+
)
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--untie-weights-roberta",
|
| 156 |
+
action="store_true",
|
| 157 |
+
help="Untie weights between embeddings and classifiers in RoBERTa",
|
| 158 |
+
)
|
| 159 |
+
parser.add_argument(
|
| 160 |
+
"--spectral-norm-classification-head",
|
| 161 |
+
action="store_true",
|
| 162 |
+
default=False,
|
| 163 |
+
help="Apply spectral normalization on the classification head",
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
@classmethod
|
| 167 |
+
def build_model(cls, args, task):
|
| 168 |
+
"""Build a new model instance."""
|
| 169 |
+
|
| 170 |
+
# make sure all arguments are present
|
| 171 |
+
base_architecture(args)
|
| 172 |
+
|
| 173 |
+
if not hasattr(args, "max_positions"):
|
| 174 |
+
args.max_positions = args.tokens_per_sample
|
| 175 |
+
|
| 176 |
+
encoder = RobertaEncoder(args, task.source_dictionary)
|
| 177 |
+
return cls(args, encoder)
|
| 178 |
+
|
| 179 |
+
def forward(
|
| 180 |
+
self,
|
| 181 |
+
src_tokens,
|
| 182 |
+
features_only=False,
|
| 183 |
+
return_all_hiddens=False,
|
| 184 |
+
classification_head_name=None,
|
| 185 |
+
**kwargs
|
| 186 |
+
):
|
| 187 |
+
if classification_head_name is not None:
|
| 188 |
+
features_only = True
|
| 189 |
+
|
| 190 |
+
x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
|
| 191 |
+
|
| 192 |
+
if classification_head_name is not None:
|
| 193 |
+
x = self.classification_heads[classification_head_name](x)
|
| 194 |
+
return x, extra
|
| 195 |
+
|
| 196 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
| 197 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
| 198 |
+
logits = net_output[0].float()
|
| 199 |
+
if log_probs:
|
| 200 |
+
return F.log_softmax(logits, dim=-1)
|
| 201 |
+
else:
|
| 202 |
+
return F.softmax(logits, dim=-1)
|
| 203 |
+
|
| 204 |
+
def register_classification_head(
|
| 205 |
+
self, name, num_classes=None, inner_dim=None, **kwargs
|
| 206 |
+
):
|
| 207 |
+
"""Register a classification head."""
|
| 208 |
+
if name in self.classification_heads:
|
| 209 |
+
prev_num_classes = self.classification_heads[name].out_proj.out_features
|
| 210 |
+
prev_inner_dim = self.classification_heads[name].dense.out_features
|
| 211 |
+
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
|
| 212 |
+
logger.warning(
|
| 213 |
+
're-registering head "{}" with num_classes {} (prev: {}) '
|
| 214 |
+
"and inner_dim {} (prev: {})".format(
|
| 215 |
+
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
|
| 216 |
+
)
|
| 217 |
+
)
|
| 218 |
+
self.classification_heads[name] = RobertaClassificationHead(
|
| 219 |
+
input_dim=self.args.encoder_embed_dim,
|
| 220 |
+
inner_dim=inner_dim or self.args.encoder_embed_dim,
|
| 221 |
+
num_classes=num_classes,
|
| 222 |
+
activation_fn=self.args.pooler_activation_fn,
|
| 223 |
+
pooler_dropout=self.args.pooler_dropout,
|
| 224 |
+
q_noise=self.args.quant_noise_pq,
|
| 225 |
+
qn_block_size=self.args.quant_noise_pq_block_size,
|
| 226 |
+
do_spectral_norm=self.args.spectral_norm_classification_head,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
def supported_targets(self):
|
| 231 |
+
return {"self"}
|
| 232 |
+
|
| 233 |
+
@classmethod
|
| 234 |
+
def from_pretrained(
|
| 235 |
+
cls,
|
| 236 |
+
model_name_or_path,
|
| 237 |
+
checkpoint_file="model.pt",
|
| 238 |
+
data_name_or_path=".",
|
| 239 |
+
bpe="gpt2",
|
| 240 |
+
**kwargs
|
| 241 |
+
):
|
| 242 |
+
from fairseq import hub_utils
|
| 243 |
+
|
| 244 |
+
x = hub_utils.from_pretrained(
|
| 245 |
+
model_name_or_path,
|
| 246 |
+
checkpoint_file,
|
| 247 |
+
data_name_or_path,
|
| 248 |
+
archive_map=cls.hub_models(),
|
| 249 |
+
bpe=bpe,
|
| 250 |
+
load_checkpoint_heads=True,
|
| 251 |
+
**kwargs,
|
| 252 |
+
)
|
| 253 |
+
cls.upgrade_args(x["args"])
|
| 254 |
+
|
| 255 |
+
logger.info(x["args"])
|
| 256 |
+
return RobertaHubInterface(x["args"], x["task"], x["models"][0])
|
| 257 |
+
|
| 258 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 259 |
+
prefix = name + "." if name != "" else ""
|
| 260 |
+
|
| 261 |
+
# rename decoder -> encoder before upgrading children modules
|
| 262 |
+
for k in list(state_dict.keys()):
|
| 263 |
+
if k.startswith(prefix + "decoder"):
|
| 264 |
+
new_k = prefix + "encoder" + k[len(prefix + "decoder") :]
|
| 265 |
+
state_dict[new_k] = state_dict[k]
|
| 266 |
+
del state_dict[k]
|
| 267 |
+
|
| 268 |
+
# upgrade children modules
|
| 269 |
+
super().upgrade_state_dict_named(state_dict, name)
|
| 270 |
+
|
| 271 |
+
# Handle new classification heads present in the state dict.
|
| 272 |
+
current_head_names = (
|
| 273 |
+
[]
|
| 274 |
+
if not hasattr(self, "classification_heads")
|
| 275 |
+
else self.classification_heads.keys()
|
| 276 |
+
)
|
| 277 |
+
keys_to_delete = []
|
| 278 |
+
for k in state_dict.keys():
|
| 279 |
+
if not k.startswith(prefix + "classification_heads."):
|
| 280 |
+
continue
|
| 281 |
+
|
| 282 |
+
head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
|
| 283 |
+
num_classes = state_dict[
|
| 284 |
+
prefix + "classification_heads." + head_name + ".out_proj.weight"
|
| 285 |
+
].size(0)
|
| 286 |
+
inner_dim = state_dict[
|
| 287 |
+
prefix + "classification_heads." + head_name + ".dense.weight"
|
| 288 |
+
].size(0)
|
| 289 |
+
|
| 290 |
+
if getattr(self.args, "load_checkpoint_heads", False):
|
| 291 |
+
if head_name not in current_head_names:
|
| 292 |
+
self.register_classification_head(head_name, num_classes, inner_dim)
|
| 293 |
+
else:
|
| 294 |
+
if head_name not in current_head_names:
|
| 295 |
+
logger.warning(
|
| 296 |
+
"deleting classification head ({}) from checkpoint "
|
| 297 |
+
"not present in current model: {}".format(head_name, k)
|
| 298 |
+
)
|
| 299 |
+
keys_to_delete.append(k)
|
| 300 |
+
elif (
|
| 301 |
+
num_classes
|
| 302 |
+
!= self.classification_heads[head_name].out_proj.out_features
|
| 303 |
+
or inner_dim
|
| 304 |
+
!= self.classification_heads[head_name].dense.out_features
|
| 305 |
+
):
|
| 306 |
+
logger.warning(
|
| 307 |
+
"deleting classification head ({}) from checkpoint "
|
| 308 |
+
"with different dimensions than current model: {}".format(
|
| 309 |
+
head_name, k
|
| 310 |
+
)
|
| 311 |
+
)
|
| 312 |
+
keys_to_delete.append(k)
|
| 313 |
+
for k in keys_to_delete:
|
| 314 |
+
del state_dict[k]
|
| 315 |
+
|
| 316 |
+
# Copy any newly-added classification heads into the state dict
|
| 317 |
+
# with their current weights.
|
| 318 |
+
if hasattr(self, "classification_heads"):
|
| 319 |
+
cur_state = self.classification_heads.state_dict()
|
| 320 |
+
for k, v in cur_state.items():
|
| 321 |
+
if prefix + "classification_heads." + k not in state_dict:
|
| 322 |
+
logger.info("Overwriting " + prefix + "classification_heads." + k)
|
| 323 |
+
state_dict[prefix + "classification_heads." + k] = v
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class RobertaLMHead(nn.Module):
|
| 327 |
+
"""Head for masked language modeling."""
|
| 328 |
+
|
| 329 |
+
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
|
| 330 |
+
super().__init__()
|
| 331 |
+
self.dense = nn.Linear(embed_dim, embed_dim)
|
| 332 |
+
self.activation_fn = utils.get_activation_fn(activation_fn)
|
| 333 |
+
self.layer_norm = LayerNorm(embed_dim)
|
| 334 |
+
|
| 335 |
+
if weight is None:
|
| 336 |
+
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
|
| 337 |
+
self.weight = weight
|
| 338 |
+
self.bias = nn.Parameter(torch.zeros(output_dim))
|
| 339 |
+
|
| 340 |
+
def forward(self, features, masked_tokens=None, **kwargs):
|
| 341 |
+
# Only project the masked tokens while training,
|
| 342 |
+
# saves both memory and computation
|
| 343 |
+
if masked_tokens is not None:
|
| 344 |
+
features = features[masked_tokens, :]
|
| 345 |
+
|
| 346 |
+
x = self.dense(features)
|
| 347 |
+
x = self.activation_fn(x)
|
| 348 |
+
x = self.layer_norm(x)
|
| 349 |
+
# project back to size of vocabulary with bias
|
| 350 |
+
x = F.linear(x, self.weight) + self.bias
|
| 351 |
+
return x
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class RobertaClassificationHead(nn.Module):
|
| 355 |
+
"""Head for sentence-level classification tasks."""
|
| 356 |
+
|
| 357 |
+
def __init__(
|
| 358 |
+
self,
|
| 359 |
+
input_dim,
|
| 360 |
+
inner_dim,
|
| 361 |
+
num_classes,
|
| 362 |
+
activation_fn,
|
| 363 |
+
pooler_dropout,
|
| 364 |
+
q_noise=0,
|
| 365 |
+
qn_block_size=8,
|
| 366 |
+
do_spectral_norm=False,
|
| 367 |
+
):
|
| 368 |
+
super().__init__()
|
| 369 |
+
self.dense = nn.Linear(input_dim, inner_dim)
|
| 370 |
+
self.activation_fn = utils.get_activation_fn(activation_fn)
|
| 371 |
+
self.dropout = nn.Dropout(p=pooler_dropout)
|
| 372 |
+
self.out_proj = apply_quant_noise_(
|
| 373 |
+
nn.Linear(inner_dim, num_classes), q_noise, qn_block_size
|
| 374 |
+
)
|
| 375 |
+
if do_spectral_norm:
|
| 376 |
+
if q_noise != 0:
|
| 377 |
+
raise NotImplementedError(
|
| 378 |
+
"Attempting to use Spectral Normalization with Quant Noise. This is not officially supported"
|
| 379 |
+
)
|
| 380 |
+
self.out_proj = torch.nn.utils.spectral_norm(self.out_proj)
|
| 381 |
+
|
| 382 |
+
def forward(self, features, **kwargs):
|
| 383 |
+
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
| 384 |
+
x = self.dropout(x)
|
| 385 |
+
x = self.dense(x)
|
| 386 |
+
x = self.activation_fn(x)
|
| 387 |
+
x = self.dropout(x)
|
| 388 |
+
x = self.out_proj(x)
|
| 389 |
+
return x
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class RobertaEncoder(FairseqEncoder):
|
| 393 |
+
"""RoBERTa encoder."""
|
| 394 |
+
|
| 395 |
+
def __init__(self, args, dictionary):
|
| 396 |
+
super().__init__(dictionary)
|
| 397 |
+
self.args = args
|
| 398 |
+
|
| 399 |
+
if args.encoder_layers_to_keep:
|
| 400 |
+
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
|
| 401 |
+
|
| 402 |
+
self.sentence_encoder = TransformerSentenceEncoder(
|
| 403 |
+
padding_idx=dictionary.pad(),
|
| 404 |
+
vocab_size=len(dictionary),
|
| 405 |
+
num_encoder_layers=args.encoder_layers,
|
| 406 |
+
embedding_dim=args.encoder_embed_dim,
|
| 407 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
| 408 |
+
num_attention_heads=args.encoder_attention_heads,
|
| 409 |
+
dropout=args.dropout,
|
| 410 |
+
attention_dropout=args.attention_dropout,
|
| 411 |
+
activation_dropout=args.activation_dropout,
|
| 412 |
+
layerdrop=args.encoder_layerdrop,
|
| 413 |
+
max_seq_len=args.max_positions,
|
| 414 |
+
num_segments=0,
|
| 415 |
+
encoder_normalize_before=True,
|
| 416 |
+
apply_bert_init=True,
|
| 417 |
+
activation_fn=args.activation_fn,
|
| 418 |
+
q_noise=args.quant_noise_pq,
|
| 419 |
+
qn_block_size=args.quant_noise_pq_block_size,
|
| 420 |
+
)
|
| 421 |
+
args.untie_weights_roberta = getattr(args, "untie_weights_roberta", False)
|
| 422 |
+
|
| 423 |
+
self.lm_head = RobertaLMHead(
|
| 424 |
+
embed_dim=args.encoder_embed_dim,
|
| 425 |
+
output_dim=len(dictionary),
|
| 426 |
+
activation_fn=args.activation_fn,
|
| 427 |
+
weight=(
|
| 428 |
+
self.sentence_encoder.embed_tokens.weight
|
| 429 |
+
if not args.untie_weights_roberta
|
| 430 |
+
else None
|
| 431 |
+
),
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
def forward(
|
| 435 |
+
self,
|
| 436 |
+
src_tokens,
|
| 437 |
+
features_only=False,
|
| 438 |
+
return_all_hiddens=False,
|
| 439 |
+
masked_tokens=None,
|
| 440 |
+
**unused
|
| 441 |
+
):
|
| 442 |
+
"""
|
| 443 |
+
Args:
|
| 444 |
+
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
|
| 445 |
+
features_only (bool, optional): skip LM head and just return
|
| 446 |
+
features. If True, the output will be of shape
|
| 447 |
+
`(batch, src_len, embed_dim)`.
|
| 448 |
+
return_all_hiddens (bool, optional): also return all of the
|
| 449 |
+
intermediate hidden states (default: False).
|
| 450 |
+
|
| 451 |
+
Returns:
|
| 452 |
+
tuple:
|
| 453 |
+
- the LM output of shape `(batch, src_len, vocab)`
|
| 454 |
+
- a dictionary of additional data, where 'inner_states'
|
| 455 |
+
is a list of hidden states. Note that the hidden
|
| 456 |
+
states have shape `(src_len, batch, vocab)`.
|
| 457 |
+
"""
|
| 458 |
+
x, extra = self.extract_features(
|
| 459 |
+
src_tokens, return_all_hiddens=return_all_hiddens
|
| 460 |
+
)
|
| 461 |
+
if not features_only:
|
| 462 |
+
x = self.output_layer(x, masked_tokens=masked_tokens)
|
| 463 |
+
return x, extra
|
| 464 |
+
|
| 465 |
+
def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs):
|
| 466 |
+
inner_states, _ = self.sentence_encoder(
|
| 467 |
+
src_tokens,
|
| 468 |
+
last_state_only=not return_all_hiddens,
|
| 469 |
+
token_embeddings=kwargs.get("token_embeddings", None),
|
| 470 |
+
)
|
| 471 |
+
features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C
|
| 472 |
+
return features, {"inner_states": inner_states if return_all_hiddens else None}
|
| 473 |
+
|
| 474 |
+
def output_layer(self, features, masked_tokens=None, **unused):
|
| 475 |
+
return self.lm_head(features, masked_tokens)
|
| 476 |
+
|
| 477 |
+
def max_positions(self):
|
| 478 |
+
"""Maximum output length supported by the encoder."""
|
| 479 |
+
return self.args.max_positions
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
@register_model_architecture("roberta", "roberta")
|
| 483 |
+
def base_architecture(args):
|
| 484 |
+
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
| 485 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
| 486 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
|
| 487 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
| 488 |
+
|
| 489 |
+
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
| 490 |
+
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
| 491 |
+
|
| 492 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 493 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
| 494 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
| 495 |
+
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
| 496 |
+
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
|
| 497 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
|
| 498 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
|
| 499 |
+
args.spectral_norm_classification_head = getattr(
|
| 500 |
+
args, "spectral_nrom_classification_head", False
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
@register_model_architecture("roberta", "roberta_base")
|
| 505 |
+
def roberta_base_architecture(args):
|
| 506 |
+
base_architecture(args)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
@register_model_architecture("roberta", "roberta_large")
|
| 510 |
+
def roberta_large_architecture(args):
|
| 511 |
+
args.encoder_layers = getattr(args, "encoder_layers", 24)
|
| 512 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
| 513 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
| 514 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
| 515 |
+
base_architecture(args)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
@register_model_architecture("roberta", "xlm")
|
| 519 |
+
def xlm_architecture(args):
|
| 520 |
+
args.encoder_layers = getattr(args, "encoder_layers", 16)
|
| 521 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1280)
|
| 522 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1280 * 4)
|
| 523 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
| 524 |
+
base_architecture(args)
|
fairseq-0.10.2/fairseq/models/roberta/model_camembert.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""
|
| 6 |
+
CamemBERT: a Tasty French Language Model
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from fairseq.models import register_model
|
| 10 |
+
|
| 11 |
+
from .hub_interface import RobertaHubInterface
|
| 12 |
+
from .model import RobertaModel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@register_model("camembert")
|
| 16 |
+
class CamembertModel(RobertaModel):
|
| 17 |
+
@classmethod
|
| 18 |
+
def hub_models(cls):
|
| 19 |
+
return {
|
| 20 |
+
"camembert": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz",
|
| 21 |
+
"camembert.v0": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz",
|
| 22 |
+
"camembert-base": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz",
|
| 23 |
+
"camembert-large": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-large.tar.gz",
|
| 24 |
+
"camembert-base-ccnet": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet.tar.gz",
|
| 25 |
+
"camembert-base-ccnet-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet-4gb.tar.gz",
|
| 26 |
+
"camembert-base-wikipedia-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-wikipedia-4gb.tar.gz",
|
| 27 |
+
"camembert-base-oscar-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-oscar-4gb.tar.gz",
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def from_pretrained(
|
| 32 |
+
cls,
|
| 33 |
+
model_name_or_path,
|
| 34 |
+
checkpoint_file="model.pt",
|
| 35 |
+
data_name_or_path=".",
|
| 36 |
+
bpe="sentencepiece",
|
| 37 |
+
**kwargs
|
| 38 |
+
):
|
| 39 |
+
from fairseq import hub_utils
|
| 40 |
+
|
| 41 |
+
x = hub_utils.from_pretrained(
|
| 42 |
+
model_name_or_path,
|
| 43 |
+
checkpoint_file,
|
| 44 |
+
data_name_or_path,
|
| 45 |
+
archive_map=cls.hub_models(),
|
| 46 |
+
bpe=bpe,
|
| 47 |
+
load_checkpoint_heads=True,
|
| 48 |
+
**kwargs,
|
| 49 |
+
)
|
| 50 |
+
return RobertaHubInterface(x["args"], x["task"], x["models"][0])
|
fairseq-0.10.2/fairseq/models/wav2vec/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .wav2vec import * # noqa
|
| 7 |
+
from .wav2vec2 import * # noqa
|
| 8 |
+
from .wav2vec2_asr import * # noqa
|
fairseq-0.10.2/fairseq/models/wav2vec/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (239 Bytes). View file
|
|
|