Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq-0.10.2/fairseq/__pycache__/pdb.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/hub_utils.py +294 -0
- fairseq-0.10.2/fairseq/iterative_refinement_generator.py +359 -0
- fairseq-0.10.2/fairseq/legacy_distributed_data_parallel.py +171 -0
- fairseq-0.10.2/fairseq/model_parallel/__pycache__/megatron_trainer.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/model_parallel/criterions/__pycache__/vocab_parallel_cross_entropy.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/model_parallel/megatron_trainer.py +66 -0
- fairseq-0.10.2/fairseq/model_parallel/models/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py +6 -0
- fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py +600 -0
- fairseq-0.10.2/fairseq/model_parallel/models/roberta/model.py +287 -0
- fairseq-0.10.2/fairseq/model_parallel/modules/__pycache__/transformer_sentence_encoder_layer.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py +77 -0
- fairseq-0.10.2/fairseq/modules/__init__.py +76 -0
- fairseq-0.10.2/fairseq/modules/__pycache__/adaptive_softmax.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/__pycache__/dynamic_convolution.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/__pycache__/grad_multiply.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/__pycache__/gumbel_vector_quantizer.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/__pycache__/kmeans_vector_quantizer.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/__pycache__/learned_positional_embedding.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/__pycache__/lightweight_convolution.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/__pycache__/transformer_sentence_encoder_layer.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/__pycache__/vggblock.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/adaptive_softmax.py +268 -0
- fairseq-0.10.2/fairseq/modules/beamable_mm.py +49 -0
- fairseq-0.10.2/fairseq/modules/character_token_embedder.py +214 -0
- fairseq-0.10.2/fairseq/modules/cross_entropy.py +59 -0
- fairseq-0.10.2/fairseq/modules/fp32_group_norm.py +25 -0
- fairseq-0.10.2/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu +375 -0
- fairseq-0.10.2/fairseq/modules/linearized_convolution.py +104 -0
- fairseq-0.10.2/fairseq/modules/multihead_attention.py +488 -0
- fairseq-0.10.2/fairseq/modules/positional_embedding.py +35 -0
- fairseq-0.10.2/fairseq/modules/quant_noise.py +107 -0
- fairseq-0.10.2/fairseq/modules/quantization/__pycache__/quantization_options.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/quantization/pq/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/quantization/pq/__pycache__/pq.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/quantization/pq/__pycache__/utils.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__init__.py +8 -0
- fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__pycache__/qconv.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__pycache__/qemb.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/quantization/pq/modules/qemb.py +107 -0
- fairseq-0.10.2/fairseq/modules/quantization/pq/modules/qlinear.py +71 -0
- fairseq-0.10.2/fairseq/modules/quantization/pq/pq.py +128 -0
- fairseq-0.10.2/fairseq/modules/quantization/quantization_options.py +44 -0
- fairseq-0.10.2/fairseq/modules/quantization/scalar/__pycache__/ops.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/quantization/scalar/modules/__pycache__/qemb.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq/modules/quantization/scalar/modules/qemb.py +147 -0
- fairseq-0.10.2/fairseq/modules/sparse_transformer_sentence_encoder.py +96 -0
fairseq-0.10.2/fairseq/__pycache__/pdb.cpython-310.pyc
ADDED
|
Binary file (1.33 kB). View file
|
|
|
fairseq-0.10.2/fairseq/hub_utils.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3 -u
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import copy
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
from typing import Any, Dict, Iterator, List, Tuple
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from fairseq import utils
|
| 15 |
+
from fairseq.data import encoders
|
| 16 |
+
from torch import nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def from_pretrained(
|
| 23 |
+
model_name_or_path,
|
| 24 |
+
checkpoint_file="model.pt",
|
| 25 |
+
data_name_or_path=".",
|
| 26 |
+
archive_map=None,
|
| 27 |
+
**kwargs
|
| 28 |
+
):
|
| 29 |
+
from fairseq import checkpoint_utils, file_utils
|
| 30 |
+
|
| 31 |
+
if archive_map is not None:
|
| 32 |
+
if model_name_or_path in archive_map:
|
| 33 |
+
model_name_or_path = archive_map[model_name_or_path]
|
| 34 |
+
if data_name_or_path is not None and data_name_or_path in archive_map:
|
| 35 |
+
data_name_or_path = archive_map[data_name_or_path]
|
| 36 |
+
|
| 37 |
+
# allow archive_map to set default arg_overrides (e.g., tokenizer, bpe)
|
| 38 |
+
# for each model
|
| 39 |
+
if isinstance(model_name_or_path, dict):
|
| 40 |
+
for k, v in model_name_or_path.items():
|
| 41 |
+
if k == "checkpoint_file":
|
| 42 |
+
checkpoint_file = v
|
| 43 |
+
elif (
|
| 44 |
+
k != "path"
|
| 45 |
+
# only set kwargs that don't already have overrides
|
| 46 |
+
and k not in kwargs
|
| 47 |
+
):
|
| 48 |
+
kwargs[k] = v
|
| 49 |
+
model_name_or_path = model_name_or_path["path"]
|
| 50 |
+
|
| 51 |
+
model_path = file_utils.load_archive_file(model_name_or_path)
|
| 52 |
+
|
| 53 |
+
# convenience hack for loading data and BPE codes from model archive
|
| 54 |
+
if data_name_or_path.startswith("."):
|
| 55 |
+
kwargs["data"] = os.path.abspath(os.path.join(model_path, data_name_or_path))
|
| 56 |
+
else:
|
| 57 |
+
kwargs["data"] = file_utils.load_archive_file(data_name_or_path)
|
| 58 |
+
for file, arg in {
|
| 59 |
+
"code": "bpe_codes",
|
| 60 |
+
"bpecodes": "bpe_codes",
|
| 61 |
+
"sentencepiece.bpe.model": "sentencepiece_model",
|
| 62 |
+
}.items():
|
| 63 |
+
path = os.path.join(model_path, file)
|
| 64 |
+
if os.path.exists(path):
|
| 65 |
+
kwargs[arg] = path
|
| 66 |
+
|
| 67 |
+
if "user_dir" in kwargs:
|
| 68 |
+
utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"]))
|
| 69 |
+
|
| 70 |
+
models, args, task = checkpoint_utils.load_model_ensemble_and_task(
|
| 71 |
+
[os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)],
|
| 72 |
+
arg_overrides=kwargs,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
return {
|
| 76 |
+
"args": args,
|
| 77 |
+
"task": task,
|
| 78 |
+
"models": models,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class GeneratorHubInterface(nn.Module):
|
| 83 |
+
"""
|
| 84 |
+
PyTorch Hub interface for generating sequences from a pre-trained
|
| 85 |
+
translation or language model.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, args, task, models):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.args = args
|
| 91 |
+
self.task = task
|
| 92 |
+
self.models = nn.ModuleList(models)
|
| 93 |
+
self.src_dict = task.source_dictionary
|
| 94 |
+
self.tgt_dict = task.target_dictionary
|
| 95 |
+
|
| 96 |
+
# optimize model for generation
|
| 97 |
+
for model in self.models:
|
| 98 |
+
model.prepare_for_inference_(args)
|
| 99 |
+
|
| 100 |
+
# Load alignment dictionary for unknown word replacement
|
| 101 |
+
# (None if no unknown word replacement, empty if no path to align dictionary)
|
| 102 |
+
self.align_dict = utils.load_align_dict(getattr(args, "replace_unk", None))
|
| 103 |
+
|
| 104 |
+
self.tokenizer = encoders.build_tokenizer(args)
|
| 105 |
+
self.bpe = encoders.build_bpe(args)
|
| 106 |
+
|
| 107 |
+
self.max_positions = utils.resolve_max_positions(
|
| 108 |
+
self.task.max_positions(), *[model.max_positions() for model in models]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# this is useful for determining the device
|
| 112 |
+
self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def device(self):
|
| 116 |
+
return self._float_tensor.device
|
| 117 |
+
|
| 118 |
+
def translate(
|
| 119 |
+
self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs
|
| 120 |
+
) -> List[str]:
|
| 121 |
+
return self.sample(sentences, beam, verbose, **kwargs)
|
| 122 |
+
|
| 123 |
+
def sample(
|
| 124 |
+
self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs
|
| 125 |
+
) -> List[str]:
|
| 126 |
+
if isinstance(sentences, str):
|
| 127 |
+
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
|
| 128 |
+
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
|
| 129 |
+
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
|
| 130 |
+
return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos]
|
| 131 |
+
|
| 132 |
+
def score(self, sentences: List[str], **kwargs):
|
| 133 |
+
if isinstance(sentences, str):
|
| 134 |
+
return self.score([sentences], **kwargs)[0]
|
| 135 |
+
# NOTE: this doesn't support translation tasks currently
|
| 136 |
+
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
|
| 137 |
+
return [
|
| 138 |
+
hypos[0]
|
| 139 |
+
for hypos in self.generate(
|
| 140 |
+
tokenized_sentences, score_reference=True, **kwargs
|
| 141 |
+
)
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
def generate(
|
| 145 |
+
self,
|
| 146 |
+
tokenized_sentences: List[torch.LongTensor],
|
| 147 |
+
beam: int = 5,
|
| 148 |
+
verbose: bool = False,
|
| 149 |
+
skip_invalid_size_inputs=False,
|
| 150 |
+
inference_step_args=None,
|
| 151 |
+
**kwargs
|
| 152 |
+
) -> List[List[Dict[str, torch.Tensor]]]:
|
| 153 |
+
if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1:
|
| 154 |
+
return self.generate(
|
| 155 |
+
tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs
|
| 156 |
+
)[0]
|
| 157 |
+
|
| 158 |
+
# build generator using current args as well as any kwargs
|
| 159 |
+
gen_args = copy.copy(self.args)
|
| 160 |
+
gen_args.beam = beam
|
| 161 |
+
for k, v in kwargs.items():
|
| 162 |
+
setattr(gen_args, k, v)
|
| 163 |
+
generator = self.task.build_generator(self.models, gen_args)
|
| 164 |
+
|
| 165 |
+
inference_step_args = inference_step_args or {}
|
| 166 |
+
results = []
|
| 167 |
+
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
|
| 168 |
+
batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
|
| 169 |
+
translations = self.task.inference_step(
|
| 170 |
+
generator, self.models, batch, **inference_step_args
|
| 171 |
+
)
|
| 172 |
+
for id, hypos in zip(batch["id"].tolist(), translations):
|
| 173 |
+
results.append((id, hypos))
|
| 174 |
+
|
| 175 |
+
# sort output to match input order
|
| 176 |
+
outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
|
| 177 |
+
|
| 178 |
+
if verbose:
|
| 179 |
+
|
| 180 |
+
def getarg(name, default):
|
| 181 |
+
return getattr(gen_args, name, getattr(self.args, name, default))
|
| 182 |
+
|
| 183 |
+
for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
|
| 184 |
+
src_str_with_unk = self.string(source_tokens)
|
| 185 |
+
logger.info("S\t{}".format(src_str_with_unk))
|
| 186 |
+
for hypo in target_hypotheses:
|
| 187 |
+
hypo_str = self.decode(hypo["tokens"])
|
| 188 |
+
logger.info("H\t{}\t{}".format(hypo["score"], hypo_str))
|
| 189 |
+
logger.info(
|
| 190 |
+
"P\t{}".format(
|
| 191 |
+
" ".join(
|
| 192 |
+
map(
|
| 193 |
+
lambda x: "{:.4f}".format(x),
|
| 194 |
+
hypo["positional_scores"].tolist(),
|
| 195 |
+
)
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
)
|
| 199 |
+
if hypo["alignment"] is not None and getarg(
|
| 200 |
+
"print_alignment", False
|
| 201 |
+
):
|
| 202 |
+
logger.info(
|
| 203 |
+
"A\t{}".format(
|
| 204 |
+
" ".join(
|
| 205 |
+
[
|
| 206 |
+
"{}-{}".format(src_idx, tgt_idx)
|
| 207 |
+
for src_idx, tgt_idx in hypo["alignment"]
|
| 208 |
+
]
|
| 209 |
+
)
|
| 210 |
+
)
|
| 211 |
+
)
|
| 212 |
+
return outputs
|
| 213 |
+
|
| 214 |
+
def encode(self, sentence: str) -> torch.LongTensor:
|
| 215 |
+
sentence = self.tokenize(sentence)
|
| 216 |
+
sentence = self.apply_bpe(sentence)
|
| 217 |
+
return self.binarize(sentence)
|
| 218 |
+
|
| 219 |
+
def decode(self, tokens: torch.LongTensor) -> str:
|
| 220 |
+
sentence = self.string(tokens)
|
| 221 |
+
sentence = self.remove_bpe(sentence)
|
| 222 |
+
return self.detokenize(sentence)
|
| 223 |
+
|
| 224 |
+
def tokenize(self, sentence: str) -> str:
|
| 225 |
+
if self.tokenizer is not None:
|
| 226 |
+
sentence = self.tokenizer.encode(sentence)
|
| 227 |
+
return sentence
|
| 228 |
+
|
| 229 |
+
def detokenize(self, sentence: str) -> str:
|
| 230 |
+
if self.tokenizer is not None:
|
| 231 |
+
sentence = self.tokenizer.decode(sentence)
|
| 232 |
+
return sentence
|
| 233 |
+
|
| 234 |
+
def apply_bpe(self, sentence: str) -> str:
|
| 235 |
+
if self.bpe is not None:
|
| 236 |
+
sentence = self.bpe.encode(sentence)
|
| 237 |
+
return sentence
|
| 238 |
+
|
| 239 |
+
def remove_bpe(self, sentence: str) -> str:
|
| 240 |
+
if self.bpe is not None:
|
| 241 |
+
sentence = self.bpe.decode(sentence)
|
| 242 |
+
return sentence
|
| 243 |
+
|
| 244 |
+
def binarize(self, sentence: str) -> torch.LongTensor:
|
| 245 |
+
return self.src_dict.encode_line(sentence, add_if_not_exist=False).long()
|
| 246 |
+
|
| 247 |
+
def string(self, tokens: torch.LongTensor) -> str:
|
| 248 |
+
return self.tgt_dict.string(tokens)
|
| 249 |
+
|
| 250 |
+
def _build_batches(
|
| 251 |
+
self, tokens: List[List[int]], skip_invalid_size_inputs: bool
|
| 252 |
+
) -> Iterator[Dict[str, Any]]:
|
| 253 |
+
lengths = torch.LongTensor([t.numel() for t in tokens])
|
| 254 |
+
batch_iterator = self.task.get_batch_iterator(
|
| 255 |
+
dataset=self.task.build_dataset_for_inference(tokens, lengths),
|
| 256 |
+
max_tokens=self.args.max_tokens,
|
| 257 |
+
max_sentences=self.args.batch_size,
|
| 258 |
+
max_positions=self.max_positions,
|
| 259 |
+
ignore_invalid_inputs=skip_invalid_size_inputs,
|
| 260 |
+
disable_iterator_cache=True,
|
| 261 |
+
).next_epoch_itr(shuffle=False)
|
| 262 |
+
return batch_iterator
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class BPEHubInterface(object):
|
| 266 |
+
"""PyTorch Hub interface for Byte-Pair Encoding (BPE)."""
|
| 267 |
+
|
| 268 |
+
def __init__(self, bpe, **kwargs):
|
| 269 |
+
super().__init__()
|
| 270 |
+
args = argparse.Namespace(bpe=bpe, **kwargs)
|
| 271 |
+
self.bpe = encoders.build_bpe(args)
|
| 272 |
+
assert self.bpe is not None
|
| 273 |
+
|
| 274 |
+
def encode(self, sentence: str) -> str:
|
| 275 |
+
return self.bpe.encode(sentence)
|
| 276 |
+
|
| 277 |
+
def decode(self, sentence: str) -> str:
|
| 278 |
+
return self.bpe.decode(sentence)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class TokenizerHubInterface(object):
|
| 282 |
+
"""PyTorch Hub interface for tokenization."""
|
| 283 |
+
|
| 284 |
+
def __init__(self, tokenizer, **kwargs):
|
| 285 |
+
super().__init__()
|
| 286 |
+
args = argparse.Namespace(tokenizer=tokenizer, **kwargs)
|
| 287 |
+
self.tokenizer = encoders.build_tokenizer(args)
|
| 288 |
+
assert self.tokenizer is not None
|
| 289 |
+
|
| 290 |
+
def encode(self, sentence: str) -> str:
|
| 291 |
+
return self.tokenizer.encode(sentence)
|
| 292 |
+
|
| 293 |
+
def decode(self, sentence: str) -> str:
|
| 294 |
+
return self.tokenizer.decode(sentence)
|
fairseq-0.10.2/fairseq/iterative_refinement_generator.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import namedtuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq import utils
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
DecoderOut = namedtuple(
|
| 14 |
+
"IterativeRefinementDecoderOut",
|
| 15 |
+
["output_tokens", "output_scores", "attn", "step", "max_step", "history"],
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class IterativeRefinementGenerator(object):
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
tgt_dict,
|
| 23 |
+
models=None,
|
| 24 |
+
eos_penalty=0.0,
|
| 25 |
+
max_iter=10,
|
| 26 |
+
max_ratio=2,
|
| 27 |
+
beam_size=1,
|
| 28 |
+
decoding_format=None,
|
| 29 |
+
retain_dropout=False,
|
| 30 |
+
adaptive=True,
|
| 31 |
+
retain_history=False,
|
| 32 |
+
reranking=False,
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Generates translations based on iterative refinement.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
tgt_dict: target dictionary
|
| 39 |
+
eos_penalty: if > 0.0, it penalized early-stopping in decoding
|
| 40 |
+
max_iter: maximum number of refinement iterations
|
| 41 |
+
max_ratio: generate sequences of maximum length ax, where x is the source length
|
| 42 |
+
decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'}
|
| 43 |
+
retain_dropout: retaining dropout in the inference
|
| 44 |
+
adaptive: decoding with early stop
|
| 45 |
+
"""
|
| 46 |
+
self.bos = tgt_dict.bos()
|
| 47 |
+
self.pad = tgt_dict.pad()
|
| 48 |
+
self.unk = tgt_dict.unk()
|
| 49 |
+
self.eos = tgt_dict.eos()
|
| 50 |
+
self.vocab_size = len(tgt_dict)
|
| 51 |
+
self.eos_penalty = eos_penalty
|
| 52 |
+
self.max_iter = max_iter
|
| 53 |
+
self.max_ratio = max_ratio
|
| 54 |
+
self.beam_size = beam_size
|
| 55 |
+
self.reranking = reranking
|
| 56 |
+
self.decoding_format = decoding_format
|
| 57 |
+
self.retain_dropout = retain_dropout
|
| 58 |
+
self.retain_history = retain_history
|
| 59 |
+
self.adaptive = adaptive
|
| 60 |
+
self.models = models
|
| 61 |
+
|
| 62 |
+
def generate_batched_itr(
|
| 63 |
+
self,
|
| 64 |
+
data_itr,
|
| 65 |
+
maxlen_a=None,
|
| 66 |
+
maxlen_b=None,
|
| 67 |
+
cuda=False,
|
| 68 |
+
timer=None,
|
| 69 |
+
prefix_size=0,
|
| 70 |
+
):
|
| 71 |
+
"""Iterate over a batched dataset and yield individual translations.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
maxlen_a/b: generate sequences of maximum length ax + b,
|
| 75 |
+
where x is the source sentence length.
|
| 76 |
+
cuda: use GPU for generation
|
| 77 |
+
timer: StopwatchMeter for timing generations.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
for sample in data_itr:
|
| 81 |
+
if "net_input" not in sample:
|
| 82 |
+
continue
|
| 83 |
+
if timer is not None:
|
| 84 |
+
timer.start()
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
hypos = self.generate(
|
| 87 |
+
self.models,
|
| 88 |
+
sample,
|
| 89 |
+
prefix_tokens=sample["target"][:, :prefix_size]
|
| 90 |
+
if prefix_size > 0
|
| 91 |
+
else None,
|
| 92 |
+
)
|
| 93 |
+
if timer is not None:
|
| 94 |
+
timer.stop(sample["ntokens"])
|
| 95 |
+
for i, id in enumerate(sample["id"]):
|
| 96 |
+
# remove padding
|
| 97 |
+
src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad)
|
| 98 |
+
ref = utils.strip_pad(sample["target"][i, :], self.pad)
|
| 99 |
+
yield id, src, ref, hypos[i]
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def generate(self, models, sample, prefix_tokens=None, constraints=None):
|
| 103 |
+
if constraints is not None:
|
| 104 |
+
raise NotImplementedError(
|
| 105 |
+
"Constrained decoding with the IterativeRefinementGenerator is not supported"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# TODO: iterative refinement generator does not support ensemble for now.
|
| 109 |
+
if not self.retain_dropout:
|
| 110 |
+
for model in models:
|
| 111 |
+
model.eval()
|
| 112 |
+
|
| 113 |
+
model, reranker = models[0], None
|
| 114 |
+
if self.reranking:
|
| 115 |
+
assert len(models) > 1, "Assuming the last checkpoint is the reranker"
|
| 116 |
+
assert (
|
| 117 |
+
self.beam_size > 1
|
| 118 |
+
), "Reranking requires multiple translation for each example"
|
| 119 |
+
|
| 120 |
+
reranker = models[-1]
|
| 121 |
+
models = models[:-1]
|
| 122 |
+
|
| 123 |
+
if len(models) > 1 and hasattr(model, "enable_ensemble"):
|
| 124 |
+
assert model.allow_ensemble, "{} does not support ensembling".format(
|
| 125 |
+
model.__class__.__name__
|
| 126 |
+
)
|
| 127 |
+
model.enable_ensemble(models)
|
| 128 |
+
|
| 129 |
+
# TODO: better encoder inputs?
|
| 130 |
+
src_tokens = sample["net_input"]["src_tokens"]
|
| 131 |
+
src_lengths = sample["net_input"]["src_lengths"]
|
| 132 |
+
bsz, src_len = src_tokens.size()
|
| 133 |
+
|
| 134 |
+
# initialize
|
| 135 |
+
encoder_out = model.forward_encoder([src_tokens, src_lengths])
|
| 136 |
+
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
|
| 137 |
+
|
| 138 |
+
if self.beam_size > 1:
|
| 139 |
+
assert (
|
| 140 |
+
model.allow_length_beam
|
| 141 |
+
), "{} does not support decoding with length beam.".format(
|
| 142 |
+
model.__class__.__name__
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# regenerate data based on length-beam
|
| 146 |
+
length_beam_order = (
|
| 147 |
+
utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1)
|
| 148 |
+
)
|
| 149 |
+
encoder_out = model.encoder.reorder_encoder_out(
|
| 150 |
+
encoder_out, length_beam_order
|
| 151 |
+
)
|
| 152 |
+
prev_decoder_out = model.regenerate_length_beam(
|
| 153 |
+
prev_decoder_out, self.beam_size
|
| 154 |
+
)
|
| 155 |
+
bsz = bsz * self.beam_size
|
| 156 |
+
|
| 157 |
+
sent_idxs = torch.arange(bsz)
|
| 158 |
+
prev_output_tokens = prev_decoder_out.output_tokens.clone()
|
| 159 |
+
|
| 160 |
+
if self.retain_history:
|
| 161 |
+
prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens])
|
| 162 |
+
|
| 163 |
+
finalized = [[] for _ in range(bsz)]
|
| 164 |
+
|
| 165 |
+
def is_a_loop(x, y, s, a):
|
| 166 |
+
b, l_x, l_y = x.size(0), x.size(1), y.size(1)
|
| 167 |
+
if l_x > l_y:
|
| 168 |
+
y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1)
|
| 169 |
+
s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1)
|
| 170 |
+
if a is not None:
|
| 171 |
+
a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1)
|
| 172 |
+
elif l_x < l_y:
|
| 173 |
+
x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1)
|
| 174 |
+
return (x == y).all(1), y, s, a
|
| 175 |
+
|
| 176 |
+
def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
|
| 177 |
+
cutoff = prev_out_token.ne(self.pad)
|
| 178 |
+
tokens = prev_out_token[cutoff]
|
| 179 |
+
if prev_out_score is None:
|
| 180 |
+
scores, score = None, None
|
| 181 |
+
else:
|
| 182 |
+
scores = prev_out_score[cutoff]
|
| 183 |
+
score = scores.mean()
|
| 184 |
+
|
| 185 |
+
if prev_out_attn is None:
|
| 186 |
+
hypo_attn, alignment = None, None
|
| 187 |
+
else:
|
| 188 |
+
hypo_attn = prev_out_attn[cutoff]
|
| 189 |
+
alignment = hypo_attn.max(dim=1)[1]
|
| 190 |
+
return {
|
| 191 |
+
"steps": step,
|
| 192 |
+
"tokens": tokens,
|
| 193 |
+
"positional_scores": scores,
|
| 194 |
+
"score": score,
|
| 195 |
+
"hypo_attn": hypo_attn,
|
| 196 |
+
"alignment": alignment,
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
for step in range(self.max_iter + 1):
|
| 200 |
+
|
| 201 |
+
decoder_options = {
|
| 202 |
+
"eos_penalty": self.eos_penalty,
|
| 203 |
+
"max_ratio": self.max_ratio,
|
| 204 |
+
"decoding_format": self.decoding_format,
|
| 205 |
+
}
|
| 206 |
+
prev_decoder_out = prev_decoder_out._replace(
|
| 207 |
+
step=step,
|
| 208 |
+
max_step=self.max_iter + 1,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
decoder_out = model.forward_decoder(
|
| 212 |
+
prev_decoder_out, encoder_out, **decoder_options
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if self.adaptive:
|
| 216 |
+
# terminate if there is a loop
|
| 217 |
+
terminated, out_tokens, out_scores, out_attn = is_a_loop(
|
| 218 |
+
prev_output_tokens,
|
| 219 |
+
decoder_out.output_tokens,
|
| 220 |
+
decoder_out.output_scores,
|
| 221 |
+
decoder_out.attn,
|
| 222 |
+
)
|
| 223 |
+
decoder_out = decoder_out._replace(
|
| 224 |
+
output_tokens=out_tokens,
|
| 225 |
+
output_scores=out_scores,
|
| 226 |
+
attn=out_attn,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
else:
|
| 230 |
+
terminated = decoder_out.output_tokens.new_zeros(
|
| 231 |
+
decoder_out.output_tokens.size(0)
|
| 232 |
+
).bool()
|
| 233 |
+
|
| 234 |
+
if step == self.max_iter: # reach last iteration, terminate
|
| 235 |
+
terminated.fill_(1)
|
| 236 |
+
|
| 237 |
+
# collect finalized sentences
|
| 238 |
+
finalized_idxs = sent_idxs[terminated]
|
| 239 |
+
finalized_tokens = decoder_out.output_tokens[terminated]
|
| 240 |
+
finalized_scores = decoder_out.output_scores[terminated]
|
| 241 |
+
finalized_attn = (
|
| 242 |
+
None
|
| 243 |
+
if (decoder_out.attn is None or decoder_out.attn.size(0) == 0)
|
| 244 |
+
else decoder_out.attn[terminated]
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
if self.retain_history:
|
| 248 |
+
finalized_history_tokens = [h[terminated] for h in decoder_out.history]
|
| 249 |
+
|
| 250 |
+
for i in range(finalized_idxs.size(0)):
|
| 251 |
+
finalized[finalized_idxs[i]] = [
|
| 252 |
+
finalized_hypos(
|
| 253 |
+
step,
|
| 254 |
+
finalized_tokens[i],
|
| 255 |
+
finalized_scores[i],
|
| 256 |
+
None if finalized_attn is None else finalized_attn[i],
|
| 257 |
+
)
|
| 258 |
+
]
|
| 259 |
+
|
| 260 |
+
if self.retain_history:
|
| 261 |
+
finalized[finalized_idxs[i]][0]["history"] = []
|
| 262 |
+
for j in range(len(finalized_history_tokens)):
|
| 263 |
+
finalized[finalized_idxs[i]][0]["history"].append(
|
| 264 |
+
finalized_hypos(
|
| 265 |
+
step, finalized_history_tokens[j][i], None, None
|
| 266 |
+
)
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# check if all terminated
|
| 270 |
+
if terminated.sum() == terminated.size(0):
|
| 271 |
+
break
|
| 272 |
+
|
| 273 |
+
# for next step
|
| 274 |
+
not_terminated = ~terminated
|
| 275 |
+
prev_decoder_out = decoder_out._replace(
|
| 276 |
+
output_tokens=decoder_out.output_tokens[not_terminated],
|
| 277 |
+
output_scores=decoder_out.output_scores[not_terminated],
|
| 278 |
+
attn=decoder_out.attn[not_terminated]
|
| 279 |
+
if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0)
|
| 280 |
+
else None,
|
| 281 |
+
history=[h[not_terminated] for h in decoder_out.history]
|
| 282 |
+
if decoder_out.history is not None
|
| 283 |
+
else None,
|
| 284 |
+
)
|
| 285 |
+
encoder_out = model.encoder.reorder_encoder_out(
|
| 286 |
+
encoder_out, not_terminated.nonzero(as_tuple=False).squeeze()
|
| 287 |
+
)
|
| 288 |
+
sent_idxs = sent_idxs[not_terminated]
|
| 289 |
+
prev_output_tokens = prev_decoder_out.output_tokens.clone()
|
| 290 |
+
|
| 291 |
+
if self.beam_size > 1:
|
| 292 |
+
if reranker is not None:
|
| 293 |
+
finalized = self.rerank(
|
| 294 |
+
reranker, finalized, [src_tokens, src_lengths], self.beam_size
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# aggregate information from length beam
|
| 298 |
+
finalized = [
|
| 299 |
+
finalized[
|
| 300 |
+
np.argmax(
|
| 301 |
+
[
|
| 302 |
+
finalized[self.beam_size * i + j][0]["score"]
|
| 303 |
+
for j in range(self.beam_size)
|
| 304 |
+
]
|
| 305 |
+
)
|
| 306 |
+
+ self.beam_size * i
|
| 307 |
+
]
|
| 308 |
+
for i in range(len(finalized) // self.beam_size)
|
| 309 |
+
]
|
| 310 |
+
|
| 311 |
+
return finalized
|
| 312 |
+
|
| 313 |
+
def rerank(self, reranker, finalized, encoder_input, beam_size):
|
| 314 |
+
def rebuild_batch(finalized):
|
| 315 |
+
finalized_tokens = [f[0]["tokens"] for f in finalized]
|
| 316 |
+
finalized_maxlen = max(f.size(0) for f in finalized_tokens)
|
| 317 |
+
final_output_tokens = (
|
| 318 |
+
finalized_tokens[0]
|
| 319 |
+
.new_zeros(len(finalized_tokens), finalized_maxlen)
|
| 320 |
+
.fill_(self.pad)
|
| 321 |
+
)
|
| 322 |
+
for i, f in enumerate(finalized_tokens):
|
| 323 |
+
final_output_tokens[i, : f.size(0)] = f
|
| 324 |
+
return final_output_tokens
|
| 325 |
+
|
| 326 |
+
final_output_tokens = rebuild_batch(finalized)
|
| 327 |
+
final_output_tokens[
|
| 328 |
+
:, 0
|
| 329 |
+
] = self.eos # autoregressive model assumes starting with EOS
|
| 330 |
+
|
| 331 |
+
reranker_encoder_out = reranker.encoder(*encoder_input)
|
| 332 |
+
length_beam_order = (
|
| 333 |
+
utils.new_arange(
|
| 334 |
+
final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1)
|
| 335 |
+
)
|
| 336 |
+
.t()
|
| 337 |
+
.reshape(-1)
|
| 338 |
+
)
|
| 339 |
+
reranker_encoder_out = reranker.encoder.reorder_encoder_out(
|
| 340 |
+
reranker_encoder_out, length_beam_order
|
| 341 |
+
)
|
| 342 |
+
reranking_scores = reranker.get_normalized_probs(
|
| 343 |
+
reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out),
|
| 344 |
+
True,
|
| 345 |
+
None,
|
| 346 |
+
)
|
| 347 |
+
reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None])
|
| 348 |
+
reranking_masks = final_output_tokens[:, 1:].ne(self.pad)
|
| 349 |
+
reranking_scores = (
|
| 350 |
+
reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1)
|
| 351 |
+
)
|
| 352 |
+
reranking_scores = reranking_scores / reranking_masks.sum(1).type_as(
|
| 353 |
+
reranking_scores
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
for i in range(len(finalized)):
|
| 357 |
+
finalized[i][0]["score"] = reranking_scores[i]
|
| 358 |
+
|
| 359 |
+
return finalized
|
fairseq-0.10.2/fairseq/legacy_distributed_data_parallel.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
A modified version of the legacy DistributedDataParallel module that uses c10d
|
| 8 |
+
communication primitives. This version is simpler than the latest PyTorch
|
| 9 |
+
version and is useful for debugging. Notably it does not overlap gradient
|
| 10 |
+
communication with the backward pass, which makes it slower but more robust
|
| 11 |
+
than the PyTorch version.
|
| 12 |
+
|
| 13 |
+
This version also supports the *no_sync* context manager, which allows faster
|
| 14 |
+
training with `--update-freq`.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
from collections import OrderedDict
|
| 19 |
+
from contextlib import contextmanager
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch import nn
|
| 23 |
+
from torch.autograd import Variable
|
| 24 |
+
|
| 25 |
+
from . import distributed_utils
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LegacyDistributedDataParallel(nn.Module):
|
| 29 |
+
"""Implements distributed data parallelism at the module level.
|
| 30 |
+
|
| 31 |
+
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
|
| 32 |
+
This version uses a c10d process group for communication and does not
|
| 33 |
+
broadcast buffers.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
module (~torch.nn.Module): module to be parallelized
|
| 37 |
+
world_size (int): number of parallel workers
|
| 38 |
+
process_group (optional): the c10d process group to be used for
|
| 39 |
+
distributed data all-reduction. If None, the default process group
|
| 40 |
+
will be used.
|
| 41 |
+
buffer_size (int, optional): number of elements to buffer before
|
| 42 |
+
performing all-reduce (default: 256M).
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, module, world_size, process_group=None, buffer_size=2 ** 28):
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
self.module = module
|
| 49 |
+
self.world_size = world_size
|
| 50 |
+
self.process_group = process_group
|
| 51 |
+
|
| 52 |
+
# Never use a bigger buffer than the number of model params
|
| 53 |
+
self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters()))
|
| 54 |
+
self.buffer = None
|
| 55 |
+
|
| 56 |
+
# We can also forcibly accumulate grads locally and only do the
|
| 57 |
+
# all-reduce at some later time
|
| 58 |
+
self.accumulate_grads = False
|
| 59 |
+
|
| 60 |
+
# make per-device lists of parameters
|
| 61 |
+
paramlists = OrderedDict()
|
| 62 |
+
for param in self.module.parameters():
|
| 63 |
+
device = param.device
|
| 64 |
+
if paramlists.get(device) is None:
|
| 65 |
+
paramlists[device] = []
|
| 66 |
+
paramlists[device] += [param]
|
| 67 |
+
self.per_device_params = list(paramlists.values())
|
| 68 |
+
|
| 69 |
+
def __getstate__(self):
|
| 70 |
+
attrs = copy.copy(self.__dict__)
|
| 71 |
+
return attrs
|
| 72 |
+
|
| 73 |
+
def __setstate__(self, state):
|
| 74 |
+
super().__setstate__(state)
|
| 75 |
+
|
| 76 |
+
@contextmanager
|
| 77 |
+
def no_sync(self):
|
| 78 |
+
"""A context manager to disable gradient synchronization."""
|
| 79 |
+
old_accumulate_grads = self.accumulate_grads
|
| 80 |
+
self.accumulate_grads = True
|
| 81 |
+
yield
|
| 82 |
+
self.accumulate_grads = old_accumulate_grads
|
| 83 |
+
|
| 84 |
+
def forward(self, *inputs, **kwargs):
|
| 85 |
+
return self.module(*inputs, **kwargs)
|
| 86 |
+
|
| 87 |
+
def all_reduce(self):
|
| 88 |
+
"""
|
| 89 |
+
This function must be called explicitly after backward to reduce
|
| 90 |
+
gradients. There is no automatic hook like c10d.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def all_reduce_params(params):
|
| 94 |
+
buffer = self.buffer
|
| 95 |
+
nonzero_buffer = False
|
| 96 |
+
if len(params) > 1:
|
| 97 |
+
offset = 0
|
| 98 |
+
for p in params:
|
| 99 |
+
sz = p.numel()
|
| 100 |
+
if p.grad is not None:
|
| 101 |
+
buffer[offset : offset + sz].copy_(p.grad.data.view(-1))
|
| 102 |
+
nonzero_buffer = True
|
| 103 |
+
else:
|
| 104 |
+
buffer[offset : offset + sz].zero_()
|
| 105 |
+
offset += sz
|
| 106 |
+
else:
|
| 107 |
+
# we only have a single grad to all-reduce
|
| 108 |
+
p = params[0]
|
| 109 |
+
if p.grad is not None:
|
| 110 |
+
buffer = p.grad.data
|
| 111 |
+
nonzero_buffer = True
|
| 112 |
+
elif p.numel() <= self.buffer.numel():
|
| 113 |
+
buffer = buffer[: p.numel()]
|
| 114 |
+
buffer.zero_()
|
| 115 |
+
else:
|
| 116 |
+
buffer = torch.zeros_like(p)
|
| 117 |
+
|
| 118 |
+
if nonzero_buffer:
|
| 119 |
+
buffer.div_(self.world_size)
|
| 120 |
+
|
| 121 |
+
distributed_utils.all_reduce(buffer, self.process_group)
|
| 122 |
+
|
| 123 |
+
# copy all-reduced grads back into their original place
|
| 124 |
+
offset = 0
|
| 125 |
+
for p in params:
|
| 126 |
+
sz = p.numel()
|
| 127 |
+
if p.grad is not None:
|
| 128 |
+
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p))
|
| 129 |
+
else:
|
| 130 |
+
p.grad = buffer[offset : offset + sz].view_as(p).clone()
|
| 131 |
+
offset += sz
|
| 132 |
+
|
| 133 |
+
def reduction_fn():
|
| 134 |
+
# This function only needs to be called once
|
| 135 |
+
if self.accumulate_grads:
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
if self.buffer is None:
|
| 139 |
+
self.buffer = next(self.module.parameters()).new(self.buffer_size)
|
| 140 |
+
|
| 141 |
+
for params in self.per_device_params:
|
| 142 |
+
# All-reduce the gradients in buckets
|
| 143 |
+
offset = 0
|
| 144 |
+
buffered_params = []
|
| 145 |
+
for param in params:
|
| 146 |
+
if not param.requires_grad:
|
| 147 |
+
continue
|
| 148 |
+
if param.grad is None:
|
| 149 |
+
param.grad = torch.zeros_like(param)
|
| 150 |
+
if param.grad.requires_grad:
|
| 151 |
+
raise RuntimeError(
|
| 152 |
+
"DistributedDataParallel only works "
|
| 153 |
+
"with gradients that don't require "
|
| 154 |
+
"grad"
|
| 155 |
+
)
|
| 156 |
+
sz = param.numel()
|
| 157 |
+
if sz > self.buffer.numel():
|
| 158 |
+
# all-reduce big params directly
|
| 159 |
+
all_reduce_params([param])
|
| 160 |
+
else:
|
| 161 |
+
if offset + sz > self.buffer.numel():
|
| 162 |
+
all_reduce_params(buffered_params)
|
| 163 |
+
offset = 0
|
| 164 |
+
buffered_params.clear()
|
| 165 |
+
buffered_params.append(param)
|
| 166 |
+
offset += sz
|
| 167 |
+
|
| 168 |
+
if len(buffered_params) > 0:
|
| 169 |
+
all_reduce_params(buffered_params)
|
| 170 |
+
|
| 171 |
+
reduction_fn()
|
fairseq-0.10.2/fairseq/model_parallel/__pycache__/megatron_trainer.cpython-310.pyc
ADDED
|
Binary file (2.41 kB). View file
|
|
|
fairseq-0.10.2/fairseq/model_parallel/criterions/__pycache__/vocab_parallel_cross_entropy.cpython-310.pyc
ADDED
|
Binary file (3.5 kB). View file
|
|
|
fairseq-0.10.2/fairseq/model_parallel/megatron_trainer.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Train a network across multiple GPUs.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from fairseq import distributed_utils
|
| 11 |
+
from fairseq.trainer import Trainer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from fairseq.model_parallel.megatron.mpu import (
|
| 16 |
+
get_data_parallel_group,
|
| 17 |
+
get_data_parallel_rank,
|
| 18 |
+
get_data_parallel_world_size,
|
| 19 |
+
get_model_parallel_group,
|
| 20 |
+
get_model_parallel_src_rank,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
has_megatron_submodule = True
|
| 24 |
+
except (ImportError, ModuleNotFoundError):
|
| 25 |
+
has_megatron_submodule = False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class MegatronTrainer(Trainer):
|
| 29 |
+
"""Main class for model parallel with data parallel training."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, args, task, model, criterion):
|
| 32 |
+
if not has_megatron_submodule:
|
| 33 |
+
raise ImportError(
|
| 34 |
+
"\n\nPlease install the megatron submodule:"
|
| 35 |
+
"\n\n git submodule update --init "
|
| 36 |
+
"fairseq/model_parallel/megatron"
|
| 37 |
+
)
|
| 38 |
+
super().__init__(args, task, model, criterion)
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def data_parallel_world_size(self):
|
| 42 |
+
return get_data_parallel_world_size()
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def data_parallel_process_group(self):
|
| 46 |
+
return get_data_parallel_group()
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def data_parallel_rank(self):
|
| 50 |
+
return get_data_parallel_rank()
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def is_data_parallel_master(self):
|
| 54 |
+
return get_model_parallel_src_rank() == 0
|
| 55 |
+
|
| 56 |
+
def clip_grad_norm(self, clip_norm):
|
| 57 |
+
def _aggregate_model_parallel_grad_norm(total_norm):
|
| 58 |
+
total_norm = total_norm ** 2
|
| 59 |
+
distributed_utils.all_reduce(total_norm, group=get_model_parallel_group())
|
| 60 |
+
total_norm = total_norm ** 0.5
|
| 61 |
+
return total_norm
|
| 62 |
+
|
| 63 |
+
return self.optimizer.clip_grad_norm(
|
| 64 |
+
clip_norm,
|
| 65 |
+
aggregate_norm_fn=_aggregate_model_parallel_grad_norm,
|
| 66 |
+
)
|
fairseq-0.10.2/fairseq/model_parallel/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (544 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .model import * # noqa
|
fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (230 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py
ADDED
|
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from collections import namedtuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from fairseq import options, utils
|
| 13 |
+
from fairseq.modules import (
|
| 14 |
+
AdaptiveSoftmax,
|
| 15 |
+
LayerNorm,
|
| 16 |
+
MultiheadAttention,
|
| 17 |
+
PositionalEmbedding,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
EncoderOut = namedtuple(
|
| 22 |
+
"TransformerEncoderOut",
|
| 23 |
+
[
|
| 24 |
+
"encoder_out", # T x B x C
|
| 25 |
+
"encoder_padding_mask", # B x T
|
| 26 |
+
"encoder_embedding", # B x T x C
|
| 27 |
+
"encoder_states", # List[T x B x C]
|
| 28 |
+
],
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TransformerEncoderEmbedding(nn.Module):
|
| 33 |
+
""" Encoder Embedding + Positional Embedding """
|
| 34 |
+
|
| 35 |
+
def __init__(self, args, embed_tokens):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.dropout = args.dropout
|
| 38 |
+
self.max_source_positions = args.max_source_positions
|
| 39 |
+
self.embed_tokens = embed_tokens
|
| 40 |
+
if isinstance(embed_tokens, nn.ModuleList):
|
| 41 |
+
self.padding_idx = embed_tokens[0].padding_idx
|
| 42 |
+
embed_dim = sum(e.embedding_dim for e in embed_tokens)
|
| 43 |
+
else:
|
| 44 |
+
self.padding_idx = embed_tokens.padding_idx
|
| 45 |
+
embed_dim = embed_tokens.embedding_dim
|
| 46 |
+
self.embed_scale = math.sqrt(embed_dim)
|
| 47 |
+
self.embed_positions = (
|
| 48 |
+
PositionalEmbedding(
|
| 49 |
+
args.max_source_positions,
|
| 50 |
+
embed_dim,
|
| 51 |
+
self.padding_idx,
|
| 52 |
+
learned=args.encoder_learned_pos,
|
| 53 |
+
)
|
| 54 |
+
if not args.no_token_positional_embeddings
|
| 55 |
+
else None
|
| 56 |
+
)
|
| 57 |
+
if getattr(args, "layernorm_embedding", False):
|
| 58 |
+
self.layernorm_embedding = LayerNorm(embed_dim)
|
| 59 |
+
else:
|
| 60 |
+
self.layernorm_embedding = None
|
| 61 |
+
|
| 62 |
+
def forward(self, input):
|
| 63 |
+
# embed tokens and positions
|
| 64 |
+
src_tokens = input[0]
|
| 65 |
+
prev_output_tokens = input[2]
|
| 66 |
+
if isinstance(self.embed_tokens, nn.ModuleList):
|
| 67 |
+
x_embed_list = []
|
| 68 |
+
for embed_tokens_part in self.embed_tokens:
|
| 69 |
+
x_embed_list.append(embed_tokens_part(src_tokens))
|
| 70 |
+
|
| 71 |
+
embedded = torch.cat(x_embed_list, dim=-1)
|
| 72 |
+
else:
|
| 73 |
+
embedded = self.embed_tokens(src_tokens)
|
| 74 |
+
x = embed = self.embed_scale * embedded
|
| 75 |
+
if self.embed_positions is not None:
|
| 76 |
+
x = embed + self.embed_positions(src_tokens)
|
| 77 |
+
if self.layernorm_embedding:
|
| 78 |
+
x = self.layernorm_embedding(x)
|
| 79 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 80 |
+
# B x T x C -> T x B x C
|
| 81 |
+
x = x.transpose(0, 1)
|
| 82 |
+
|
| 83 |
+
# compute padding mask
|
| 84 |
+
encoder_padding_mask = src_tokens.eq(self.padding_idx)
|
| 85 |
+
return (x, encoder_padding_mask, prev_output_tokens)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class TransformerEncoderLayerNorm(nn.Module):
|
| 89 |
+
"""
|
| 90 |
+
Layer norm at the the end of all encoder layers if
|
| 91 |
+
args.encoder_enormalize_before = True
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, args, embed_dim):
|
| 95 |
+
super().__init__()
|
| 96 |
+
if args.encoder_normalize_before:
|
| 97 |
+
self.layer_norm = LayerNorm(embed_dim)
|
| 98 |
+
else:
|
| 99 |
+
self.layer_norm = None
|
| 100 |
+
|
| 101 |
+
def forward(self, input):
|
| 102 |
+
x = input[0]
|
| 103 |
+
encoder_padding_mask = input[1]
|
| 104 |
+
prev_output_tokens = input[2]
|
| 105 |
+
if self.layer_norm:
|
| 106 |
+
x = self.layer_norm(x)
|
| 107 |
+
# keeping track of the incremental_state is not supported yet
|
| 108 |
+
return (x, encoder_padding_mask, prev_output_tokens)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class TransformerDecoderEmbedding(nn.Module):
|
| 112 |
+
""" Decoder Embedding + Positional Embedding """
|
| 113 |
+
|
| 114 |
+
def __init__(self, args, embed_tokens):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.dropout = args.dropout
|
| 117 |
+
self.share_input_output_embed = args.share_decoder_input_output_embed
|
| 118 |
+
input_embed_dim = (
|
| 119 |
+
sum(e.embedding_dim for e in embed_tokens)
|
| 120 |
+
if isinstance(embed_tokens, nn.ModuleList)
|
| 121 |
+
else embed_tokens.embedding_dim
|
| 122 |
+
)
|
| 123 |
+
embed_dim = args.decoder_embed_dim
|
| 124 |
+
self.output_embed_dim = args.decoder_output_dim
|
| 125 |
+
|
| 126 |
+
padding_idx = (
|
| 127 |
+
embed_tokens[0].padding_idx
|
| 128 |
+
if isinstance(embed_tokens, nn.ModuleList)
|
| 129 |
+
else embed_tokens.padding_idx
|
| 130 |
+
)
|
| 131 |
+
self.max_target_positions = args.max_target_positions
|
| 132 |
+
|
| 133 |
+
self.embed_tokens = embed_tokens
|
| 134 |
+
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
|
| 135 |
+
|
| 136 |
+
self.project_in_dim = (
|
| 137 |
+
Linear(input_embed_dim, embed_dim, bias=False)
|
| 138 |
+
if embed_dim != input_embed_dim
|
| 139 |
+
else None
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
self.embed_positions = (
|
| 143 |
+
PositionalEmbedding(
|
| 144 |
+
args.max_target_positions,
|
| 145 |
+
embed_dim,
|
| 146 |
+
padding_idx,
|
| 147 |
+
learned=args.decoder_learned_pos,
|
| 148 |
+
)
|
| 149 |
+
if not args.no_token_positional_embeddings
|
| 150 |
+
else None
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def forward(self, input):
|
| 154 |
+
mt_task = False
|
| 155 |
+
if isinstance(input, tuple):
|
| 156 |
+
if len(input) == 3:
|
| 157 |
+
encoder_out = input[0]
|
| 158 |
+
encoder_padding_mask = input[1]
|
| 159 |
+
prev_output_tokens = input[2]
|
| 160 |
+
incremental_state = None # Hardcoding to avoid passing of None objects
|
| 161 |
+
mt_task = True
|
| 162 |
+
else:
|
| 163 |
+
# HACK for now, need to fix (TODO sidgoyal)
|
| 164 |
+
prev_output_tokens = input[0]
|
| 165 |
+
# discard "src_lengths"
|
| 166 |
+
encoder_out = None
|
| 167 |
+
encoder_padding_mask = None
|
| 168 |
+
incremental_state = None
|
| 169 |
+
|
| 170 |
+
else:
|
| 171 |
+
prev_output_tokens = input
|
| 172 |
+
encoder_out = None
|
| 173 |
+
encoder_padding_mask = None
|
| 174 |
+
incremental_state = None
|
| 175 |
+
|
| 176 |
+
positions = (
|
| 177 |
+
self.embed_positions(
|
| 178 |
+
prev_output_tokens,
|
| 179 |
+
incremental_state=incremental_state,
|
| 180 |
+
)
|
| 181 |
+
if self.embed_positions is not None
|
| 182 |
+
else None
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if incremental_state is not None:
|
| 186 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
| 187 |
+
if positions is not None:
|
| 188 |
+
positions = positions[:, -1:]
|
| 189 |
+
|
| 190 |
+
# embed tokens and positions
|
| 191 |
+
|
| 192 |
+
if isinstance(self.embed_tokens, nn.ModuleList):
|
| 193 |
+
x_embed_list = []
|
| 194 |
+
for embed_tokens_part in self.embed_tokens:
|
| 195 |
+
x_embed_list.append(embed_tokens_part(prev_output_tokens))
|
| 196 |
+
|
| 197 |
+
x = self.embed_scale * torch.cat(x_embed_list, dim=-1)
|
| 198 |
+
else:
|
| 199 |
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
| 200 |
+
|
| 201 |
+
if self.project_in_dim is not None:
|
| 202 |
+
x = self.project_in_dim(x)
|
| 203 |
+
|
| 204 |
+
if positions is not None:
|
| 205 |
+
x += positions
|
| 206 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 207 |
+
|
| 208 |
+
# B x T x C -> T x B x C
|
| 209 |
+
x = x.transpose(0, 1)
|
| 210 |
+
if mt_task:
|
| 211 |
+
return (x, encoder_out, encoder_padding_mask)
|
| 212 |
+
return x
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class TransformerDecoderOutputLayer(nn.Module):
|
| 216 |
+
def __init__(self, args, embed_tokens, dictionary):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.share_input_output_embed = args.share_decoder_input_output_embed
|
| 219 |
+
self.embed_tokens = embed_tokens
|
| 220 |
+
self.output_embed_dim = args.decoder_output_dim
|
| 221 |
+
embed_dim = args.decoder_embed_dim
|
| 222 |
+
|
| 223 |
+
self.project_out_dim = (
|
| 224 |
+
Linear(embed_dim, self.output_embed_dim, bias=False)
|
| 225 |
+
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
|
| 226 |
+
else None
|
| 227 |
+
)
|
| 228 |
+
self.adaptive_softmax = None
|
| 229 |
+
if args.adaptive_softmax_cutoff is not None:
|
| 230 |
+
assert not isinstance(embed_tokens, nn.ModuleList)
|
| 231 |
+
self.adaptive_softmax = AdaptiveSoftmax(
|
| 232 |
+
len(dictionary),
|
| 233 |
+
self.output_embed_dim,
|
| 234 |
+
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
|
| 235 |
+
dropout=args.adaptive_softmax_dropout,
|
| 236 |
+
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
|
| 237 |
+
factor=args.adaptive_softmax_factor,
|
| 238 |
+
tie_proj=args.tie_adaptive_proj,
|
| 239 |
+
)
|
| 240 |
+
elif not self.share_input_output_embed:
|
| 241 |
+
self.embed_tokens = nn.Parameter(
|
| 242 |
+
torch.Tensor(len(dictionary), self.output_embed_dim)
|
| 243 |
+
)
|
| 244 |
+
nn.init.normal_(
|
| 245 |
+
self.embed_tokens, mean=0, std=self.output_embed_dim ** -0.5
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if args.decoder_normalize_before and not getattr(
|
| 249 |
+
args, "no_decoder_final_norm", False
|
| 250 |
+
):
|
| 251 |
+
self.layer_norm = LayerNorm(embed_dim)
|
| 252 |
+
else:
|
| 253 |
+
self.layer_norm = None
|
| 254 |
+
|
| 255 |
+
def forward(self, input, apply_final_proj=True):
|
| 256 |
+
if isinstance(input, tuple):
|
| 257 |
+
x = input[0]
|
| 258 |
+
else:
|
| 259 |
+
x = input
|
| 260 |
+
|
| 261 |
+
if self.layer_norm:
|
| 262 |
+
x = self.layer_norm(x)
|
| 263 |
+
|
| 264 |
+
# T x B x C -> B x T x C
|
| 265 |
+
x = x.transpose(0, 1)
|
| 266 |
+
|
| 267 |
+
if self.project_out_dim is not None:
|
| 268 |
+
x = self.project_out_dim(x)
|
| 269 |
+
if apply_final_proj:
|
| 270 |
+
x = self.output_layer(x)
|
| 271 |
+
return x
|
| 272 |
+
|
| 273 |
+
def output_layer(self, features, **kwargs):
|
| 274 |
+
"""Project features to the vocabulary size."""
|
| 275 |
+
if self.adaptive_softmax is None:
|
| 276 |
+
# project back to size of vocabulary
|
| 277 |
+
if self.share_input_output_embed:
|
| 278 |
+
if isinstance(self.embed_tokens, nn.ModuleList):
|
| 279 |
+
output = None
|
| 280 |
+
for i, emb in enumerate(self.embed_tokens):
|
| 281 |
+
sidx = i * emb.embedding_dim
|
| 282 |
+
eidx = (i + 1) * emb.embedding_dim
|
| 283 |
+
if output is None:
|
| 284 |
+
output = F.linear(features[:, :, sidx:eidx], emb.weight)
|
| 285 |
+
else:
|
| 286 |
+
output += F.linear(features[:, :, sidx:eidx], emb.weight)
|
| 287 |
+
|
| 288 |
+
return output
|
| 289 |
+
else:
|
| 290 |
+
return F.linear(features, self.embed_tokens.weight)
|
| 291 |
+
else:
|
| 292 |
+
return F.linear(features, self.embed_tokens)
|
| 293 |
+
else:
|
| 294 |
+
return features
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class TransformerEncoderLayer(nn.Module):
|
| 298 |
+
"""Encoder layer block.
|
| 299 |
+
In the original paper each operation (multi-head attention or FFN) is
|
| 300 |
+
postprocessed with: `dropout -> add residual -> layernorm`. In the
|
| 301 |
+
tensor2tensor code they suggest that learning is more robust when
|
| 302 |
+
preprocessing each layer with layernorm and postprocessing with:
|
| 303 |
+
`dropout -> add residual`. We default to the approach in the paper, but the
|
| 304 |
+
tensor2tensor approach can be enabled by setting
|
| 305 |
+
*args.encoder_normalize_before* to ``True``.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
args (argparse.Namespace): parsed command-line arguments
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
def __init__(self, args):
|
| 312 |
+
super().__init__()
|
| 313 |
+
self.embed_dim = args.encoder_embed_dim
|
| 314 |
+
self.self_attn = MultiheadAttention(
|
| 315 |
+
self.embed_dim,
|
| 316 |
+
args.encoder_attention_heads,
|
| 317 |
+
dropout=args.attention_dropout,
|
| 318 |
+
self_attention=True,
|
| 319 |
+
)
|
| 320 |
+
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
| 321 |
+
self.dropout = args.dropout
|
| 322 |
+
self.activation_fn = utils.get_activation_fn(
|
| 323 |
+
activation=getattr(args, "activation_fn", "relu")
|
| 324 |
+
)
|
| 325 |
+
self.activation_dropout = getattr(args, "activation_dropout", 0)
|
| 326 |
+
if self.activation_dropout == 0:
|
| 327 |
+
# for backwards compatibility with models that use args.relu_dropout
|
| 328 |
+
self.activation_dropout = getattr(args, "relu_dropout", 0)
|
| 329 |
+
self.normalize_before = args.encoder_normalize_before
|
| 330 |
+
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
|
| 331 |
+
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
|
| 332 |
+
self.final_layer_norm = LayerNorm(self.embed_dim)
|
| 333 |
+
|
| 334 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 335 |
+
"""
|
| 336 |
+
Rename layer norm states from `...layer_norms.0.weight` to
|
| 337 |
+
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
|
| 338 |
+
`...final_layer_norm.weight`
|
| 339 |
+
"""
|
| 340 |
+
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
|
| 341 |
+
for old, new in layer_norm_map.items():
|
| 342 |
+
for m in ("weight", "bias"):
|
| 343 |
+
k = "{}.layer_norms.{}.{}".format(name, old, m)
|
| 344 |
+
if k in state_dict:
|
| 345 |
+
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
|
| 346 |
+
del state_dict[k]
|
| 347 |
+
|
| 348 |
+
def forward(self, input):
|
| 349 |
+
"""
|
| 350 |
+
Args:
|
| 351 |
+
input (Tuple):
|
| 352 |
+
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
| 353 |
+
input[1] (ByteTensor/FloatTensor): encoder padding mask -
|
| 354 |
+
binary ByteTensor of shape `(batch, src_len)` where padding elements
|
| 355 |
+
are indicated by ``1``.
|
| 356 |
+
input[2] (LongTensor): previous decoder outputs of shape
|
| 357 |
+
`(batch, tgt_len)`, for teacher forcing)
|
| 358 |
+
Returns:
|
| 359 |
+
output (Tuple):
|
| 360 |
+
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
|
| 361 |
+
output[1] (ByteTensor/FloatTensor): encoder padding mask
|
| 362 |
+
output[2] (LongTensor): previous decoder outputs
|
| 363 |
+
"""
|
| 364 |
+
x = input[0]
|
| 365 |
+
encoder_padding_mask = input[1]
|
| 366 |
+
prev_output_tokens = input[2]
|
| 367 |
+
residual = x
|
| 368 |
+
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
|
| 369 |
+
x, _ = self.self_attn(
|
| 370 |
+
query=x, key=x, value=x, key_padding_mask=encoder_padding_mask
|
| 371 |
+
)
|
| 372 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 373 |
+
x = residual + x
|
| 374 |
+
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
|
| 375 |
+
|
| 376 |
+
residual = x
|
| 377 |
+
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
|
| 378 |
+
x = self.activation_fn(self.fc1(x))
|
| 379 |
+
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
| 380 |
+
x = self.fc2(x)
|
| 381 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 382 |
+
x = residual + x
|
| 383 |
+
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
|
| 384 |
+
return (x, encoder_padding_mask, prev_output_tokens)
|
| 385 |
+
|
| 386 |
+
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
|
| 387 |
+
assert before ^ after
|
| 388 |
+
if after ^ self.normalize_before:
|
| 389 |
+
return layer_norm(x)
|
| 390 |
+
else:
|
| 391 |
+
return x
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class TransformerDecoderLayer(nn.Module):
|
| 395 |
+
"""Decoder layer block.
|
| 396 |
+
|
| 397 |
+
In the original paper each operation (multi-head attention, encoder
|
| 398 |
+
attention or FFN) is postprocessed with: `dropout -> add residual ->
|
| 399 |
+
layernorm`. In the tensor2tensor code they suggest that learning is more
|
| 400 |
+
robust when preprocessing each layer with layernorm and postprocessing with:
|
| 401 |
+
`dropout -> add residual`. We default to the approach in the paper, but the
|
| 402 |
+
tensor2tensor approach can be enabled by setting
|
| 403 |
+
*args.decoder_normalize_before* to ``True``.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
args (argparse.Namespace): parsed command-line arguments
|
| 407 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
| 408 |
+
(default: False).
|
| 409 |
+
"""
|
| 410 |
+
|
| 411 |
+
def __init__(
|
| 412 |
+
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
|
| 413 |
+
):
|
| 414 |
+
super().__init__()
|
| 415 |
+
self.embed_dim = args.decoder_embed_dim
|
| 416 |
+
self.self_attn = MultiheadAttention(
|
| 417 |
+
embed_dim=self.embed_dim,
|
| 418 |
+
num_heads=args.decoder_attention_heads,
|
| 419 |
+
dropout=args.attention_dropout,
|
| 420 |
+
add_bias_kv=add_bias_kv,
|
| 421 |
+
add_zero_attn=add_zero_attn,
|
| 422 |
+
self_attention=True,
|
| 423 |
+
)
|
| 424 |
+
self.dropout = args.dropout
|
| 425 |
+
self.activation_fn = utils.get_activation_fn(
|
| 426 |
+
activation=getattr(args, "activation_fn", "relu")
|
| 427 |
+
)
|
| 428 |
+
self.activation_dropout = getattr(args, "activation_dropout", 0)
|
| 429 |
+
if self.activation_dropout == 0:
|
| 430 |
+
# for backwards compatibility with models that use args.relu_dropout
|
| 431 |
+
self.activation_dropout = getattr(args, "relu_dropout", 0)
|
| 432 |
+
self.normalize_before = args.decoder_normalize_before
|
| 433 |
+
|
| 434 |
+
# use layerNorm rather than FusedLayerNorm for exporting.
|
| 435 |
+
# char_inputs can be used to determint this.
|
| 436 |
+
# TODO remove this once we update apex with the fix
|
| 437 |
+
export = getattr(args, "char_inputs", False)
|
| 438 |
+
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
|
| 439 |
+
|
| 440 |
+
if no_encoder_attn:
|
| 441 |
+
self.encoder_attn = None
|
| 442 |
+
self.encoder_attn_layer_norm = None
|
| 443 |
+
else:
|
| 444 |
+
self.encoder_attn = MultiheadAttention(
|
| 445 |
+
self.embed_dim,
|
| 446 |
+
args.decoder_attention_heads,
|
| 447 |
+
kdim=getattr(args, "encoder_embed_dim", None),
|
| 448 |
+
vdim=getattr(args, "encoder_embed_dim", None),
|
| 449 |
+
dropout=args.attention_dropout,
|
| 450 |
+
encoder_decoder_attention=True,
|
| 451 |
+
)
|
| 452 |
+
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
|
| 453 |
+
|
| 454 |
+
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
|
| 455 |
+
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
|
| 456 |
+
|
| 457 |
+
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
|
| 458 |
+
self.need_attn = True
|
| 459 |
+
|
| 460 |
+
self.onnx_trace = False
|
| 461 |
+
|
| 462 |
+
def prepare_for_onnx_export_(self):
|
| 463 |
+
self.onnx_trace = True
|
| 464 |
+
|
| 465 |
+
def forward(self, input):
|
| 466 |
+
"""
|
| 467 |
+
Args:
|
| 468 |
+
input (Tuple):
|
| 469 |
+
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
| 470 |
+
input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)`
|
| 471 |
+
input[2] (ByteTensor/FloatTensor): encoder padding mask -
|
| 472 |
+
binary ByteTensor of shape `(batch, src_len)` where padding elements
|
| 473 |
+
are indicated by ``1``.
|
| 474 |
+
Returns:
|
| 475 |
+
output (Tuple):
|
| 476 |
+
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
|
| 477 |
+
output[1] (ByteTensor/FloatTensor): encoder padding mask
|
| 478 |
+
output[2] (LongTensor): previous decoder outputs
|
| 479 |
+
"""
|
| 480 |
+
# Note: incremental state is not yet supported
|
| 481 |
+
mt_task = False
|
| 482 |
+
if isinstance(input, tuple):
|
| 483 |
+
x = input[0]
|
| 484 |
+
encoder_out = input[1]
|
| 485 |
+
encoder_padding_mask = input[2]
|
| 486 |
+
incremental_state = None
|
| 487 |
+
mt_task = True
|
| 488 |
+
else:
|
| 489 |
+
x = input
|
| 490 |
+
encoder_out = None
|
| 491 |
+
encoder_padding_mask = None
|
| 492 |
+
incremental_state = None
|
| 493 |
+
|
| 494 |
+
if incremental_state is None:
|
| 495 |
+
self_attn_mask = self.buffered_future_mask(x)
|
| 496 |
+
else:
|
| 497 |
+
self_attn_mask = None
|
| 498 |
+
|
| 499 |
+
# TODO: add back prev_self_attn_state, prev_attn_state,
|
| 500 |
+
# self_attn_padding_mask
|
| 501 |
+
prev_self_attn_state = None
|
| 502 |
+
prev_attn_state = None
|
| 503 |
+
self_attn_padding_mask = None
|
| 504 |
+
|
| 505 |
+
residual = x
|
| 506 |
+
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
|
| 507 |
+
if prev_self_attn_state is not None:
|
| 508 |
+
if incremental_state is None:
|
| 509 |
+
incremental_state = {}
|
| 510 |
+
prev_key, prev_value = prev_self_attn_state
|
| 511 |
+
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
|
| 512 |
+
self.self_attn._set_input_buffer(incremental_state, saved_state)
|
| 513 |
+
x, attn = self.self_attn(
|
| 514 |
+
query=x,
|
| 515 |
+
key=x,
|
| 516 |
+
value=x,
|
| 517 |
+
key_padding_mask=self_attn_padding_mask,
|
| 518 |
+
incremental_state=incremental_state,
|
| 519 |
+
need_weights=False,
|
| 520 |
+
attn_mask=self_attn_mask,
|
| 521 |
+
)
|
| 522 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 523 |
+
x = residual + x
|
| 524 |
+
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
|
| 525 |
+
|
| 526 |
+
if self.encoder_attn is not None:
|
| 527 |
+
residual = x
|
| 528 |
+
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
|
| 529 |
+
if prev_attn_state is not None:
|
| 530 |
+
if incremental_state is None:
|
| 531 |
+
incremental_state = {}
|
| 532 |
+
prev_key, prev_value = prev_attn_state
|
| 533 |
+
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
|
| 534 |
+
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
|
| 535 |
+
x, attn = self.encoder_attn(
|
| 536 |
+
query=x,
|
| 537 |
+
key=encoder_out,
|
| 538 |
+
value=encoder_out,
|
| 539 |
+
key_padding_mask=encoder_padding_mask,
|
| 540 |
+
incremental_state=incremental_state,
|
| 541 |
+
static_kv=True,
|
| 542 |
+
need_weights=(not self.training and self.need_attn),
|
| 543 |
+
)
|
| 544 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 545 |
+
x = residual + x
|
| 546 |
+
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
|
| 547 |
+
|
| 548 |
+
residual = x
|
| 549 |
+
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
|
| 550 |
+
x = self.activation_fn(self.fc1(x))
|
| 551 |
+
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
| 552 |
+
x = self.fc2(x)
|
| 553 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 554 |
+
x = residual + x
|
| 555 |
+
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
|
| 556 |
+
|
| 557 |
+
if mt_task:
|
| 558 |
+
return (x, encoder_out, encoder_padding_mask)
|
| 559 |
+
return x
|
| 560 |
+
|
| 561 |
+
def buffered_future_mask(self, tensor):
|
| 562 |
+
dim = tensor.size(0)
|
| 563 |
+
if (
|
| 564 |
+
not hasattr(self, "_future_mask")
|
| 565 |
+
or self._future_mask is None
|
| 566 |
+
or self._future_mask.device != tensor.device
|
| 567 |
+
):
|
| 568 |
+
self._future_mask = torch.triu(
|
| 569 |
+
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
|
| 570 |
+
)
|
| 571 |
+
if self._future_mask.size(0) < dim:
|
| 572 |
+
self._future_mask = torch.triu(
|
| 573 |
+
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
|
| 574 |
+
)
|
| 575 |
+
return self._future_mask[:dim, :dim]
|
| 576 |
+
|
| 577 |
+
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
|
| 578 |
+
assert before ^ after
|
| 579 |
+
if after ^ self.normalize_before:
|
| 580 |
+
return layer_norm(x)
|
| 581 |
+
else:
|
| 582 |
+
return x
|
| 583 |
+
|
| 584 |
+
def make_generation_fast_(self, need_attn=False, **kwargs):
|
| 585 |
+
self.need_attn = need_attn
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx):
|
| 589 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
| 590 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
| 591 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
| 592 |
+
return m
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def Linear(in_features, out_features, bias=True):
|
| 596 |
+
m = nn.Linear(in_features, out_features, bias)
|
| 597 |
+
nn.init.xavier_uniform_(m.weight)
|
| 598 |
+
if bias:
|
| 599 |
+
nn.init.constant_(m.bias, 0.0)
|
| 600 |
+
return m
|
fairseq-0.10.2/fairseq/model_parallel/models/roberta/model.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""
|
| 6 |
+
RoBERTa: A Robustly Optimized BERT Pretraining Approach.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from fairseq import utils
|
| 15 |
+
from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoder
|
| 16 |
+
from fairseq.models import FairseqEncoder, register_model, register_model_architecture
|
| 17 |
+
from fairseq.models.roberta import (
|
| 18 |
+
RobertaClassificationHead,
|
| 19 |
+
RobertaEncoder,
|
| 20 |
+
RobertaLMHead,
|
| 21 |
+
RobertaModel,
|
| 22 |
+
)
|
| 23 |
+
from fairseq.modules import LayerNorm, TransformerSentenceEncoder
|
| 24 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from fairseq.model_parallel.megatron.mpu import (
|
| 29 |
+
copy_to_model_parallel_region,
|
| 30 |
+
gather_from_model_parallel_region,
|
| 31 |
+
ColumnParallelLinear,
|
| 32 |
+
RowParallelLinear,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
has_megatron_submodule = True
|
| 36 |
+
except (ImportError, ModuleNotFoundError):
|
| 37 |
+
has_megatron_submodule = False
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@register_model("model_parallel_roberta")
|
| 43 |
+
class ModelParallelRobertaModel(RobertaModel):
|
| 44 |
+
def __init__(self, args, encoder):
|
| 45 |
+
super().__init__(args, encoder)
|
| 46 |
+
|
| 47 |
+
self.classification_heads = nn.ModuleDict()
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def add_args(parser):
|
| 51 |
+
super(ModelParallelRobertaModel, ModelParallelRobertaModel).add_args(parser)
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def build_model(cls, args, task):
|
| 55 |
+
"""Build a new model instance."""
|
| 56 |
+
|
| 57 |
+
# make sure all arguments are present
|
| 58 |
+
base_architecture(args)
|
| 59 |
+
|
| 60 |
+
task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
|
| 61 |
+
task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
|
| 62 |
+
|
| 63 |
+
if not hasattr(args, "max_positions"):
|
| 64 |
+
args.max_positions = args.tokens_per_sample
|
| 65 |
+
|
| 66 |
+
if getattr(args, "untie_weights_roberta", False):
|
| 67 |
+
raise NotImplementedError(
|
| 68 |
+
"--untie-weights-roberta is not supported in model parallel mode"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
encoder = ModelParallelRobertaEncoder(args, task.source_dictionary)
|
| 72 |
+
return cls(args, encoder)
|
| 73 |
+
|
| 74 |
+
def forward(
|
| 75 |
+
self,
|
| 76 |
+
src_tokens,
|
| 77 |
+
features_only=False,
|
| 78 |
+
return_all_hiddens=False,
|
| 79 |
+
classification_head_name=None,
|
| 80 |
+
**kwargs
|
| 81 |
+
):
|
| 82 |
+
if classification_head_name is not None:
|
| 83 |
+
features_only = True
|
| 84 |
+
|
| 85 |
+
x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
|
| 86 |
+
|
| 87 |
+
if classification_head_name is not None:
|
| 88 |
+
x = self.classification_heads[classification_head_name](x)
|
| 89 |
+
return x, extra
|
| 90 |
+
|
| 91 |
+
def register_classification_head(
|
| 92 |
+
self, name, num_classes=None, inner_dim=None, **kwargs
|
| 93 |
+
):
|
| 94 |
+
"""Register a classification head."""
|
| 95 |
+
if name in self.classification_heads:
|
| 96 |
+
prev_num_classes = self.classification_heads[name].out_proj.out_features
|
| 97 |
+
prev_inner_dim = self.classification_heads[name].dense.out_features
|
| 98 |
+
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
|
| 99 |
+
logger.warning(
|
| 100 |
+
're-registering head "{}" with num_classes {} (prev: {}) '
|
| 101 |
+
"and inner_dim {} (prev: {})".format(
|
| 102 |
+
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
|
| 103 |
+
)
|
| 104 |
+
)
|
| 105 |
+
self.classification_heads[name] = ModelParallelRobertaClassificationHead(
|
| 106 |
+
self.args.encoder_embed_dim,
|
| 107 |
+
inner_dim or self.args.encoder_embed_dim,
|
| 108 |
+
num_classes,
|
| 109 |
+
self.args.pooler_activation_fn,
|
| 110 |
+
self.args.pooler_dropout,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class ModelParallelRobertaLMHead(nn.Module):
|
| 115 |
+
"""Head for masked language modeling."""
|
| 116 |
+
|
| 117 |
+
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True)
|
| 120 |
+
self.activation_fn = utils.get_activation_fn(activation_fn)
|
| 121 |
+
self.layer_norm = LayerNorm(embed_dim)
|
| 122 |
+
|
| 123 |
+
if weight is None:
|
| 124 |
+
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
|
| 125 |
+
self.weight = weight
|
| 126 |
+
self.bias = nn.Parameter(torch.zeros(output_dim))
|
| 127 |
+
|
| 128 |
+
def forward(self, features, masked_tokens=None, **kwargs):
|
| 129 |
+
# Only project the unmasked tokens while training,
|
| 130 |
+
# saves both memory and computation
|
| 131 |
+
if masked_tokens is not None:
|
| 132 |
+
features = features[masked_tokens, :]
|
| 133 |
+
|
| 134 |
+
x = self.dense(features)
|
| 135 |
+
x = self.activation_fn(x)
|
| 136 |
+
x = self.layer_norm(x)
|
| 137 |
+
|
| 138 |
+
x = copy_to_model_parallel_region(x)
|
| 139 |
+
# project back to size of vocabulary with bias
|
| 140 |
+
x = F.linear(x, self.weight)
|
| 141 |
+
x = gather_from_model_parallel_region(x).contiguous()
|
| 142 |
+
x = x + self.bias
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class ModelParallelRobertaClassificationHead(nn.Module):
|
| 147 |
+
"""Head for sentence-level classification tasks."""
|
| 148 |
+
|
| 149 |
+
def __init__(
|
| 150 |
+
self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout
|
| 151 |
+
):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True)
|
| 154 |
+
self.activation_fn = utils.get_activation_fn(activation_fn)
|
| 155 |
+
self.dropout = nn.Dropout(p=pooler_dropout)
|
| 156 |
+
self.out_proj = nn.Linear(inner_dim, num_classes)
|
| 157 |
+
|
| 158 |
+
def forward(self, features, **kwargs):
|
| 159 |
+
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
| 160 |
+
x = self.dropout(x)
|
| 161 |
+
x = self.dense(x)
|
| 162 |
+
x = self.activation_fn(x)
|
| 163 |
+
x = self.dropout(x)
|
| 164 |
+
x = self.out_proj(x)
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ModelParallelRobertaEncoder(FairseqEncoder):
|
| 169 |
+
"""RoBERTa encoder.
|
| 170 |
+
|
| 171 |
+
Implements the :class:`~fairseq.models.FairseqDecoder` interface required
|
| 172 |
+
by :class:`~fairseq.models.FairseqLanguageModel`.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(self, args, dictionary):
|
| 176 |
+
super().__init__(dictionary)
|
| 177 |
+
self.args = args
|
| 178 |
+
|
| 179 |
+
# RoBERTa is a sentence encoder model, so users will intuitively trim
|
| 180 |
+
# encoder layers. However, the implementation uses the fairseq decoder,
|
| 181 |
+
# so we fix here.
|
| 182 |
+
if args.encoder_layers_to_keep:
|
| 183 |
+
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
|
| 184 |
+
args.decoder_layers_to_keep = args.encoder_layers_to_keep
|
| 185 |
+
args.encoder_layers_to_keep = None
|
| 186 |
+
|
| 187 |
+
self.sentence_encoder = ModelParallelTransformerSentenceEncoder(
|
| 188 |
+
padding_idx=dictionary.pad(),
|
| 189 |
+
vocab_size=len(dictionary),
|
| 190 |
+
num_encoder_layers=args.encoder_layers,
|
| 191 |
+
embedding_dim=args.encoder_embed_dim,
|
| 192 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
| 193 |
+
num_attention_heads=args.encoder_attention_heads,
|
| 194 |
+
dropout=args.dropout,
|
| 195 |
+
attention_dropout=args.attention_dropout,
|
| 196 |
+
activation_dropout=args.activation_dropout,
|
| 197 |
+
layerdrop=args.encoder_layerdrop,
|
| 198 |
+
max_seq_len=args.max_positions,
|
| 199 |
+
num_segments=0,
|
| 200 |
+
encoder_normalize_before=False,
|
| 201 |
+
apply_bert_init=False,
|
| 202 |
+
activation_fn=args.activation_fn,
|
| 203 |
+
)
|
| 204 |
+
self.lm_head = ModelParallelRobertaLMHead(
|
| 205 |
+
embed_dim=args.encoder_embed_dim,
|
| 206 |
+
output_dim=len(dictionary),
|
| 207 |
+
activation_fn=args.activation_fn,
|
| 208 |
+
weight=self.sentence_encoder.embed_tokens.weight,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def forward(
|
| 212 |
+
self,
|
| 213 |
+
src_tokens,
|
| 214 |
+
features_only=False,
|
| 215 |
+
return_all_hiddens=False,
|
| 216 |
+
masked_tokens=None,
|
| 217 |
+
**unused
|
| 218 |
+
):
|
| 219 |
+
"""
|
| 220 |
+
Args:
|
| 221 |
+
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
|
| 222 |
+
features_only (bool, optional): skip LM head and just return
|
| 223 |
+
features. If True, the output will be of shape
|
| 224 |
+
`(batch, src_len, embed_dim)`.
|
| 225 |
+
return_all_hiddens (bool, optional): also return all of the
|
| 226 |
+
intermediate hidden states (default: False).
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
tuple:
|
| 230 |
+
- the LM output of shape `(batch, src_len, vocab)`
|
| 231 |
+
- a dictionary of additional data, where 'inner_states'
|
| 232 |
+
is a list of hidden states. Note that the hidden
|
| 233 |
+
states have shape `(src_len, batch, vocab)`.
|
| 234 |
+
"""
|
| 235 |
+
x, extra = self.extract_features(
|
| 236 |
+
src_tokens, return_all_hiddens=return_all_hiddens
|
| 237 |
+
)
|
| 238 |
+
if not features_only:
|
| 239 |
+
x = self.output_layer(x, masked_tokens=masked_tokens)
|
| 240 |
+
return x, extra
|
| 241 |
+
|
| 242 |
+
def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
|
| 243 |
+
inner_states, _ = self.sentence_encoder(
|
| 244 |
+
src_tokens,
|
| 245 |
+
last_state_only=not return_all_hiddens,
|
| 246 |
+
)
|
| 247 |
+
features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C
|
| 248 |
+
return features, {"inner_states": inner_states if return_all_hiddens else None}
|
| 249 |
+
|
| 250 |
+
def output_layer(self, features, masked_tokens=None, **unused):
|
| 251 |
+
return self.lm_head(features, masked_tokens)
|
| 252 |
+
|
| 253 |
+
def max_positions(self):
|
| 254 |
+
"""Maximum output length supported by the encoder."""
|
| 255 |
+
return self.args.max_positions
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta")
|
| 259 |
+
def base_architecture(args):
|
| 260 |
+
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
| 261 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
| 262 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
|
| 263 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
| 264 |
+
|
| 265 |
+
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
| 266 |
+
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
| 267 |
+
|
| 268 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 269 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
| 270 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
| 271 |
+
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
| 272 |
+
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
|
| 273 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_base")
|
| 277 |
+
def roberta_base_architecture(args):
|
| 278 |
+
base_architecture(args)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_large")
|
| 282 |
+
def roberta_large_architecture(args):
|
| 283 |
+
args.encoder_layers = getattr(args, "encoder_layers", 24)
|
| 284 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
| 285 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
| 286 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
| 287 |
+
base_architecture(args)
|
fairseq-0.10.2/fairseq/model_parallel/modules/__pycache__/transformer_sentence_encoder_layer.cpython-310.pyc
ADDED
|
Binary file (2.46 kB). View file
|
|
|
fairseq-0.10.2/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from fairseq import utils
|
| 9 |
+
from fairseq.model_parallel.modules import ModelParallelMultiheadAttention
|
| 10 |
+
from fairseq.modules import TransformerSentenceEncoderLayer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from fairseq.model_parallel.megatron.mpu import (
|
| 15 |
+
ColumnParallelLinear,
|
| 16 |
+
RowParallelLinear,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
has_megatron_submodule = True
|
| 20 |
+
except (ImportError, ModuleNotFoundError):
|
| 21 |
+
has_megatron_submodule = False
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer):
|
| 25 |
+
"""
|
| 26 |
+
Implements a Model Parallel Transformer Encoder Layer used in
|
| 27 |
+
BERT/XLM style pre-trained models.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def build_fc1(self, input_dim, output_dim, **unused):
|
| 31 |
+
return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
|
| 32 |
+
|
| 33 |
+
def build_fc2(self, input_dim, output_dim, **unused):
|
| 34 |
+
return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
|
| 35 |
+
|
| 36 |
+
def build_self_attention(
|
| 37 |
+
self,
|
| 38 |
+
embed_dim,
|
| 39 |
+
num_attention_heads,
|
| 40 |
+
dropout,
|
| 41 |
+
**kwargs,
|
| 42 |
+
):
|
| 43 |
+
return ModelParallelMultiheadAttention(
|
| 44 |
+
embed_dim, num_attention_heads, dropout=dropout, self_attention=True
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def forward(
|
| 48 |
+
self,
|
| 49 |
+
x: torch.Tensor,
|
| 50 |
+
self_attn_mask: torch.Tensor = None,
|
| 51 |
+
self_attn_padding_mask: torch.Tensor = None,
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
LayerNorm is applied either before or after the self-attention/ffn
|
| 55 |
+
modules similar to the original Transformer imlementation.
|
| 56 |
+
"""
|
| 57 |
+
residual = x
|
| 58 |
+
x = self.self_attn_layer_norm(x)
|
| 59 |
+
x, attn = self.self_attn(
|
| 60 |
+
query=x,
|
| 61 |
+
key=x,
|
| 62 |
+
value=x,
|
| 63 |
+
key_padding_mask=self_attn_padding_mask,
|
| 64 |
+
need_weights=False,
|
| 65 |
+
attn_mask=self_attn_mask,
|
| 66 |
+
)
|
| 67 |
+
x = self.dropout_module(x)
|
| 68 |
+
x = residual + x
|
| 69 |
+
|
| 70 |
+
residual = x
|
| 71 |
+
x = self.final_layer_norm(x)
|
| 72 |
+
x = self.activation_fn(self.fc1(x))
|
| 73 |
+
x = self.activation_dropout_module(x)
|
| 74 |
+
x = self.fc2(x)
|
| 75 |
+
x = self.dropout_module(x)
|
| 76 |
+
x = residual + x
|
| 77 |
+
return x, None
|
fairseq-0.10.2/fairseq/modules/__init__.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""isort:skip_file"""
|
| 6 |
+
|
| 7 |
+
from .adaptive_input import AdaptiveInput
|
| 8 |
+
from .adaptive_softmax import AdaptiveSoftmax
|
| 9 |
+
from .beamable_mm import BeamableMM
|
| 10 |
+
from .character_token_embedder import CharacterTokenEmbedder
|
| 11 |
+
from .conv_tbc import ConvTBC
|
| 12 |
+
from .cross_entropy import cross_entropy
|
| 13 |
+
from .downsampled_multihead_attention import DownsampledMultiHeadAttention
|
| 14 |
+
from .dynamic_convolution import DynamicConv, DynamicConv1dTBC
|
| 15 |
+
from .dynamic_crf_layer import DynamicCRF
|
| 16 |
+
from .fairseq_dropout import FairseqDropout
|
| 17 |
+
from .fp32_group_norm import Fp32GroupNorm
|
| 18 |
+
from .gelu import gelu, gelu_accurate
|
| 19 |
+
from .grad_multiply import GradMultiply
|
| 20 |
+
from .gumbel_vector_quantizer import GumbelVectorQuantizer
|
| 21 |
+
from .kmeans_vector_quantizer import KmeansVectorQuantizer
|
| 22 |
+
from .layer_drop import LayerDropModuleList
|
| 23 |
+
from .layer_norm import Fp32LayerNorm, LayerNorm
|
| 24 |
+
from .learned_positional_embedding import LearnedPositionalEmbedding
|
| 25 |
+
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
|
| 26 |
+
from .linearized_convolution import LinearizedConvolution
|
| 27 |
+
from .multihead_attention import MultiheadAttention
|
| 28 |
+
from .positional_embedding import PositionalEmbedding
|
| 29 |
+
from .same_pad import SamePad
|
| 30 |
+
from .scalar_bias import ScalarBias
|
| 31 |
+
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
|
| 32 |
+
from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer
|
| 33 |
+
from .transformer_sentence_encoder import TransformerSentenceEncoder
|
| 34 |
+
from .transpose_last import TransposeLast
|
| 35 |
+
from .unfold import unfold1d
|
| 36 |
+
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
|
| 37 |
+
from .vggblock import VGGBlock
|
| 38 |
+
|
| 39 |
+
__all__ = [
|
| 40 |
+
"AdaptiveInput",
|
| 41 |
+
"AdaptiveSoftmax",
|
| 42 |
+
"BeamableMM",
|
| 43 |
+
"CharacterTokenEmbedder",
|
| 44 |
+
"ConvTBC",
|
| 45 |
+
"cross_entropy",
|
| 46 |
+
"DownsampledMultiHeadAttention",
|
| 47 |
+
"DynamicConv1dTBC",
|
| 48 |
+
"DynamicConv",
|
| 49 |
+
"DynamicCRF",
|
| 50 |
+
"FairseqDropout",
|
| 51 |
+
"Fp32GroupNorm",
|
| 52 |
+
"Fp32LayerNorm",
|
| 53 |
+
"gelu",
|
| 54 |
+
"gelu_accurate",
|
| 55 |
+
"GradMultiply",
|
| 56 |
+
"GumbelVectorQuantizer",
|
| 57 |
+
"KmeansVectorQuantizer",
|
| 58 |
+
"LayerDropModuleList",
|
| 59 |
+
"LayerNorm",
|
| 60 |
+
"LearnedPositionalEmbedding",
|
| 61 |
+
"LightweightConv1dTBC",
|
| 62 |
+
"LightweightConv",
|
| 63 |
+
"LinearizedConvolution",
|
| 64 |
+
"MultiheadAttention",
|
| 65 |
+
"PositionalEmbedding",
|
| 66 |
+
"SamePad",
|
| 67 |
+
"ScalarBias",
|
| 68 |
+
"SinusoidalPositionalEmbedding",
|
| 69 |
+
"TransformerSentenceEncoderLayer",
|
| 70 |
+
"TransformerSentenceEncoder",
|
| 71 |
+
"TransformerDecoderLayer",
|
| 72 |
+
"TransformerEncoderLayer",
|
| 73 |
+
"TransposeLast",
|
| 74 |
+
"VGGBlock",
|
| 75 |
+
"unfold1d",
|
| 76 |
+
]
|
fairseq-0.10.2/fairseq/modules/__pycache__/adaptive_softmax.cpython-310.pyc
ADDED
|
Binary file (6.81 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/__pycache__/dynamic_convolution.cpython-310.pyc
ADDED
|
Binary file (8.24 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/__pycache__/grad_multiply.cpython-310.pyc
ADDED
|
Binary file (703 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/modules/__pycache__/gumbel_vector_quantizer.cpython-310.pyc
ADDED
|
Binary file (5.99 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/__pycache__/kmeans_vector_quantizer.cpython-310.pyc
ADDED
|
Binary file (3.86 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/__pycache__/learned_positional_embedding.cpython-310.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/__pycache__/lightweight_convolution.cpython-310.pyc
ADDED
|
Binary file (8.87 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/__pycache__/transformer_sentence_encoder_layer.cpython-310.pyc
ADDED
|
Binary file (3.29 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/__pycache__/vggblock.cpython-310.pyc
ADDED
|
Binary file (3.49 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/adaptive_softmax.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import functools
|
| 7 |
+
import operator
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from fairseq.modules.fairseq_dropout import FairseqDropout
|
| 12 |
+
from fairseq.modules.quant_noise import quant_noise
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TiedLinear(nn.Module):
|
| 17 |
+
def __init__(self, weight, transpose):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.weight = weight
|
| 20 |
+
self.transpose = transpose
|
| 21 |
+
|
| 22 |
+
def forward(self, input):
|
| 23 |
+
return F.linear(input, self.weight.t() if self.transpose else self.weight)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TiedHeadModule(nn.Module):
|
| 27 |
+
def __init__(self, weights, input_dim, num_classes, q_noise, qn_block_size):
|
| 28 |
+
super().__init__()
|
| 29 |
+
tied_emb, _ = weights
|
| 30 |
+
self.num_words, emb_dim = tied_emb.size()
|
| 31 |
+
|
| 32 |
+
self.word_proj = quant_noise(
|
| 33 |
+
TiedLinear(tied_emb, transpose=False), q_noise, qn_block_size
|
| 34 |
+
)
|
| 35 |
+
if input_dim != emb_dim:
|
| 36 |
+
self.word_proj = nn.Sequential(
|
| 37 |
+
quant_noise(
|
| 38 |
+
nn.Linear(input_dim, emb_dim, bias=False), q_noise, qn_block_size
|
| 39 |
+
),
|
| 40 |
+
self.word_proj,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self.class_proj = quant_noise(
|
| 44 |
+
nn.Linear(input_dim, num_classes, bias=False), q_noise, qn_block_size
|
| 45 |
+
)
|
| 46 |
+
self.out_dim = self.num_words + num_classes
|
| 47 |
+
|
| 48 |
+
self.register_buffer("_float_tensor", torch.FloatTensor(1))
|
| 49 |
+
|
| 50 |
+
def forward(self, input):
|
| 51 |
+
inp_sz = functools.reduce(operator.mul, input.shape[:-1], 1)
|
| 52 |
+
out = self._float_tensor.new(inp_sz, self.out_dim)
|
| 53 |
+
out[:, : self.num_words] = self.word_proj(input.view(inp_sz, -1))
|
| 54 |
+
out[:, self.num_words :] = self.class_proj(input.view(inp_sz, -1))
|
| 55 |
+
return out
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class AdaptiveSoftmax(nn.Module):
|
| 59 |
+
"""
|
| 60 |
+
This is an implementation of the efficient softmax approximation for
|
| 61 |
+
graphical processing units (GPU), described in the paper "Efficient softmax
|
| 62 |
+
approximation for GPUs" (http://arxiv.org/abs/1609.04309).
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
vocab_size,
|
| 68 |
+
input_dim,
|
| 69 |
+
cutoff,
|
| 70 |
+
dropout,
|
| 71 |
+
factor=4.0,
|
| 72 |
+
adaptive_inputs=None,
|
| 73 |
+
tie_proj=False,
|
| 74 |
+
q_noise=0,
|
| 75 |
+
qn_block_size=8,
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
if vocab_size > cutoff[-1]:
|
| 80 |
+
cutoff = cutoff + [vocab_size]
|
| 81 |
+
else:
|
| 82 |
+
assert (
|
| 83 |
+
vocab_size == cutoff[-1]
|
| 84 |
+
), "cannot specify cutoff larger than vocab size"
|
| 85 |
+
|
| 86 |
+
output_dim = cutoff[0] + len(cutoff) - 1
|
| 87 |
+
|
| 88 |
+
self.vocab_size = vocab_size
|
| 89 |
+
self.cutoff = cutoff
|
| 90 |
+
self.dropout_module = FairseqDropout(
|
| 91 |
+
dropout, module_name=self.__class__.__name__
|
| 92 |
+
)
|
| 93 |
+
self.input_dim = input_dim
|
| 94 |
+
self.factor = factor
|
| 95 |
+
self.q_noise = q_noise
|
| 96 |
+
self.qn_block_size = qn_block_size
|
| 97 |
+
|
| 98 |
+
self.lsm = nn.LogSoftmax(dim=1)
|
| 99 |
+
|
| 100 |
+
if adaptive_inputs is not None:
|
| 101 |
+
self.head = TiedHeadModule(
|
| 102 |
+
adaptive_inputs.weights_for_band(0),
|
| 103 |
+
input_dim,
|
| 104 |
+
len(cutoff) - 1,
|
| 105 |
+
self.q_noise,
|
| 106 |
+
self.qn_block_size,
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
self.head = quant_noise(
|
| 110 |
+
nn.Linear(input_dim, output_dim, bias=False),
|
| 111 |
+
self.q_noise,
|
| 112 |
+
self.qn_block_size,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self._make_tail(adaptive_inputs, tie_proj)
|
| 116 |
+
|
| 117 |
+
def init_weights(m):
|
| 118 |
+
if (
|
| 119 |
+
hasattr(m, "weight")
|
| 120 |
+
and not isinstance(m, TiedLinear)
|
| 121 |
+
and not isinstance(m, TiedHeadModule)
|
| 122 |
+
):
|
| 123 |
+
nn.init.xavier_uniform_(m.weight)
|
| 124 |
+
|
| 125 |
+
self.apply(init_weights)
|
| 126 |
+
|
| 127 |
+
self.register_buffer("version", torch.LongTensor([1]))
|
| 128 |
+
|
| 129 |
+
def _make_tail(self, adaptive_inputs=None, tie_proj=False):
|
| 130 |
+
self.tail = nn.ModuleList()
|
| 131 |
+
for i in range(len(self.cutoff) - 1):
|
| 132 |
+
dim = int(self.input_dim // self.factor ** (i + 1))
|
| 133 |
+
|
| 134 |
+
tied_emb, tied_proj = (
|
| 135 |
+
adaptive_inputs.weights_for_band(i + 1)
|
| 136 |
+
if adaptive_inputs is not None
|
| 137 |
+
else (None, None)
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if tied_proj is not None:
|
| 141 |
+
if tie_proj:
|
| 142 |
+
proj = quant_noise(
|
| 143 |
+
TiedLinear(tied_proj, transpose=True),
|
| 144 |
+
self.q_noise,
|
| 145 |
+
self.qn_block_size,
|
| 146 |
+
)
|
| 147 |
+
else:
|
| 148 |
+
proj = quant_noise(
|
| 149 |
+
nn.Linear(tied_proj.size(0), tied_proj.size(1), bias=False),
|
| 150 |
+
self.q_noise,
|
| 151 |
+
self.qn_block_size,
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
proj = quant_noise(
|
| 155 |
+
nn.Linear(self.input_dim, dim, bias=False),
|
| 156 |
+
self.q_noise,
|
| 157 |
+
self.qn_block_size,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if tied_emb is None:
|
| 161 |
+
out_proj = nn.Linear(
|
| 162 |
+
dim, self.cutoff[i + 1] - self.cutoff[i], bias=False
|
| 163 |
+
)
|
| 164 |
+
else:
|
| 165 |
+
out_proj = TiedLinear(tied_emb, transpose=False)
|
| 166 |
+
|
| 167 |
+
m = nn.Sequential(
|
| 168 |
+
proj,
|
| 169 |
+
nn.Dropout(self.dropout_module.p),
|
| 170 |
+
quant_noise(out_proj, self.q_noise, self.qn_block_size),
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
self.tail.append(m)
|
| 174 |
+
|
| 175 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 176 |
+
version_name = name + ".version"
|
| 177 |
+
if version_name not in state_dict:
|
| 178 |
+
raise Exception("This version of the model is no longer supported")
|
| 179 |
+
|
| 180 |
+
def adapt_target(self, target):
|
| 181 |
+
"""
|
| 182 |
+
In order to be efficient, the AdaptiveSoftMax does not compute the
|
| 183 |
+
scores for all the word of the vocabulary for all the examples. It is
|
| 184 |
+
thus necessary to call the method adapt_target of the AdaptiveSoftMax
|
| 185 |
+
layer inside each forward pass.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
target = target.view(-1)
|
| 189 |
+
new_target = [target.clone()]
|
| 190 |
+
target_idxs = []
|
| 191 |
+
|
| 192 |
+
for i in range(len(self.cutoff) - 1):
|
| 193 |
+
mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
|
| 194 |
+
new_target[0][mask] = self.cutoff[0] + i
|
| 195 |
+
|
| 196 |
+
if mask.any():
|
| 197 |
+
target_idxs.append(mask.nonzero(as_tuple=False).squeeze(1))
|
| 198 |
+
new_target.append(target[mask].add(-self.cutoff[i]))
|
| 199 |
+
else:
|
| 200 |
+
target_idxs.append(None)
|
| 201 |
+
new_target.append(None)
|
| 202 |
+
|
| 203 |
+
return new_target, target_idxs
|
| 204 |
+
|
| 205 |
+
def forward(self, input, target):
|
| 206 |
+
"""
|
| 207 |
+
Args:
|
| 208 |
+
input: (b x t x d)
|
| 209 |
+
target: (b x t)
|
| 210 |
+
Returns:
|
| 211 |
+
2 lists: output for each cutoff section and new targets by cut off
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
input = input.contiguous().view(-1, input.size(-1))
|
| 215 |
+
input = self.dropout_module(input)
|
| 216 |
+
|
| 217 |
+
new_target, target_idxs = self.adapt_target(target)
|
| 218 |
+
output = [self.head(input)]
|
| 219 |
+
|
| 220 |
+
for i in range(len(target_idxs)):
|
| 221 |
+
if target_idxs[i] is not None:
|
| 222 |
+
output.append(self.tail[i](input.index_select(0, target_idxs[i])))
|
| 223 |
+
else:
|
| 224 |
+
output.append(None)
|
| 225 |
+
|
| 226 |
+
return output, new_target
|
| 227 |
+
|
| 228 |
+
def get_log_prob(self, input, target):
|
| 229 |
+
"""
|
| 230 |
+
Computes the log probabilities for all the words of the vocabulary,
|
| 231 |
+
given a 2D tensor of hidden vectors.
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
bsz, length, dim = input.size()
|
| 235 |
+
input = input.contiguous().view(-1, dim)
|
| 236 |
+
|
| 237 |
+
if target is not None:
|
| 238 |
+
_, target_idxs = self.adapt_target(target)
|
| 239 |
+
else:
|
| 240 |
+
target_idxs = None
|
| 241 |
+
|
| 242 |
+
head_y = self.head(input)
|
| 243 |
+
log_probs = head_y.new_zeros(input.size(0), self.vocab_size)
|
| 244 |
+
|
| 245 |
+
head_sz = self.cutoff[0] + len(self.tail)
|
| 246 |
+
log_probs[:, :head_sz] = self.lsm(head_y)
|
| 247 |
+
tail_priors = log_probs[:, self.cutoff[0] : head_sz].clone()
|
| 248 |
+
|
| 249 |
+
for i in range(len(self.tail)):
|
| 250 |
+
start = self.cutoff[i]
|
| 251 |
+
end = self.cutoff[i + 1]
|
| 252 |
+
|
| 253 |
+
if target_idxs is None:
|
| 254 |
+
tail_out = log_probs[:, start:end]
|
| 255 |
+
tail_out.copy_(self.tail[i](input))
|
| 256 |
+
log_probs[:, start:end] = self.lsm(tail_out).add_(
|
| 257 |
+
tail_priors[:, i, None]
|
| 258 |
+
)
|
| 259 |
+
elif target_idxs[i] is not None:
|
| 260 |
+
idxs = target_idxs[i]
|
| 261 |
+
tail_out = log_probs[idxs, start:end]
|
| 262 |
+
tail_out.copy_(self.tail[i](input[idxs]))
|
| 263 |
+
log_probs[idxs, start:end] = self.lsm(tail_out).add_(
|
| 264 |
+
tail_priors[idxs, i, None]
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
log_probs = log_probs.view(bsz, length, -1)
|
| 268 |
+
return log_probs
|
fairseq-0.10.2/fairseq/modules/beamable_mm.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BeamableMM(nn.Module):
|
| 11 |
+
"""This module provides an optimized MM for beam decoding with attention.
|
| 12 |
+
|
| 13 |
+
It leverage the fact that the source-side of the input is replicated beam
|
| 14 |
+
times and the target-side of the input is of width one. This layer speeds up
|
| 15 |
+
inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
|
| 16 |
+
with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, beam_size=None):
|
| 20 |
+
super(BeamableMM, self).__init__()
|
| 21 |
+
self.beam_size = beam_size
|
| 22 |
+
|
| 23 |
+
def forward(self, input1, input2):
|
| 24 |
+
if (
|
| 25 |
+
not self.training
|
| 26 |
+
and self.beam_size is not None # test mode
|
| 27 |
+
and input1.dim() == 3 # beam size is set
|
| 28 |
+
and input1.size(1) # only support batched input
|
| 29 |
+
== 1 # single time step update
|
| 30 |
+
):
|
| 31 |
+
bsz, beam = input1.size(0), self.beam_size
|
| 32 |
+
|
| 33 |
+
# bsz x 1 x nhu --> bsz/beam x beam x nhu
|
| 34 |
+
input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1)
|
| 35 |
+
|
| 36 |
+
# bsz x sz2 x nhu --> bsz/beam x sz2 x nhu
|
| 37 |
+
input2 = input2.unfold(0, beam, beam)[:, :, :, 0]
|
| 38 |
+
|
| 39 |
+
# use non batched operation if bsz = beam
|
| 40 |
+
if input1.size(0) == 1:
|
| 41 |
+
output = torch.mm(input1[0, :, :], input2[0, :, :])
|
| 42 |
+
else:
|
| 43 |
+
output = input1.bmm(input2)
|
| 44 |
+
return output.view(bsz, 1, -1)
|
| 45 |
+
else:
|
| 46 |
+
return input1.bmm(input2)
|
| 47 |
+
|
| 48 |
+
def set_beam_size(self, beam_size):
|
| 49 |
+
self.beam_size = beam_size
|
fairseq-0.10.2/fairseq/modules/character_token_embedder.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from typing import List, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from fairseq.data import Dictionary
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
CHAR_PAD_IDX = 0
|
| 16 |
+
CHAR_EOS_IDX = 257
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class CharacterTokenEmbedder(torch.nn.Module):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
vocab: Dictionary,
|
| 26 |
+
filters: List[Tuple[int, int]],
|
| 27 |
+
char_embed_dim: int,
|
| 28 |
+
word_embed_dim: int,
|
| 29 |
+
highway_layers: int,
|
| 30 |
+
max_char_len: int = 50,
|
| 31 |
+
char_inputs: bool = False,
|
| 32 |
+
):
|
| 33 |
+
super(CharacterTokenEmbedder, self).__init__()
|
| 34 |
+
|
| 35 |
+
self.onnx_trace = False
|
| 36 |
+
self.embedding_dim = word_embed_dim
|
| 37 |
+
self.max_char_len = max_char_len
|
| 38 |
+
self.char_embeddings = nn.Embedding(257, char_embed_dim, padding_idx=0)
|
| 39 |
+
self.symbol_embeddings = nn.Parameter(torch.FloatTensor(2, word_embed_dim))
|
| 40 |
+
self.eos_idx, self.unk_idx = 0, 1
|
| 41 |
+
self.char_inputs = char_inputs
|
| 42 |
+
|
| 43 |
+
self.convolutions = nn.ModuleList()
|
| 44 |
+
for width, out_c in filters:
|
| 45 |
+
self.convolutions.append(
|
| 46 |
+
nn.Conv1d(char_embed_dim, out_c, kernel_size=width)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
last_dim = sum(f[1] for f in filters)
|
| 50 |
+
|
| 51 |
+
self.highway = Highway(last_dim, highway_layers) if highway_layers > 0 else None
|
| 52 |
+
|
| 53 |
+
self.projection = nn.Linear(last_dim, word_embed_dim)
|
| 54 |
+
|
| 55 |
+
assert (
|
| 56 |
+
vocab is not None or char_inputs
|
| 57 |
+
), "vocab must be set if not using char inputs"
|
| 58 |
+
self.vocab = None
|
| 59 |
+
if vocab is not None:
|
| 60 |
+
self.set_vocab(vocab, max_char_len)
|
| 61 |
+
|
| 62 |
+
self.reset_parameters()
|
| 63 |
+
|
| 64 |
+
def prepare_for_onnx_export_(self):
|
| 65 |
+
self.onnx_trace = True
|
| 66 |
+
|
| 67 |
+
def set_vocab(self, vocab, max_char_len):
|
| 68 |
+
word_to_char = torch.LongTensor(len(vocab), max_char_len)
|
| 69 |
+
|
| 70 |
+
truncated = 0
|
| 71 |
+
for i in range(len(vocab)):
|
| 72 |
+
if i < vocab.nspecial:
|
| 73 |
+
char_idxs = [0] * max_char_len
|
| 74 |
+
else:
|
| 75 |
+
chars = vocab[i].encode()
|
| 76 |
+
# +1 for padding
|
| 77 |
+
char_idxs = [c + 1 for c in chars] + [0] * (max_char_len - len(chars))
|
| 78 |
+
if len(char_idxs) > max_char_len:
|
| 79 |
+
truncated += 1
|
| 80 |
+
char_idxs = char_idxs[:max_char_len]
|
| 81 |
+
word_to_char[i] = torch.LongTensor(char_idxs)
|
| 82 |
+
|
| 83 |
+
if truncated > 0:
|
| 84 |
+
logger.info(
|
| 85 |
+
"truncated {} words longer than {} characters".format(
|
| 86 |
+
truncated, max_char_len
|
| 87 |
+
)
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
self.vocab = vocab
|
| 91 |
+
self.word_to_char = word_to_char
|
| 92 |
+
|
| 93 |
+
@property
|
| 94 |
+
def padding_idx(self):
|
| 95 |
+
return Dictionary().pad() if self.vocab is None else self.vocab.pad()
|
| 96 |
+
|
| 97 |
+
def reset_parameters(self):
|
| 98 |
+
nn.init.xavier_normal_(self.char_embeddings.weight)
|
| 99 |
+
nn.init.xavier_normal_(self.symbol_embeddings)
|
| 100 |
+
nn.init.xavier_uniform_(self.projection.weight)
|
| 101 |
+
|
| 102 |
+
nn.init.constant_(
|
| 103 |
+
self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.0
|
| 104 |
+
)
|
| 105 |
+
nn.init.constant_(self.projection.bias, 0.0)
|
| 106 |
+
|
| 107 |
+
def forward(
|
| 108 |
+
self,
|
| 109 |
+
input: torch.Tensor,
|
| 110 |
+
):
|
| 111 |
+
if self.char_inputs:
|
| 112 |
+
chars = input.view(-1, self.max_char_len)
|
| 113 |
+
pads = chars[:, 0].eq(CHAR_PAD_IDX)
|
| 114 |
+
eos = chars[:, 0].eq(CHAR_EOS_IDX)
|
| 115 |
+
if eos.any():
|
| 116 |
+
if self.onnx_trace:
|
| 117 |
+
chars = torch.where(eos.unsqueeze(1), chars.new_zeros(1), chars)
|
| 118 |
+
else:
|
| 119 |
+
chars[eos] = 0
|
| 120 |
+
|
| 121 |
+
unk = None
|
| 122 |
+
else:
|
| 123 |
+
flat_words = input.view(-1)
|
| 124 |
+
chars = self.word_to_char[flat_words.type_as(self.word_to_char)].type_as(
|
| 125 |
+
input
|
| 126 |
+
)
|
| 127 |
+
pads = flat_words.eq(self.vocab.pad())
|
| 128 |
+
eos = flat_words.eq(self.vocab.eos())
|
| 129 |
+
unk = flat_words.eq(self.vocab.unk())
|
| 130 |
+
|
| 131 |
+
word_embs = self._convolve(chars)
|
| 132 |
+
if self.onnx_trace:
|
| 133 |
+
if pads.any():
|
| 134 |
+
word_embs = torch.where(
|
| 135 |
+
pads.unsqueeze(1), word_embs.new_zeros(1), word_embs
|
| 136 |
+
)
|
| 137 |
+
if eos.any():
|
| 138 |
+
word_embs = torch.where(
|
| 139 |
+
eos.unsqueeze(1), self.symbol_embeddings[self.eos_idx], word_embs
|
| 140 |
+
)
|
| 141 |
+
if unk is not None and unk.any():
|
| 142 |
+
word_embs = torch.where(
|
| 143 |
+
unk.unsqueeze(1), self.symbol_embeddings[self.unk_idx], word_embs
|
| 144 |
+
)
|
| 145 |
+
else:
|
| 146 |
+
if pads.any():
|
| 147 |
+
word_embs[pads] = 0
|
| 148 |
+
if eos.any():
|
| 149 |
+
word_embs[eos] = self.symbol_embeddings[self.eos_idx]
|
| 150 |
+
if unk is not None and unk.any():
|
| 151 |
+
word_embs[unk] = self.symbol_embeddings[self.unk_idx]
|
| 152 |
+
|
| 153 |
+
return word_embs.view(input.size()[:2] + (-1,))
|
| 154 |
+
|
| 155 |
+
def _convolve(
|
| 156 |
+
self,
|
| 157 |
+
char_idxs: torch.Tensor,
|
| 158 |
+
):
|
| 159 |
+
char_embs = self.char_embeddings(char_idxs)
|
| 160 |
+
char_embs = char_embs.transpose(1, 2) # BTC -> BCT
|
| 161 |
+
|
| 162 |
+
conv_result = []
|
| 163 |
+
|
| 164 |
+
for conv in self.convolutions:
|
| 165 |
+
x = conv(char_embs)
|
| 166 |
+
x, _ = torch.max(x, -1)
|
| 167 |
+
x = F.relu(x)
|
| 168 |
+
conv_result.append(x)
|
| 169 |
+
|
| 170 |
+
x = torch.cat(conv_result, dim=-1)
|
| 171 |
+
|
| 172 |
+
if self.highway is not None:
|
| 173 |
+
x = self.highway(x)
|
| 174 |
+
x = self.projection(x)
|
| 175 |
+
|
| 176 |
+
return x
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class Highway(torch.nn.Module):
|
| 180 |
+
"""
|
| 181 |
+
A `Highway layer <https://arxiv.org/abs/1505.00387>`_.
|
| 182 |
+
Adopted from the AllenNLP implementation.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
def __init__(self, input_dim: int, num_layers: int = 1):
|
| 186 |
+
super(Highway, self).__init__()
|
| 187 |
+
self.input_dim = input_dim
|
| 188 |
+
self.layers = nn.ModuleList(
|
| 189 |
+
[nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)]
|
| 190 |
+
)
|
| 191 |
+
self.activation = nn.ReLU()
|
| 192 |
+
|
| 193 |
+
self.reset_parameters()
|
| 194 |
+
|
| 195 |
+
def reset_parameters(self):
|
| 196 |
+
for layer in self.layers:
|
| 197 |
+
# As per comment in AllenNLP:
|
| 198 |
+
# We should bias the highway layer to just carry its input forward. We do that by
|
| 199 |
+
# setting the bias on `B(x)` to be positive, because that means `g` will be biased to
|
| 200 |
+
# be high, so we will carry the input forward. The bias on `B(x)` is the second half
|
| 201 |
+
# of the bias vector in each Linear layer.
|
| 202 |
+
nn.init.constant_(layer.bias[self.input_dim :], 1)
|
| 203 |
+
|
| 204 |
+
nn.init.constant_(layer.bias[: self.input_dim], 0)
|
| 205 |
+
nn.init.xavier_normal_(layer.weight)
|
| 206 |
+
|
| 207 |
+
def forward(self, x: torch.Tensor):
|
| 208 |
+
for layer in self.layers:
|
| 209 |
+
projection = layer(x)
|
| 210 |
+
proj_x, gate = projection.chunk(2, dim=-1)
|
| 211 |
+
proj_x = self.activation(proj_x)
|
| 212 |
+
gate = torch.sigmoid(gate)
|
| 213 |
+
x = gate * x + (gate.new_tensor([1]) - gate) * proj_x
|
| 214 |
+
return x
|
fairseq-0.10.2/fairseq/modules/cross_entropy.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"):
|
| 16 |
+
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
| 17 |
+
return F.nll_loss(
|
| 18 |
+
lprobs,
|
| 19 |
+
target,
|
| 20 |
+
ignore_index=ignore_index,
|
| 21 |
+
reduction=reduction,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
import xentropy_cuda
|
| 27 |
+
from apex.contrib import xentropy
|
| 28 |
+
|
| 29 |
+
logger.info("using fused cross entropy")
|
| 30 |
+
|
| 31 |
+
def cross_entropy(logits, target, ignore_index=-100, reduction="mean"):
|
| 32 |
+
if logits.device == torch.device("cpu"):
|
| 33 |
+
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
|
| 34 |
+
else:
|
| 35 |
+
half_to_float = logits.dtype == torch.half
|
| 36 |
+
losses = xentropy.SoftmaxCrossEntropyLoss.apply(
|
| 37 |
+
logits,
|
| 38 |
+
target,
|
| 39 |
+
0.0,
|
| 40 |
+
ignore_index,
|
| 41 |
+
half_to_float,
|
| 42 |
+
)
|
| 43 |
+
if reduction == "sum":
|
| 44 |
+
return losses.sum()
|
| 45 |
+
elif reduction == "mean":
|
| 46 |
+
if ignore_index >= 0:
|
| 47 |
+
return losses.sum() / target.ne(ignore_index).sum()
|
| 48 |
+
else:
|
| 49 |
+
return losses.mean()
|
| 50 |
+
elif reduction == "none":
|
| 51 |
+
return losses
|
| 52 |
+
else:
|
| 53 |
+
raise NotImplementedError
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
except ImportError:
|
| 57 |
+
|
| 58 |
+
def cross_entropy(logits, target, ignore_index=-100, reduction="mean"):
|
| 59 |
+
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
|
fairseq-0.10.2/fairseq/modules/fp32_group_norm.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""
|
| 6 |
+
Layer norm done in fp32 (for fp16 training)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Fp32GroupNorm(nn.GroupNorm):
|
| 14 |
+
def __init__(self, *args, **kwargs):
|
| 15 |
+
super().__init__(*args, **kwargs)
|
| 16 |
+
|
| 17 |
+
def forward(self, input):
|
| 18 |
+
output = F.group_norm(
|
| 19 |
+
input.float(),
|
| 20 |
+
self.num_groups,
|
| 21 |
+
self.weight.float() if self.weight is not None else None,
|
| 22 |
+
self.bias.float() if self.bias is not None else None,
|
| 23 |
+
self.eps,
|
| 24 |
+
)
|
| 25 |
+
return output.type_as(input)
|
fairseq-0.10.2/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
*
|
| 4 |
+
* This source code is licensed under the MIT license found in the
|
| 5 |
+
* LICENSE file in the root directory of this source tree.
|
| 6 |
+
*/
|
| 7 |
+
|
| 8 |
+
#include "lightconv_cuda.cuh"
|
| 9 |
+
#include "lightconv_cuda_forward.cu"
|
| 10 |
+
#include "lightconv_cuda_backward.cu"
|
| 11 |
+
#include "../cuda_utils.cu"
|
| 12 |
+
|
| 13 |
+
template<int FS, int SB, int padding_l, typename scalar_t>
|
| 14 |
+
__global__
|
| 15 |
+
void lightconv_forward_kernel(const scalar_t* input,
|
| 16 |
+
const scalar_t* filters,
|
| 17 |
+
int minibatch, int sequenceLength,
|
| 18 |
+
int numFeatures, int numFiltersInBlock,
|
| 19 |
+
scalar_t* output) {
|
| 20 |
+
|
| 21 |
+
const int tid = threadIdx.x;
|
| 22 |
+
const int batchIdx = blockIdx.x;
|
| 23 |
+
const int featureIdx = blockIdx.y;
|
| 24 |
+
const int filterIdx = featureIdx / numFiltersInBlock;
|
| 25 |
+
|
| 26 |
+
const int IOOffset = numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength;
|
| 27 |
+
const scalar_t* inputFeature = &input[IOOffset];
|
| 28 |
+
scalar_t* outputFeature = &output[IOOffset];
|
| 29 |
+
const scalar_t* inputFilter = &filters[filterIdx * FS];
|
| 30 |
+
|
| 31 |
+
assert(blockDim.x == SB);
|
| 32 |
+
|
| 33 |
+
scalar_t filter[FS];
|
| 34 |
+
#pragma unroll
|
| 35 |
+
for (int i = 0; i < FS; ++i) {
|
| 36 |
+
filter[i] = inputFilter[i];
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
__shared__ scalar_t temp[SB + FS];
|
| 40 |
+
zeroSharedMem<FS, SB, padding_l>(temp);
|
| 41 |
+
|
| 42 |
+
const int numIterations = divUp<int, int>(sequenceLength, SB);
|
| 43 |
+
|
| 44 |
+
for (int i = 0; i < numIterations; ++i) {
|
| 45 |
+
// Read input into shared memory
|
| 46 |
+
const int inputOffset = i * SB;
|
| 47 |
+
|
| 48 |
+
load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset, sequenceLength,
|
| 49 |
+
i, numIterations, (numIterations == 1), temp);
|
| 50 |
+
|
| 51 |
+
__syncthreads();
|
| 52 |
+
|
| 53 |
+
scalar_t out = 0;
|
| 54 |
+
#pragma unroll
|
| 55 |
+
for (int j = 0; j < FS; ++j) {
|
| 56 |
+
out += filter[j] * temp[tid + j];
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
// Write output
|
| 60 |
+
const int outputOffset = inputOffset;
|
| 61 |
+
if ((outputOffset + tid) < sequenceLength) {
|
| 62 |
+
outputFeature[outputOffset + tid] = out;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
__syncthreads();
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
template<int FS, int SB, int padding_l, typename scalar_t>
|
| 70 |
+
__global__
|
| 71 |
+
void lightconv_grad_wrt_input_kernel(
|
| 72 |
+
const scalar_t* input,
|
| 73 |
+
const scalar_t* filters,
|
| 74 |
+
int minibatch,
|
| 75 |
+
int sequenceLength,
|
| 76 |
+
int numFeatures,
|
| 77 |
+
int numFiltersInBlock,
|
| 78 |
+
scalar_t* output) {
|
| 79 |
+
|
| 80 |
+
// input grad kernel is similar to forward kernel
|
| 81 |
+
const int tid = threadIdx.x;
|
| 82 |
+
const int batchIdx = blockIdx.x;
|
| 83 |
+
const int featureIdx = blockIdx.y;
|
| 84 |
+
const int filterIdx = featureIdx / numFiltersInBlock;
|
| 85 |
+
|
| 86 |
+
const int IOOffset = numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength;
|
| 87 |
+
const scalar_t* inputFeature = &input[IOOffset];
|
| 88 |
+
scalar_t* outputFeature = &output[IOOffset];
|
| 89 |
+
const scalar_t* inputFilter = &filters[filterIdx * FS];
|
| 90 |
+
|
| 91 |
+
assert(blockDim.x == SB);
|
| 92 |
+
|
| 93 |
+
scalar_t filter[FS];
|
| 94 |
+
|
| 95 |
+
// The only change is loading the filter in reverse
|
| 96 |
+
#pragma unroll
|
| 97 |
+
for (int i = 0; i < FS; ++i) {
|
| 98 |
+
filter[i] = inputFilter[FS - i - 1];
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
__shared__ scalar_t temp[SB + FS];
|
| 102 |
+
const int padding = FS - padding_l - 1;
|
| 103 |
+
zeroSharedMem<FS, SB, padding>(temp);
|
| 104 |
+
|
| 105 |
+
__syncthreads();
|
| 106 |
+
|
| 107 |
+
const int numIterations = divUp<int, int>(sequenceLength, SB);
|
| 108 |
+
|
| 109 |
+
for (int i = 0; i < numIterations; ++i) {
|
| 110 |
+
// Read input into shared memory
|
| 111 |
+
const int inputOffset = i * SB;
|
| 112 |
+
|
| 113 |
+
load_input_to_shared<FS, SB, padding>(inputFeature, inputOffset, sequenceLength,
|
| 114 |
+
i, numIterations, false, temp);
|
| 115 |
+
|
| 116 |
+
__syncthreads();
|
| 117 |
+
|
| 118 |
+
scalar_t out = 0;
|
| 119 |
+
#pragma unroll
|
| 120 |
+
for (int j = 0; j < FS; ++j) {
|
| 121 |
+
out += filter[j] * temp[tid + j];
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
// Write output
|
| 125 |
+
const int outputOffset = inputOffset;
|
| 126 |
+
if ((outputOffset + tid) < sequenceLength) {
|
| 127 |
+
outputFeature[outputOffset + tid] = out;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
__syncthreads();
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
// This is by far the most expensive kernel in terms of time taken.
|
| 135 |
+
// Can be 16x slower than the forward or grad_wrt_input when filter size is 31
|
| 136 |
+
template<int FS, int SB, int padding_l, typename scalar_t>
|
| 137 |
+
__global__
|
| 138 |
+
void lightconv_grad_wrt_weights_firstpass_short_kernel(
|
| 139 |
+
const scalar_t* input,
|
| 140 |
+
const scalar_t* gradInput,
|
| 141 |
+
int minibatch,
|
| 142 |
+
int sequenceLength,
|
| 143 |
+
int numFeatures,
|
| 144 |
+
int numFiltersInBlock,
|
| 145 |
+
int numHeads,
|
| 146 |
+
float* output) {
|
| 147 |
+
|
| 148 |
+
const int tid = threadIdx.x;
|
| 149 |
+
const int batchIdx = blockIdx.x;
|
| 150 |
+
const int filterIdx = blockIdx.y;
|
| 151 |
+
|
| 152 |
+
const int numIterations = divUp<int, int>(sequenceLength, SB);
|
| 153 |
+
|
| 154 |
+
float* tempOutputGradWeight = &output[filterIdx * FS * minibatch];
|
| 155 |
+
|
| 156 |
+
assert(blockDim.x == SB);
|
| 157 |
+
|
| 158 |
+
__shared__ scalar_t tempInput[SB + FS];
|
| 159 |
+
__shared__ scalar_t tempGradInput[SB + FS];
|
| 160 |
+
|
| 161 |
+
// local weight accumulation
|
| 162 |
+
float accumWeights[FS];
|
| 163 |
+
|
| 164 |
+
// Initialize memory
|
| 165 |
+
for (int i = 0; i < FS; ++i) {
|
| 166 |
+
accumWeights[i] = float(0.0);
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
// loop over each sequence within filterblock
|
| 171 |
+
for (int idxInFilterBlock = 0; idxInFilterBlock < numFiltersInBlock; ++idxInFilterBlock) {
|
| 172 |
+
|
| 173 |
+
const int featureOffset = batchIdx * numFeatures * sequenceLength + (filterIdx * numFiltersInBlock + idxInFilterBlock) * sequenceLength;
|
| 174 |
+
const scalar_t* inputFeature = &input[featureOffset];
|
| 175 |
+
const scalar_t* gradInputFeature = &gradInput[featureOffset];
|
| 176 |
+
|
| 177 |
+
zeroSharedMem<FS, SB, padding_l>(tempInput);
|
| 178 |
+
zeroSharedMem<FS, SB, (FS/2)>(tempGradInput);
|
| 179 |
+
__syncthreads();
|
| 180 |
+
|
| 181 |
+
for (int i = 0; i < numIterations; ++i) {
|
| 182 |
+
|
| 183 |
+
const int inputOffset = i * SB;
|
| 184 |
+
|
| 185 |
+
load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset, sequenceLength,
|
| 186 |
+
i, numIterations, false, tempInput);
|
| 187 |
+
load_input_to_shared<FS, SB, (FS/2)>(gradInputFeature, inputOffset, sequenceLength,
|
| 188 |
+
i, numIterations, false, tempGradInput);
|
| 189 |
+
|
| 190 |
+
__syncthreads();
|
| 191 |
+
|
| 192 |
+
const int gradIndex = (FS/2) + tid;
|
| 193 |
+
scalar_t tempGrad = tempGradInput[gradIndex];
|
| 194 |
+
|
| 195 |
+
#pragma unroll
|
| 196 |
+
for (int j = 0; j < FS; j++) {
|
| 197 |
+
const int inputIndex = tid + j;
|
| 198 |
+
accumWeights[j] += tempInput[inputIndex] * tempGrad;
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
__syncthreads();
|
| 202 |
+
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
// Row-major sum
|
| 208 |
+
for (int filterWeightIdx = 0; filterWeightIdx < FS; ++filterWeightIdx) {
|
| 209 |
+
|
| 210 |
+
float temp;
|
| 211 |
+
if (tid < sequenceLength) {
|
| 212 |
+
temp = accumWeights[filterWeightIdx];
|
| 213 |
+
} else {
|
| 214 |
+
temp = float(0.0);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
const int outputOffset = filterWeightIdx * minibatch + batchIdx;
|
| 218 |
+
|
| 219 |
+
temp = blockReduce(temp);
|
| 220 |
+
|
| 221 |
+
if (tid == 0) {
|
| 222 |
+
tempOutputGradWeight[outputOffset] = temp;
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
template<int FS, int SB, typename scalar_t>
|
| 228 |
+
__global__
|
| 229 |
+
void lightconv_grad_wrt_weights_secondpass_short_kernel(
|
| 230 |
+
const float* input,
|
| 231 |
+
const int minibatch,
|
| 232 |
+
const int numFiltersInBlock,
|
| 233 |
+
scalar_t* output) {
|
| 234 |
+
|
| 235 |
+
assert(blockDim.x == SB);
|
| 236 |
+
|
| 237 |
+
const int tid = threadIdx.x;
|
| 238 |
+
|
| 239 |
+
const int filterIdx = blockIdx.x;
|
| 240 |
+
const int filterWeightIdx = blockIdx.y;
|
| 241 |
+
|
| 242 |
+
const int inputOffset = filterIdx * FS * minibatch +
|
| 243 |
+
filterWeightIdx * minibatch;
|
| 244 |
+
const float* tempInput = &input[inputOffset];
|
| 245 |
+
|
| 246 |
+
// read into shared memory for reduction
|
| 247 |
+
int readIndex = tid;
|
| 248 |
+
|
| 249 |
+
float sum = 0.0;
|
| 250 |
+
while (readIndex < minibatch) {
|
| 251 |
+
sum += tempInput[readIndex];
|
| 252 |
+
readIndex += SB;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
float temp = blockReduce(sum);
|
| 256 |
+
|
| 257 |
+
if (tid == 0) {
|
| 258 |
+
output[blockIdx.x * FS + blockIdx.y] = temp;
|
| 259 |
+
}
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
// This is by far the most expensive kernel in terms of time taken.
|
| 263 |
+
// Can be 16x slower than the forward or grad_wrt_input when filter size is 31
|
| 264 |
+
template<int FS, int SB, int padding_l, typename scalar_t>
|
| 265 |
+
__global__
|
| 266 |
+
void lightconv_grad_wrt_weights_firstpass_kernel(
|
| 267 |
+
const scalar_t* input,
|
| 268 |
+
const scalar_t* gradInput,
|
| 269 |
+
int minibatch,
|
| 270 |
+
int sequenceLength,
|
| 271 |
+
int numFeatures,
|
| 272 |
+
int numFiltersInBlock,
|
| 273 |
+
float* output) {
|
| 274 |
+
|
| 275 |
+
assert(blockDim.x == SB);
|
| 276 |
+
|
| 277 |
+
const int tid = threadIdx.x;
|
| 278 |
+
const int batchIdx = blockIdx.x;
|
| 279 |
+
const int featureIdx = blockIdx.y;
|
| 280 |
+
const int filterIdx = featureIdx / numFiltersInBlock;
|
| 281 |
+
const int idxInFilterBlock = featureIdx % numFiltersInBlock;
|
| 282 |
+
|
| 283 |
+
const int numIterations = divUp<int, int>(sequenceLength, SB);
|
| 284 |
+
|
| 285 |
+
float temp;
|
| 286 |
+
|
| 287 |
+
__shared__ scalar_t tempInput[SB + FS];
|
| 288 |
+
__shared__ scalar_t tempGradInput[SB + FS];
|
| 289 |
+
zeroSharedMem<FS, SB, padding_l>(tempInput);
|
| 290 |
+
zeroSharedMem<FS, SB, (FS/2)>(tempGradInput);
|
| 291 |
+
__syncthreads();
|
| 292 |
+
|
| 293 |
+
float accumWeights[FS];
|
| 294 |
+
|
| 295 |
+
for (int i = 0; i < FS; ++i) {
|
| 296 |
+
accumWeights[i] = float(0.0);
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
const int IOOffset = batchIdx * numFeatures * sequenceLength + featureIdx * sequenceLength;
|
| 300 |
+
const scalar_t* inputFeature = &input[IOOffset];
|
| 301 |
+
const scalar_t* gradInputFeature = &gradInput[IOOffset];
|
| 302 |
+
float* tempOutputGradWeight = &output[filterIdx * FS * minibatch * numFiltersInBlock];
|
| 303 |
+
|
| 304 |
+
for (int i = 0; i < numIterations; ++i) {
|
| 305 |
+
const int inputOffset = i * SB;
|
| 306 |
+
|
| 307 |
+
load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset, sequenceLength,
|
| 308 |
+
i, numIterations, false, tempInput);
|
| 309 |
+
load_input_to_shared<FS, SB, (FS/2)>(gradInputFeature, inputOffset, sequenceLength,
|
| 310 |
+
i, numIterations, false, tempGradInput);
|
| 311 |
+
__syncthreads();
|
| 312 |
+
|
| 313 |
+
#pragma unroll
|
| 314 |
+
for (int j = 0; j < FS; ++j) {
|
| 315 |
+
accumWeights[j] += tempInput[tid + j] * tempGradInput[tid + (FS/2)];
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
__syncthreads();
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
// Row-major sum
|
| 322 |
+
for (int filterWeightIdx = 0; filterWeightIdx < FS; ++filterWeightIdx) {
|
| 323 |
+
|
| 324 |
+
// Write to shared memory before reduction
|
| 325 |
+
if (tid < sequenceLength) {
|
| 326 |
+
temp = accumWeights[filterWeightIdx];
|
| 327 |
+
} else {
|
| 328 |
+
temp = float(0.0);
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
temp = blockReduce(temp);
|
| 332 |
+
|
| 333 |
+
const int outputOffset = filterWeightIdx * minibatch * numFiltersInBlock +
|
| 334 |
+
batchIdx * numFiltersInBlock +
|
| 335 |
+
idxInFilterBlock;
|
| 336 |
+
|
| 337 |
+
if (tid == 0) {
|
| 338 |
+
tempOutputGradWeight[outputOffset] = temp;
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
template<int FS, int SB, typename scalar_t>
|
| 344 |
+
__global__
|
| 345 |
+
void lightconv_grad_wrt_weights_secondpass_kernel(
|
| 346 |
+
const float* input,
|
| 347 |
+
const int minibatch,
|
| 348 |
+
const int numFiltersInBlock,
|
| 349 |
+
scalar_t* output) {
|
| 350 |
+
|
| 351 |
+
assert(blockDim.x == SB);
|
| 352 |
+
const int tid = threadIdx.x;
|
| 353 |
+
|
| 354 |
+
// What is the id within a minibatch
|
| 355 |
+
const int filterIdx = blockIdx.x;
|
| 356 |
+
const int filterWeightIdx = blockIdx.y;
|
| 357 |
+
|
| 358 |
+
const int inputOffset = filterIdx * FS * minibatch * numFiltersInBlock +
|
| 359 |
+
filterWeightIdx * minibatch * numFiltersInBlock;
|
| 360 |
+
const float* tempInput = &input[inputOffset];
|
| 361 |
+
|
| 362 |
+
int readIndex = tid;
|
| 363 |
+
|
| 364 |
+
float sum = float(0.0);
|
| 365 |
+
while (readIndex < (minibatch * numFiltersInBlock)) {
|
| 366 |
+
sum += tempInput[readIndex];
|
| 367 |
+
readIndex += SB;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
float temp = blockReduce(sum);
|
| 371 |
+
|
| 372 |
+
if (tid == 0) {
|
| 373 |
+
output[blockIdx.x * FS + blockIdx.y] = temp;
|
| 374 |
+
}
|
| 375 |
+
}
|
fairseq-0.10.2/fairseq/modules/linearized_convolution.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from fairseq import utils
|
| 9 |
+
from fairseq.incremental_decoding_utils import with_incremental_state
|
| 10 |
+
|
| 11 |
+
from .conv_tbc import ConvTBC
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@with_incremental_state
|
| 15 |
+
class LinearizedConvolution(ConvTBC):
|
| 16 |
+
"""An optimized version of nn.Conv1d.
|
| 17 |
+
|
| 18 |
+
At training time, this module uses ConvTBC, which is an optimized version
|
| 19 |
+
of Conv1d. At inference time, it optimizes incremental generation (i.e.,
|
| 20 |
+
one time step at a time) by replacing the convolutions with linear layers.
|
| 21 |
+
Note that the input order changes from training to inference.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
|
| 25 |
+
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
|
| 26 |
+
self._linearized_weight = None
|
| 27 |
+
self.register_backward_hook(self._clear_linearized_weight)
|
| 28 |
+
|
| 29 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
| 30 |
+
state = ConvTBC.state_dict(self, destination, prefix, keep_vars=keep_vars)
|
| 31 |
+
# don't store redundant _linearized_weight in checkpoints
|
| 32 |
+
if prefix + "_linearized_weight" in state:
|
| 33 |
+
del state[prefix + "_linearized_weight"]
|
| 34 |
+
return state
|
| 35 |
+
|
| 36 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 37 |
+
prefix = name + "." if name != "" else ""
|
| 38 |
+
if prefix + "_linearized_weight" in state_dict:
|
| 39 |
+
del state_dict[prefix + "_linearized_weight"]
|
| 40 |
+
|
| 41 |
+
def forward(self, input, incremental_state=None):
|
| 42 |
+
"""
|
| 43 |
+
Args:
|
| 44 |
+
incremental_state: Used to buffer signal; if not None, then input is
|
| 45 |
+
expected to contain a single frame. If the input order changes
|
| 46 |
+
between time steps, call reorder_incremental_state.
|
| 47 |
+
Input:
|
| 48 |
+
Time x Batch x Channel during training
|
| 49 |
+
Batch x Time x Channel during inference
|
| 50 |
+
"""
|
| 51 |
+
if incremental_state is None:
|
| 52 |
+
output = super().forward(input)
|
| 53 |
+
if self.kernel_size[0] > 1 and self.padding[0] > 0:
|
| 54 |
+
# remove future timesteps added by padding
|
| 55 |
+
output = output[: -self.padding[0], :, :]
|
| 56 |
+
return output
|
| 57 |
+
|
| 58 |
+
# reshape weight
|
| 59 |
+
weight = self._get_linearized_weight()
|
| 60 |
+
kw = self.kernel_size[0]
|
| 61 |
+
|
| 62 |
+
bsz = input.size(0) # input: bsz x len x dim
|
| 63 |
+
if kw > 1:
|
| 64 |
+
input = input.data
|
| 65 |
+
input_buffer = self._get_input_buffer(incremental_state)
|
| 66 |
+
if input_buffer is None:
|
| 67 |
+
input_buffer = input.new(bsz, kw, input.size(2)).zero_()
|
| 68 |
+
self._set_input_buffer(incremental_state, input_buffer)
|
| 69 |
+
else:
|
| 70 |
+
# shift buffer
|
| 71 |
+
input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
|
| 72 |
+
# append next input
|
| 73 |
+
input_buffer[:, -1, :] = input[:, -1, :]
|
| 74 |
+
input = input_buffer
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
output = F.linear(input.view(bsz, -1), weight, self.bias)
|
| 77 |
+
return output.view(bsz, 1, -1)
|
| 78 |
+
|
| 79 |
+
def reorder_incremental_state(self, incremental_state, new_order):
|
| 80 |
+
input_buffer = self._get_input_buffer(incremental_state)
|
| 81 |
+
if input_buffer is not None:
|
| 82 |
+
input_buffer = input_buffer.index_select(0, new_order)
|
| 83 |
+
self._set_input_buffer(incremental_state, input_buffer)
|
| 84 |
+
|
| 85 |
+
def _get_input_buffer(self, incremental_state):
|
| 86 |
+
return utils.get_incremental_state(self, incremental_state, "input_buffer")
|
| 87 |
+
|
| 88 |
+
def _set_input_buffer(self, incremental_state, new_buffer):
|
| 89 |
+
return utils.set_incremental_state(
|
| 90 |
+
self, incremental_state, "input_buffer", new_buffer
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def _get_linearized_weight(self):
|
| 94 |
+
if self._linearized_weight is None:
|
| 95 |
+
kw = self.kernel_size[0]
|
| 96 |
+
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
|
| 97 |
+
assert weight.size() == (self.out_channels, kw, self.in_channels)
|
| 98 |
+
self._linearized_weight = torch.nn.Parameter(
|
| 99 |
+
weight.view(self.out_channels, -1)
|
| 100 |
+
)
|
| 101 |
+
return self._linearized_weight
|
| 102 |
+
|
| 103 |
+
def _clear_linearized_weight(self, *args):
|
| 104 |
+
self._linearized_weight = None
|
fairseq-0.10.2/fairseq/modules/multihead_attention.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Dict, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from fairseq import utils
|
| 12 |
+
from fairseq.incremental_decoding_utils import with_incremental_state
|
| 13 |
+
from fairseq.modules.fairseq_dropout import FairseqDropout
|
| 14 |
+
from fairseq.modules.quant_noise import quant_noise
|
| 15 |
+
from torch import Tensor, nn
|
| 16 |
+
from torch.nn import Parameter
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@with_incremental_state
|
| 20 |
+
class MultiheadAttention(nn.Module):
|
| 21 |
+
"""Multi-headed attention.
|
| 22 |
+
|
| 23 |
+
See "Attention Is All You Need" for more details.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
embed_dim,
|
| 29 |
+
num_heads,
|
| 30 |
+
kdim=None,
|
| 31 |
+
vdim=None,
|
| 32 |
+
dropout=0.0,
|
| 33 |
+
bias=True,
|
| 34 |
+
add_bias_kv=False,
|
| 35 |
+
add_zero_attn=False,
|
| 36 |
+
self_attention=False,
|
| 37 |
+
encoder_decoder_attention=False,
|
| 38 |
+
q_noise=0.0,
|
| 39 |
+
qn_block_size=8,
|
| 40 |
+
):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.embed_dim = embed_dim
|
| 43 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 44 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 45 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 46 |
+
|
| 47 |
+
self.num_heads = num_heads
|
| 48 |
+
self.dropout_module = FairseqDropout(
|
| 49 |
+
dropout, module_name=self.__class__.__name__
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.head_dim = embed_dim // num_heads
|
| 53 |
+
assert (
|
| 54 |
+
self.head_dim * num_heads == self.embed_dim
|
| 55 |
+
), "embed_dim must be divisible by num_heads"
|
| 56 |
+
self.scaling = self.head_dim ** -0.5
|
| 57 |
+
|
| 58 |
+
self.self_attention = self_attention
|
| 59 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
| 60 |
+
|
| 61 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
| 62 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self.k_proj = quant_noise(
|
| 66 |
+
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 67 |
+
)
|
| 68 |
+
self.v_proj = quant_noise(
|
| 69 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 70 |
+
)
|
| 71 |
+
self.q_proj = quant_noise(
|
| 72 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
self.out_proj = quant_noise(
|
| 76 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
if add_bias_kv:
|
| 80 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 81 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 82 |
+
else:
|
| 83 |
+
self.bias_k = self.bias_v = None
|
| 84 |
+
|
| 85 |
+
self.add_zero_attn = add_zero_attn
|
| 86 |
+
|
| 87 |
+
self.reset_parameters()
|
| 88 |
+
|
| 89 |
+
self.onnx_trace = False
|
| 90 |
+
self.tpu = False
|
| 91 |
+
|
| 92 |
+
def prepare_for_onnx_export_(self):
|
| 93 |
+
self.onnx_trace = True
|
| 94 |
+
|
| 95 |
+
def prepare_for_tpu_(self, **kwargs):
|
| 96 |
+
self.tpu = True
|
| 97 |
+
|
| 98 |
+
def reset_parameters(self):
|
| 99 |
+
if self.qkv_same_dim:
|
| 100 |
+
# Empirically observed the convergence to be much better with
|
| 101 |
+
# the scaled initialization
|
| 102 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
| 103 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
| 104 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
| 105 |
+
else:
|
| 106 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
| 107 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
| 108 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
| 109 |
+
|
| 110 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 111 |
+
if self.out_proj.bias is not None:
|
| 112 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
| 113 |
+
if self.bias_k is not None:
|
| 114 |
+
nn.init.xavier_normal_(self.bias_k)
|
| 115 |
+
if self.bias_v is not None:
|
| 116 |
+
nn.init.xavier_normal_(self.bias_v)
|
| 117 |
+
|
| 118 |
+
def forward(
|
| 119 |
+
self,
|
| 120 |
+
query,
|
| 121 |
+
key: Optional[Tensor],
|
| 122 |
+
value: Optional[Tensor],
|
| 123 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 124 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
| 125 |
+
need_weights: bool = True,
|
| 126 |
+
static_kv: bool = False,
|
| 127 |
+
attn_mask: Optional[Tensor] = None,
|
| 128 |
+
before_softmax: bool = False,
|
| 129 |
+
need_head_weights: bool = False,
|
| 130 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
| 131 |
+
"""Input shape: Time x Batch x Channel
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
| 135 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
| 136 |
+
padding elements are indicated by 1s.
|
| 137 |
+
need_weights (bool, optional): return the attention weights,
|
| 138 |
+
averaged over heads (default: False).
|
| 139 |
+
attn_mask (ByteTensor, optional): typically used to
|
| 140 |
+
implement causal attention, where the mask prevents the
|
| 141 |
+
attention from looking forward in time (default: None).
|
| 142 |
+
before_softmax (bool, optional): return the raw attention
|
| 143 |
+
weights and values before the attention softmax.
|
| 144 |
+
need_head_weights (bool, optional): return the attention
|
| 145 |
+
weights for each head. Implies *need_weights*. Default:
|
| 146 |
+
return the average attention weights over all heads.
|
| 147 |
+
"""
|
| 148 |
+
if need_head_weights:
|
| 149 |
+
need_weights = True
|
| 150 |
+
|
| 151 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 152 |
+
assert embed_dim == self.embed_dim
|
| 153 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
| 154 |
+
|
| 155 |
+
if (
|
| 156 |
+
not self.onnx_trace
|
| 157 |
+
and not self.tpu # don't use PyTorch version on TPUs
|
| 158 |
+
and incremental_state is None
|
| 159 |
+
and not static_kv
|
| 160 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
| 161 |
+
# treats bias in linear module as method.
|
| 162 |
+
and not torch.jit.is_scripting()
|
| 163 |
+
):
|
| 164 |
+
assert key is not None and value is not None
|
| 165 |
+
return F.multi_head_attention_forward(
|
| 166 |
+
query,
|
| 167 |
+
key,
|
| 168 |
+
value,
|
| 169 |
+
self.embed_dim,
|
| 170 |
+
self.num_heads,
|
| 171 |
+
torch.empty([0]),
|
| 172 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
| 173 |
+
self.bias_k,
|
| 174 |
+
self.bias_v,
|
| 175 |
+
self.add_zero_attn,
|
| 176 |
+
self.dropout_module.p,
|
| 177 |
+
self.out_proj.weight,
|
| 178 |
+
self.out_proj.bias,
|
| 179 |
+
self.training or self.dropout_module.apply_during_inference,
|
| 180 |
+
key_padding_mask,
|
| 181 |
+
need_weights,
|
| 182 |
+
attn_mask,
|
| 183 |
+
use_separate_proj_weight=True,
|
| 184 |
+
q_proj_weight=self.q_proj.weight,
|
| 185 |
+
k_proj_weight=self.k_proj.weight,
|
| 186 |
+
v_proj_weight=self.v_proj.weight,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if incremental_state is not None:
|
| 190 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 191 |
+
if saved_state is not None and "prev_key" in saved_state:
|
| 192 |
+
# previous time steps are cached - no need to recompute
|
| 193 |
+
# key and value if they are static
|
| 194 |
+
if static_kv:
|
| 195 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
| 196 |
+
key = value = None
|
| 197 |
+
else:
|
| 198 |
+
saved_state = None
|
| 199 |
+
|
| 200 |
+
if self.self_attention:
|
| 201 |
+
q = self.q_proj(query)
|
| 202 |
+
k = self.k_proj(query)
|
| 203 |
+
v = self.v_proj(query)
|
| 204 |
+
elif self.encoder_decoder_attention:
|
| 205 |
+
# encoder-decoder attention
|
| 206 |
+
q = self.q_proj(query)
|
| 207 |
+
if key is None:
|
| 208 |
+
assert value is None
|
| 209 |
+
k = v = None
|
| 210 |
+
else:
|
| 211 |
+
k = self.k_proj(key)
|
| 212 |
+
v = self.v_proj(key)
|
| 213 |
+
|
| 214 |
+
else:
|
| 215 |
+
assert key is not None and value is not None
|
| 216 |
+
q = self.q_proj(query)
|
| 217 |
+
k = self.k_proj(key)
|
| 218 |
+
v = self.v_proj(value)
|
| 219 |
+
q *= self.scaling
|
| 220 |
+
|
| 221 |
+
if self.bias_k is not None:
|
| 222 |
+
assert self.bias_v is not None
|
| 223 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
| 224 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
| 225 |
+
if attn_mask is not None:
|
| 226 |
+
attn_mask = torch.cat(
|
| 227 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 228 |
+
)
|
| 229 |
+
if key_padding_mask is not None:
|
| 230 |
+
key_padding_mask = torch.cat(
|
| 231 |
+
[
|
| 232 |
+
key_padding_mask,
|
| 233 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
| 234 |
+
],
|
| 235 |
+
dim=1,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
q = (
|
| 239 |
+
q.contiguous()
|
| 240 |
+
.view(tgt_len, bsz * self.num_heads, self.head_dim)
|
| 241 |
+
.transpose(0, 1)
|
| 242 |
+
)
|
| 243 |
+
if k is not None:
|
| 244 |
+
k = (
|
| 245 |
+
k.contiguous()
|
| 246 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
| 247 |
+
.transpose(0, 1)
|
| 248 |
+
)
|
| 249 |
+
if v is not None:
|
| 250 |
+
v = (
|
| 251 |
+
v.contiguous()
|
| 252 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
| 253 |
+
.transpose(0, 1)
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
if saved_state is not None:
|
| 257 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
| 258 |
+
if "prev_key" in saved_state:
|
| 259 |
+
_prev_key = saved_state["prev_key"]
|
| 260 |
+
assert _prev_key is not None
|
| 261 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
| 262 |
+
if static_kv:
|
| 263 |
+
k = prev_key
|
| 264 |
+
else:
|
| 265 |
+
assert k is not None
|
| 266 |
+
k = torch.cat([prev_key, k], dim=1)
|
| 267 |
+
if "prev_value" in saved_state:
|
| 268 |
+
_prev_value = saved_state["prev_value"]
|
| 269 |
+
assert _prev_value is not None
|
| 270 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
| 271 |
+
if static_kv:
|
| 272 |
+
v = prev_value
|
| 273 |
+
else:
|
| 274 |
+
assert v is not None
|
| 275 |
+
v = torch.cat([prev_value, v], dim=1)
|
| 276 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
| 277 |
+
if "prev_key_padding_mask" in saved_state:
|
| 278 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
| 279 |
+
assert k is not None and v is not None
|
| 280 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
| 281 |
+
key_padding_mask=key_padding_mask,
|
| 282 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
| 283 |
+
batch_size=bsz,
|
| 284 |
+
src_len=k.size(1),
|
| 285 |
+
static_kv=static_kv,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
| 289 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
| 290 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
| 291 |
+
# In this branch incremental_state is never None
|
| 292 |
+
assert incremental_state is not None
|
| 293 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
| 294 |
+
assert k is not None
|
| 295 |
+
src_len = k.size(1)
|
| 296 |
+
|
| 297 |
+
# This is part of a workaround to get around fork/join parallelism
|
| 298 |
+
# not supporting Optional types.
|
| 299 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
| 300 |
+
key_padding_mask = None
|
| 301 |
+
|
| 302 |
+
if key_padding_mask is not None:
|
| 303 |
+
assert key_padding_mask.size(0) == bsz
|
| 304 |
+
assert key_padding_mask.size(1) == src_len
|
| 305 |
+
|
| 306 |
+
if self.add_zero_attn:
|
| 307 |
+
assert v is not None
|
| 308 |
+
src_len += 1
|
| 309 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
| 310 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
| 311 |
+
if attn_mask is not None:
|
| 312 |
+
attn_mask = torch.cat(
|
| 313 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 314 |
+
)
|
| 315 |
+
if key_padding_mask is not None:
|
| 316 |
+
key_padding_mask = torch.cat(
|
| 317 |
+
[
|
| 318 |
+
key_padding_mask,
|
| 319 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
| 320 |
+
key_padding_mask
|
| 321 |
+
),
|
| 322 |
+
],
|
| 323 |
+
dim=1,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
| 327 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
| 328 |
+
|
| 329 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
| 330 |
+
|
| 331 |
+
if attn_mask is not None:
|
| 332 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 333 |
+
if self.onnx_trace:
|
| 334 |
+
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
| 335 |
+
attn_weights += attn_mask
|
| 336 |
+
|
| 337 |
+
if key_padding_mask is not None:
|
| 338 |
+
# don't attend to padding symbols
|
| 339 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 340 |
+
if not self.tpu:
|
| 341 |
+
attn_weights = attn_weights.masked_fill(
|
| 342 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
| 343 |
+
float("-inf"),
|
| 344 |
+
)
|
| 345 |
+
else:
|
| 346 |
+
attn_weights = attn_weights.transpose(0, 2)
|
| 347 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
| 348 |
+
attn_weights = attn_weights.transpose(0, 2)
|
| 349 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 350 |
+
|
| 351 |
+
if before_softmax:
|
| 352 |
+
return attn_weights, v
|
| 353 |
+
|
| 354 |
+
attn_weights_float = utils.softmax(
|
| 355 |
+
attn_weights, dim=-1, onnx_trace=self.onnx_trace
|
| 356 |
+
)
|
| 357 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
| 358 |
+
attn_probs = self.dropout_module(attn_weights)
|
| 359 |
+
|
| 360 |
+
assert v is not None
|
| 361 |
+
attn = torch.bmm(attn_probs, v)
|
| 362 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
| 363 |
+
if self.onnx_trace and attn.size(1) == 1:
|
| 364 |
+
# when ONNX tracing a single decoder step (sequence length == 1)
|
| 365 |
+
# the transpose is a no-op copy before view, thus unnecessary
|
| 366 |
+
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
|
| 367 |
+
else:
|
| 368 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
| 369 |
+
attn = self.out_proj(attn)
|
| 370 |
+
attn_weights: Optional[Tensor] = None
|
| 371 |
+
if need_weights:
|
| 372 |
+
attn_weights = attn_weights_float.view(
|
| 373 |
+
bsz, self.num_heads, tgt_len, src_len
|
| 374 |
+
).transpose(1, 0)
|
| 375 |
+
if not need_head_weights:
|
| 376 |
+
# average attention weights over heads
|
| 377 |
+
attn_weights = attn_weights.mean(dim=0)
|
| 378 |
+
|
| 379 |
+
return attn, attn_weights
|
| 380 |
+
|
| 381 |
+
@staticmethod
|
| 382 |
+
def _append_prev_key_padding_mask(
|
| 383 |
+
key_padding_mask: Optional[Tensor],
|
| 384 |
+
prev_key_padding_mask: Optional[Tensor],
|
| 385 |
+
batch_size: int,
|
| 386 |
+
src_len: int,
|
| 387 |
+
static_kv: bool,
|
| 388 |
+
) -> Optional[Tensor]:
|
| 389 |
+
# saved key padding masks have shape (bsz, seq_len)
|
| 390 |
+
if prev_key_padding_mask is not None and static_kv:
|
| 391 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 392 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
| 393 |
+
new_key_padding_mask = torch.cat(
|
| 394 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
| 395 |
+
)
|
| 396 |
+
# During incremental decoding, as the padding token enters and
|
| 397 |
+
# leaves the frame, there will be a time when prev or current
|
| 398 |
+
# is None
|
| 399 |
+
elif prev_key_padding_mask is not None:
|
| 400 |
+
filler = torch.zeros(
|
| 401 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
| 402 |
+
device=prev_key_padding_mask.device,
|
| 403 |
+
)
|
| 404 |
+
new_key_padding_mask = torch.cat(
|
| 405 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
| 406 |
+
)
|
| 407 |
+
elif key_padding_mask is not None:
|
| 408 |
+
filler = torch.zeros(
|
| 409 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
| 410 |
+
device=key_padding_mask.device,
|
| 411 |
+
)
|
| 412 |
+
new_key_padding_mask = torch.cat(
|
| 413 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 417 |
+
return new_key_padding_mask
|
| 418 |
+
|
| 419 |
+
@torch.jit.export
|
| 420 |
+
def reorder_incremental_state(
|
| 421 |
+
self,
|
| 422 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
| 423 |
+
new_order: Tensor,
|
| 424 |
+
):
|
| 425 |
+
"""Reorder buffered internal state (for incremental generation)."""
|
| 426 |
+
input_buffer = self._get_input_buffer(incremental_state)
|
| 427 |
+
if input_buffer is not None:
|
| 428 |
+
for k in input_buffer.keys():
|
| 429 |
+
input_buffer_k = input_buffer[k]
|
| 430 |
+
if input_buffer_k is not None:
|
| 431 |
+
if self.encoder_decoder_attention and input_buffer_k.size(
|
| 432 |
+
0
|
| 433 |
+
) == new_order.size(0):
|
| 434 |
+
break
|
| 435 |
+
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
| 436 |
+
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
| 437 |
+
return incremental_state
|
| 438 |
+
|
| 439 |
+
def _get_input_buffer(
|
| 440 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
| 441 |
+
) -> Dict[str, Optional[Tensor]]:
|
| 442 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
| 443 |
+
if result is not None:
|
| 444 |
+
return result
|
| 445 |
+
else:
|
| 446 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
| 447 |
+
return empty_result
|
| 448 |
+
|
| 449 |
+
def _set_input_buffer(
|
| 450 |
+
self,
|
| 451 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
| 452 |
+
buffer: Dict[str, Optional[Tensor]],
|
| 453 |
+
):
|
| 454 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
| 455 |
+
|
| 456 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
| 457 |
+
return attn_weights
|
| 458 |
+
|
| 459 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 460 |
+
prefix = name + "." if name != "" else ""
|
| 461 |
+
items_to_add = {}
|
| 462 |
+
keys_to_remove = []
|
| 463 |
+
for k in state_dict.keys():
|
| 464 |
+
if k.endswith(prefix + "in_proj_weight"):
|
| 465 |
+
# in_proj_weight used to be q + k + v with same dimensions
|
| 466 |
+
dim = int(state_dict[k].shape[0] / 3)
|
| 467 |
+
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
| 468 |
+
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
| 469 |
+
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
| 470 |
+
|
| 471 |
+
keys_to_remove.append(k)
|
| 472 |
+
|
| 473 |
+
k_bias = prefix + "in_proj_bias"
|
| 474 |
+
if k_bias in state_dict.keys():
|
| 475 |
+
dim = int(state_dict[k].shape[0] / 3)
|
| 476 |
+
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
| 477 |
+
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
| 478 |
+
dim : 2 * dim
|
| 479 |
+
]
|
| 480 |
+
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
| 481 |
+
|
| 482 |
+
keys_to_remove.append(prefix + "in_proj_bias")
|
| 483 |
+
|
| 484 |
+
for k in keys_to_remove:
|
| 485 |
+
del state_dict[k]
|
| 486 |
+
|
| 487 |
+
for key, value in items_to_add.items():
|
| 488 |
+
state_dict[key] = value
|
fairseq-0.10.2/fairseq/modules/positional_embedding.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from .learned_positional_embedding import LearnedPositionalEmbedding
|
| 9 |
+
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def PositionalEmbedding(
|
| 13 |
+
num_embeddings: int,
|
| 14 |
+
embedding_dim: int,
|
| 15 |
+
padding_idx: int,
|
| 16 |
+
learned: bool = False,
|
| 17 |
+
):
|
| 18 |
+
if learned:
|
| 19 |
+
# if padding_idx is specified then offset the embedding ids by
|
| 20 |
+
# this index and adjust num_embeddings appropriately
|
| 21 |
+
# TODO: The right place for this offset would be inside
|
| 22 |
+
# LearnedPositionalEmbedding. Move this there for a cleaner implementation.
|
| 23 |
+
if padding_idx is not None:
|
| 24 |
+
num_embeddings = num_embeddings + padding_idx + 1
|
| 25 |
+
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
|
| 26 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
| 27 |
+
if padding_idx is not None:
|
| 28 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
| 29 |
+
else:
|
| 30 |
+
m = SinusoidalPositionalEmbedding(
|
| 31 |
+
embedding_dim,
|
| 32 |
+
padding_idx,
|
| 33 |
+
init_size=num_embeddings + padding_idx + 1,
|
| 34 |
+
)
|
| 35 |
+
return m
|
fairseq-0.10.2/fairseq/modules/quant_noise.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def quant_noise(module, p, block_size):
|
| 11 |
+
"""
|
| 12 |
+
Wraps modules and applies quantization noise to the weights for
|
| 13 |
+
subsequent quantization with Iterative Product Quantization as
|
| 14 |
+
described in "Training with Quantization Noise for Extreme Model Compression"
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
- module: nn.Module
|
| 18 |
+
- p: amount of Quantization Noise
|
| 19 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
| 20 |
+
|
| 21 |
+
Remarks:
|
| 22 |
+
- Module weights must have the right sizes wrt the block size
|
| 23 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
| 24 |
+
- For more detail on how to quantize by blocks with convolutional weights,
|
| 25 |
+
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
| 26 |
+
- We implement the simplest form of noise here as stated in the paper
|
| 27 |
+
which consists in randomly dropping blocks
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
# if no quantization noise, don't register hook
|
| 31 |
+
if p <= 0:
|
| 32 |
+
return module
|
| 33 |
+
|
| 34 |
+
# supported modules
|
| 35 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
| 36 |
+
|
| 37 |
+
# test whether module.weight has the right sizes wrt block_size
|
| 38 |
+
is_conv = module.weight.ndim == 4
|
| 39 |
+
|
| 40 |
+
# 2D matrix
|
| 41 |
+
if not is_conv:
|
| 42 |
+
assert (
|
| 43 |
+
module.weight.size(1) % block_size == 0
|
| 44 |
+
), "Input features must be a multiple of block sizes"
|
| 45 |
+
|
| 46 |
+
# 4D matrix
|
| 47 |
+
else:
|
| 48 |
+
# 1x1 convolutions
|
| 49 |
+
if module.kernel_size == (1, 1):
|
| 50 |
+
assert (
|
| 51 |
+
module.in_channels % block_size == 0
|
| 52 |
+
), "Input channels must be a multiple of block sizes"
|
| 53 |
+
# regular convolutions
|
| 54 |
+
else:
|
| 55 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
| 56 |
+
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
| 57 |
+
|
| 58 |
+
def _forward_pre_hook(mod, input):
|
| 59 |
+
# no noise for evaluation
|
| 60 |
+
if mod.training:
|
| 61 |
+
if not is_conv:
|
| 62 |
+
# gather weight and sizes
|
| 63 |
+
weight = mod.weight
|
| 64 |
+
in_features = weight.size(1)
|
| 65 |
+
out_features = weight.size(0)
|
| 66 |
+
|
| 67 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
| 68 |
+
mask = torch.zeros(
|
| 69 |
+
in_features // block_size * out_features, device=weight.device
|
| 70 |
+
)
|
| 71 |
+
mask.bernoulli_(p)
|
| 72 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
| 73 |
+
|
| 74 |
+
else:
|
| 75 |
+
# gather weight and sizes
|
| 76 |
+
weight = mod.weight
|
| 77 |
+
in_channels = mod.in_channels
|
| 78 |
+
out_channels = mod.out_channels
|
| 79 |
+
|
| 80 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
| 81 |
+
if mod.kernel_size == (1, 1):
|
| 82 |
+
mask = torch.zeros(
|
| 83 |
+
int(in_channels // block_size * out_channels),
|
| 84 |
+
device=weight.device,
|
| 85 |
+
)
|
| 86 |
+
mask.bernoulli_(p)
|
| 87 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
| 88 |
+
else:
|
| 89 |
+
mask = torch.zeros(
|
| 90 |
+
weight.size(0), weight.size(1), device=weight.device
|
| 91 |
+
)
|
| 92 |
+
mask.bernoulli_(p)
|
| 93 |
+
mask = (
|
| 94 |
+
mask.unsqueeze(2)
|
| 95 |
+
.unsqueeze(3)
|
| 96 |
+
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# scale weights and apply mask
|
| 100 |
+
mask = mask.to(
|
| 101 |
+
torch.bool
|
| 102 |
+
) # x.bool() is not currently supported in TorchScript
|
| 103 |
+
s = 1 / (1 - p)
|
| 104 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
| 105 |
+
|
| 106 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
| 107 |
+
return module
|
fairseq-0.10.2/fairseq/modules/quantization/__pycache__/quantization_options.cpython-310.pyc
ADDED
|
Binary file (1.34 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/quantization/pq/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (247 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/modules/quantization/pq/__pycache__/pq.cpython-310.pyc
ADDED
|
Binary file (3.55 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/quantization/pq/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (9.88 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .qconv import PQConv2d # NOQA
|
| 7 |
+
from .qemb import PQEmbedding # NOQA
|
| 8 |
+
from .qlinear import PQLinear # NOQA
|
fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (306 Bytes). View file
|
|
|
fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__pycache__/qconv.cpython-310.pyc
ADDED
|
Binary file (3.84 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/quantization/pq/modules/__pycache__/qemb.cpython-310.pyc
ADDED
|
Binary file (3.1 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/quantization/pq/modules/qemb.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PQEmbedding(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Quantized counterpart of nn.Embedding module. Stores the centroids and
|
| 14 |
+
the assignments. The full weight is re-instantiated at each forward
|
| 15 |
+
pass.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
- centroids: centroids of size n_centroids x block_size
|
| 19 |
+
- assignments: assignments of the centroids to the subvectors
|
| 20 |
+
of size self.out_features x n_blocks
|
| 21 |
+
- bias: the non-quantized bias
|
| 22 |
+
|
| 23 |
+
Remarks:
|
| 24 |
+
- We refer the reader to the official documentation of the nn.Embedding module
|
| 25 |
+
for the other arguments and the behavior of the module
|
| 26 |
+
- Performance tests on GPU show that this implementation is 10% slower than
|
| 27 |
+
the non-quantized nn.Embedding module for a standard training loop.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
centroids,
|
| 33 |
+
assignments,
|
| 34 |
+
num_embeddings,
|
| 35 |
+
embedding_dim,
|
| 36 |
+
padding_idx=None,
|
| 37 |
+
max_norm=None,
|
| 38 |
+
norm_type=2.0,
|
| 39 |
+
scale_grad_by_freq=False,
|
| 40 |
+
sparse=False,
|
| 41 |
+
_weight=None,
|
| 42 |
+
):
|
| 43 |
+
super(PQEmbedding, self).__init__()
|
| 44 |
+
self.block_size = centroids.size(1)
|
| 45 |
+
self.n_centroids = centroids.size(0)
|
| 46 |
+
self.num_embeddings = num_embeddings
|
| 47 |
+
self.embedding_dim = embedding_dim
|
| 48 |
+
if padding_idx is not None:
|
| 49 |
+
if padding_idx > 0:
|
| 50 |
+
assert (
|
| 51 |
+
padding_idx < self.num_embeddings
|
| 52 |
+
), "Padding_idx must be within num_embeddings"
|
| 53 |
+
elif padding_idx < 0:
|
| 54 |
+
assert (
|
| 55 |
+
padding_idx >= -self.num_embeddings
|
| 56 |
+
), "Padding_idx must be within num_embeddings"
|
| 57 |
+
padding_idx = self.num_embeddings + padding_idx
|
| 58 |
+
self.padding_idx = padding_idx
|
| 59 |
+
self.max_norm = max_norm
|
| 60 |
+
self.norm_type = norm_type
|
| 61 |
+
self.scale_grad_by_freq = scale_grad_by_freq
|
| 62 |
+
self.sparse = sparse
|
| 63 |
+
# check compatibility
|
| 64 |
+
if self.embedding_dim % self.block_size != 0:
|
| 65 |
+
raise ValueError("Wrong PQ sizes")
|
| 66 |
+
if len(assignments) % self.num_embeddings != 0:
|
| 67 |
+
raise ValueError("Wrong PQ sizes")
|
| 68 |
+
# define parameters
|
| 69 |
+
self.centroids = nn.Parameter(centroids, requires_grad=True)
|
| 70 |
+
self.register_buffer("assignments", assignments)
|
| 71 |
+
self.register_buffer("counts", torch.bincount(assignments).type_as(centroids))
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def weight(self):
|
| 75 |
+
return (
|
| 76 |
+
self.centroids[self.assignments]
|
| 77 |
+
.reshape(-1, self.num_embeddings, self.block_size)
|
| 78 |
+
.permute(1, 0, 2)
|
| 79 |
+
.flatten(1, 2)
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def forward(self, input):
|
| 83 |
+
return F.embedding(
|
| 84 |
+
input,
|
| 85 |
+
self.weight,
|
| 86 |
+
self.padding_idx,
|
| 87 |
+
self.max_norm,
|
| 88 |
+
self.norm_type,
|
| 89 |
+
self.scale_grad_by_freq,
|
| 90 |
+
self.sparse,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def extra_repr(self):
|
| 94 |
+
s = "{num_embeddings}, {embedding_dim}"
|
| 95 |
+
if self.padding_idx is not None:
|
| 96 |
+
s += ", padding_idx={padding_idx}"
|
| 97 |
+
if self.max_norm is not None:
|
| 98 |
+
s += ", max_norm={max_norm}"
|
| 99 |
+
if self.norm_type != 2:
|
| 100 |
+
s += ", norm_type={norm_type}"
|
| 101 |
+
if self.scale_grad_by_freq is not False:
|
| 102 |
+
s += ", scale_grad_by_freq={scale_grad_by_freq}"
|
| 103 |
+
if self.sparse is not False:
|
| 104 |
+
s += ", sparse=True"
|
| 105 |
+
s += ", n_centroids={n_centroids}, block_size={block_size}"
|
| 106 |
+
|
| 107 |
+
return s.format(**self.__dict__)
|
fairseq-0.10.2/fairseq/modules/quantization/pq/modules/qlinear.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PQLinear(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Quantized counterpart of nn.Linear module. Stores the centroid, the assignments
|
| 14 |
+
and the non-quantized biases. The full weight is re-instantiated at each forward
|
| 15 |
+
pass.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
- centroids: centroids of size n_centroids x block_size
|
| 19 |
+
- assignments: assignments of the centroids to the subvectors
|
| 20 |
+
of size self.out_features x n_blocks
|
| 21 |
+
- bias: the non-quantized bias
|
| 22 |
+
|
| 23 |
+
Remarks:
|
| 24 |
+
- We refer the reader to the official documentation of the nn.Linear module
|
| 25 |
+
for the other arguments and the behavior of the module
|
| 26 |
+
- Performance tests on GPU show that this implementation is 15% slower than
|
| 27 |
+
the non-quantized nn.Linear module for a standard training loop.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, centroids, assignments, bias, in_features, out_features):
|
| 31 |
+
super(PQLinear, self).__init__()
|
| 32 |
+
self.block_size = centroids.size(1)
|
| 33 |
+
self.n_centroids = centroids.size(0)
|
| 34 |
+
self.in_features = in_features
|
| 35 |
+
self.out_features = out_features
|
| 36 |
+
# check compatibility
|
| 37 |
+
if self.in_features % self.block_size != 0:
|
| 38 |
+
raise ValueError("Wrong PQ sizes")
|
| 39 |
+
if len(assignments) % self.out_features != 0:
|
| 40 |
+
raise ValueError("Wrong PQ sizes")
|
| 41 |
+
# define parameters
|
| 42 |
+
self.centroids = nn.Parameter(centroids, requires_grad=True)
|
| 43 |
+
self.register_buffer("assignments", assignments)
|
| 44 |
+
self.register_buffer("counts", torch.bincount(assignments).type_as(centroids))
|
| 45 |
+
if bias is not None:
|
| 46 |
+
self.bias = nn.Parameter(bias)
|
| 47 |
+
else:
|
| 48 |
+
self.register_parameter("bias", None)
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def weight(self):
|
| 52 |
+
return (
|
| 53 |
+
self.centroids[self.assignments]
|
| 54 |
+
.reshape(-1, self.out_features, self.block_size)
|
| 55 |
+
.permute(1, 0, 2)
|
| 56 |
+
.flatten(1, 2)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
return F.linear(
|
| 61 |
+
x,
|
| 62 |
+
self.weight,
|
| 63 |
+
self.bias,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def extra_repr(self):
|
| 67 |
+
return f"in_features={self.in_features},\
|
| 68 |
+
out_features={self.out_features},\
|
| 69 |
+
n_centroids={self.n_centroids},\
|
| 70 |
+
block_size={self.block_size},\
|
| 71 |
+
bias={self.bias is not None}"
|
fairseq-0.10.2/fairseq/modules/quantization/pq/pq.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .em import EM, EmptyClusterResolveError
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PQ(EM):
|
| 10 |
+
"""
|
| 11 |
+
Quantizes the layer weights W with the standard Product Quantization
|
| 12 |
+
technique. This learns a codebook of codewords or centroids of size
|
| 13 |
+
block_size from W. For further reference on using PQ to quantize
|
| 14 |
+
neural networks, see "And the Bit Goes Down: Revisiting the Quantization
|
| 15 |
+
of Neural Networks", Stock et al., ICLR 2020.
|
| 16 |
+
|
| 17 |
+
PQ is performed in two steps:
|
| 18 |
+
(1) The matrix W (weights or fully-connected or convolutional layer)
|
| 19 |
+
is reshaped to (block_size, -1).
|
| 20 |
+
- If W is fully-connected (2D), its columns are split into
|
| 21 |
+
blocks of size block_size.
|
| 22 |
+
- If W is convolutional (4D), its filters are split along the
|
| 23 |
+
spatial dimension.
|
| 24 |
+
(2) We apply the standard EM/k-means algorithm to the resulting reshaped matrix.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
- W: weight matrix to quantize of size (in_features x out_features)
|
| 28 |
+
- block_size: size of the blocks (subvectors)
|
| 29 |
+
- n_centroids: number of centroids
|
| 30 |
+
- n_iter: number of k-means iterations
|
| 31 |
+
- eps: for cluster reassignment when an empty cluster is found
|
| 32 |
+
- max_tentatives for cluster reassignment when an empty cluster is found
|
| 33 |
+
- verbose: print information after each iteration
|
| 34 |
+
|
| 35 |
+
Remarks:
|
| 36 |
+
- block_size be compatible with the shape of W
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
W,
|
| 42 |
+
block_size,
|
| 43 |
+
n_centroids=256,
|
| 44 |
+
n_iter=20,
|
| 45 |
+
eps=1e-6,
|
| 46 |
+
max_tentatives=30,
|
| 47 |
+
verbose=True,
|
| 48 |
+
):
|
| 49 |
+
self.block_size = block_size
|
| 50 |
+
W_reshaped = self._reshape(W)
|
| 51 |
+
super(PQ, self).__init__(
|
| 52 |
+
W_reshaped,
|
| 53 |
+
n_centroids=n_centroids,
|
| 54 |
+
n_iter=n_iter,
|
| 55 |
+
eps=eps,
|
| 56 |
+
max_tentatives=max_tentatives,
|
| 57 |
+
verbose=verbose,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def _reshape(self, W):
|
| 61 |
+
"""
|
| 62 |
+
Reshapes the matrix W as expained in step (1).
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
# fully connected: by convention the weight has size out_features x in_features
|
| 66 |
+
if len(W.size()) == 2:
|
| 67 |
+
self.out_features, self.in_features = W.size()
|
| 68 |
+
assert (
|
| 69 |
+
self.in_features % self.block_size == 0
|
| 70 |
+
), "Linear: n_blocks must be a multiple of in_features"
|
| 71 |
+
return (
|
| 72 |
+
W.reshape(self.out_features, -1, self.block_size)
|
| 73 |
+
.permute(2, 1, 0)
|
| 74 |
+
.flatten(1, 2)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# convolutional: we reshape along the spatial dimension
|
| 78 |
+
elif len(W.size()) == 4:
|
| 79 |
+
self.out_channels, self.in_channels, self.k_h, self.k_w = W.size()
|
| 80 |
+
assert (
|
| 81 |
+
self.in_channels * self.k_h * self.k_w
|
| 82 |
+
) % self.block_size == 0, (
|
| 83 |
+
"Conv2d: n_blocks must be a multiple of in_channels * k_h * k_w"
|
| 84 |
+
)
|
| 85 |
+
return (
|
| 86 |
+
W.reshape(self.out_channels, -1, self.block_size)
|
| 87 |
+
.permute(2, 1, 0)
|
| 88 |
+
.flatten(1, 2)
|
| 89 |
+
)
|
| 90 |
+
# not implemented
|
| 91 |
+
else:
|
| 92 |
+
raise NotImplementedError(W.size())
|
| 93 |
+
|
| 94 |
+
def encode(self):
|
| 95 |
+
"""
|
| 96 |
+
Performs self.n_iter EM steps.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
self.initialize_centroids()
|
| 100 |
+
for i in range(self.n_iter):
|
| 101 |
+
try:
|
| 102 |
+
self.step(i)
|
| 103 |
+
except EmptyClusterResolveError:
|
| 104 |
+
break
|
| 105 |
+
|
| 106 |
+
def decode(self):
|
| 107 |
+
"""
|
| 108 |
+
Returns the encoded full weight matrix. Must be called after
|
| 109 |
+
the encode function.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
# fully connected case
|
| 113 |
+
if "k_h" not in self.__dict__:
|
| 114 |
+
return (
|
| 115 |
+
self.centroids[self.assignments]
|
| 116 |
+
.reshape(-1, self.out_features, self.block_size)
|
| 117 |
+
.permute(1, 0, 2)
|
| 118 |
+
.flatten(1, 2)
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# convolutional case
|
| 122 |
+
else:
|
| 123 |
+
return (
|
| 124 |
+
self.centroids[self.assignments]
|
| 125 |
+
.reshape(-1, self.out_channels, self.block_size)
|
| 126 |
+
.permute(1, 0, 2)
|
| 127 |
+
.reshape(self.out_channels, self.in_channels, self.k_h, self.k_w)
|
| 128 |
+
)
|
fairseq-0.10.2/fairseq/modules/quantization/quantization_options.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def parse_config_yaml(yaml_data):
|
| 8 |
+
# Initialize to default options.
|
| 9 |
+
quantization_options = {
|
| 10 |
+
"n_centroids": {
|
| 11 |
+
"Linear": ["in_features", {"*": 256}],
|
| 12 |
+
"Embedding": ["embedding_dim", {"*": 256}],
|
| 13 |
+
},
|
| 14 |
+
"block_sizes": {
|
| 15 |
+
"Linear": ["fuzzy_name", {"fc": 8, "attn": 4, "emb": 4}],
|
| 16 |
+
"Embedding": ["fuzzy_name", {"emb": 8}],
|
| 17 |
+
},
|
| 18 |
+
"layers_to_quantize": [
|
| 19 |
+
"decoder\\.layers\\.\\d+\\.fc[12]",
|
| 20 |
+
"decoder\\.embed_tokens\\.embeddings\\.[012]\\.[01]",
|
| 21 |
+
"decoder\\.layers\\.\\d+\\.self_attn\\.(k_proj|v_proj|q_proj|out_proj)",
|
| 22 |
+
],
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
if "n_centroids" in yaml_data:
|
| 26 |
+
quantization_options["n_centroids"] = {
|
| 27 |
+
layer: convert_yaml_to_tuple(layer_data)
|
| 28 |
+
for layer, layer_data in yaml_data["n_centroids"].items()
|
| 29 |
+
}
|
| 30 |
+
if "block_sizes" in yaml_data:
|
| 31 |
+
quantization_options["block_sizes"] = {
|
| 32 |
+
layer: convert_yaml_to_tuple(layer_data)
|
| 33 |
+
for layer, layer_data in yaml_data["block_sizes"].items()
|
| 34 |
+
}
|
| 35 |
+
if "layers_to_quantize" in yaml_data:
|
| 36 |
+
quantization_options["layers_to_quantize"] = yaml_data["layers_to_quantize"]
|
| 37 |
+
|
| 38 |
+
return quantization_options
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def convert_yaml_to_tuple(yaml_dictionary):
|
| 42 |
+
"""Converts a yaml dictionary with two keys: `key` and `value` into a two
|
| 43 |
+
argument tuple of those values."""
|
| 44 |
+
return (yaml_dictionary["key"], yaml_dictionary["value"])
|
fairseq-0.10.2/fairseq/modules/quantization/scalar/__pycache__/ops.cpython-310.pyc
ADDED
|
Binary file (1.54 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/quantization/scalar/modules/__pycache__/qemb.cpython-310.pyc
ADDED
|
Binary file (3.85 kB). View file
|
|
|
fairseq-0.10.2/fairseq/modules/quantization/scalar/modules/qemb.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from ..ops import emulate_int
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class IntEmbedding(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Quantized counterpart of the nn.Embedding module that applies QuantNoise during training.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
- num_embeddings: number of tokens
|
| 19 |
+
- embedding_dim: embedding dimension
|
| 20 |
+
- p: amount of noise to inject (0 = no quantization, 1 = quantize all the weights)
|
| 21 |
+
- bits: number of bits
|
| 22 |
+
- method: choose among {"tensor", "histogram", "channel"}
|
| 23 |
+
- update_step: recompute scale and zero_point every update_steps iterations
|
| 24 |
+
|
| 25 |
+
Remarks:
|
| 26 |
+
- We use the straight-through estimator so that the gradients
|
| 27 |
+
back-propagate nicely in the network, this is implemented with
|
| 28 |
+
the detach() trick
|
| 29 |
+
- Parameters scale and zero_point are recomputed every update_step
|
| 30 |
+
forward pass to reduce the overhead
|
| 31 |
+
- At test time, the weights are fully quantized
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
num_embeddings,
|
| 37 |
+
embedding_dim,
|
| 38 |
+
padding_idx=None,
|
| 39 |
+
max_norm=None,
|
| 40 |
+
norm_type=2.0,
|
| 41 |
+
scale_grad_by_freq=False,
|
| 42 |
+
sparse=False,
|
| 43 |
+
_weight=None,
|
| 44 |
+
p=0,
|
| 45 |
+
update_step=1000,
|
| 46 |
+
bits=8,
|
| 47 |
+
method="histogram",
|
| 48 |
+
):
|
| 49 |
+
super(IntEmbedding, self).__init__()
|
| 50 |
+
self.num_embeddings = num_embeddings
|
| 51 |
+
self.embedding_dim = embedding_dim
|
| 52 |
+
if padding_idx is not None:
|
| 53 |
+
if padding_idx > 0:
|
| 54 |
+
assert (
|
| 55 |
+
padding_idx < self.num_embeddings
|
| 56 |
+
), "Padding_idx must be within num_embeddings"
|
| 57 |
+
elif padding_idx < 0:
|
| 58 |
+
assert (
|
| 59 |
+
padding_idx >= -self.num_embeddings
|
| 60 |
+
), "Padding_idx must be within num_embeddings"
|
| 61 |
+
padding_idx = self.num_embeddings + padding_idx
|
| 62 |
+
self.padding_idx = padding_idx
|
| 63 |
+
self.max_norm = max_norm
|
| 64 |
+
self.norm_type = norm_type
|
| 65 |
+
self.scale_grad_by_freq = scale_grad_by_freq
|
| 66 |
+
if _weight is None:
|
| 67 |
+
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
| 68 |
+
self.reset_parameters()
|
| 69 |
+
else:
|
| 70 |
+
assert list(_weight.shape) == [
|
| 71 |
+
num_embeddings,
|
| 72 |
+
embedding_dim,
|
| 73 |
+
], "Shape of weight does not match num_embeddings and embedding_dim"
|
| 74 |
+
self.weight = nn.Parameter(_weight)
|
| 75 |
+
self.sparse = sparse
|
| 76 |
+
|
| 77 |
+
# quantization parameters
|
| 78 |
+
self.p = p
|
| 79 |
+
self.bits = bits
|
| 80 |
+
self.method = method
|
| 81 |
+
self.update_step = update_step
|
| 82 |
+
self.counter = 0
|
| 83 |
+
|
| 84 |
+
def reset_parameters(self):
|
| 85 |
+
nn.init.normal_(self.weight)
|
| 86 |
+
if self.padding_idx is not None:
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
self.weight[self.padding_idx].fill_(0)
|
| 89 |
+
|
| 90 |
+
def forward(self, input):
|
| 91 |
+
# train with QuantNoise and evaluate the fully quantized network
|
| 92 |
+
p = self.p if self.training else 1
|
| 93 |
+
|
| 94 |
+
# update parameters every 1000 iterations
|
| 95 |
+
if self.counter % self.update_step == 0:
|
| 96 |
+
self.scale = None
|
| 97 |
+
self.zero_point = None
|
| 98 |
+
self.counter += 1
|
| 99 |
+
|
| 100 |
+
# quantize weight
|
| 101 |
+
weight_quantized, self.scale, self.zero_point = emulate_int(
|
| 102 |
+
self.weight.detach(),
|
| 103 |
+
bits=self.bits,
|
| 104 |
+
method=self.method,
|
| 105 |
+
scale=self.scale,
|
| 106 |
+
zero_point=self.zero_point,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# mask to apply noise
|
| 110 |
+
mask = torch.zeros_like(self.weight)
|
| 111 |
+
mask.bernoulli_(1 - p)
|
| 112 |
+
noise = (weight_quantized - self.weight).masked_fill(mask.bool(), 0)
|
| 113 |
+
|
| 114 |
+
# using straight-through estimator (STE)
|
| 115 |
+
clamp_low = -self.scale * self.zero_point
|
| 116 |
+
clamp_high = self.scale * (2 ** self.bits - 1 - self.zero_point)
|
| 117 |
+
weight = (
|
| 118 |
+
torch.clamp(self.weight, clamp_low.item(), clamp_high.item())
|
| 119 |
+
+ noise.detach()
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# return output
|
| 123 |
+
output = F.embedding(
|
| 124 |
+
input,
|
| 125 |
+
weight,
|
| 126 |
+
self.padding_idx,
|
| 127 |
+
self.max_norm,
|
| 128 |
+
self.norm_type,
|
| 129 |
+
self.scale_grad_by_freq,
|
| 130 |
+
self.sparse,
|
| 131 |
+
)
|
| 132 |
+
return output
|
| 133 |
+
|
| 134 |
+
def extra_repr(self):
|
| 135 |
+
s = "{num_embeddings}, {embedding_dim}"
|
| 136 |
+
if self.padding_idx is not None:
|
| 137 |
+
s += ", padding_idx={padding_idx}"
|
| 138 |
+
if self.max_norm is not None:
|
| 139 |
+
s += ", max_norm={max_norm}"
|
| 140 |
+
if self.norm_type != 2:
|
| 141 |
+
s += ", norm_type={norm_type}"
|
| 142 |
+
if self.scale_grad_by_freq is not False:
|
| 143 |
+
s += ", scale_grad_by_freq={scale_grad_by_freq}"
|
| 144 |
+
if self.sparse is not False:
|
| 145 |
+
s += ", sparse=True"
|
| 146 |
+
s += "quant_noise={p}, bits={bits}, method={method}"
|
| 147 |
+
return s.format(**self.__dict__)
|
fairseq-0.10.2/fairseq/modules/sparse_transformer_sentence_encoder.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from fairseq.modules import TransformerSentenceEncoder
|
| 8 |
+
from fairseq.modules.sparse_transformer_sentence_encoder_layer import (
|
| 9 |
+
SparseTransformerSentenceEncoderLayer,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SparseTransformerSentenceEncoder(TransformerSentenceEncoder):
|
| 14 |
+
"""
|
| 15 |
+
Sparse implementation of the TransformerSentenceEncoder
|
| 16 |
+
- see SparseMultiheadAttention
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
padding_idx: int,
|
| 22 |
+
vocab_size: int,
|
| 23 |
+
num_encoder_layers: int = 6,
|
| 24 |
+
embedding_dim: int = 768,
|
| 25 |
+
ffn_embedding_dim: int = 3072,
|
| 26 |
+
num_attention_heads: int = 8,
|
| 27 |
+
dropout: float = 0.1,
|
| 28 |
+
attention_dropout: float = 0.1,
|
| 29 |
+
activation_dropout: float = 0.1,
|
| 30 |
+
max_seq_len: int = 256,
|
| 31 |
+
num_segments: int = 2,
|
| 32 |
+
use_position_embeddings: bool = True,
|
| 33 |
+
offset_positions_by_padding: bool = True,
|
| 34 |
+
encoder_normalize_before: bool = False,
|
| 35 |
+
apply_bert_init: bool = False,
|
| 36 |
+
activation_fn: str = "relu",
|
| 37 |
+
learned_pos_embedding: bool = True,
|
| 38 |
+
embed_scale: float = None,
|
| 39 |
+
freeze_embeddings: bool = False,
|
| 40 |
+
n_trans_layers_to_freeze: int = 0,
|
| 41 |
+
export: bool = False,
|
| 42 |
+
is_bidirectional: bool = True,
|
| 43 |
+
stride: int = 32,
|
| 44 |
+
expressivity: int = 8,
|
| 45 |
+
) -> None:
|
| 46 |
+
|
| 47 |
+
super().__init__(
|
| 48 |
+
padding_idx,
|
| 49 |
+
vocab_size,
|
| 50 |
+
num_encoder_layers,
|
| 51 |
+
embedding_dim,
|
| 52 |
+
ffn_embedding_dim,
|
| 53 |
+
num_attention_heads,
|
| 54 |
+
dropout,
|
| 55 |
+
attention_dropout,
|
| 56 |
+
activation_dropout,
|
| 57 |
+
max_seq_len,
|
| 58 |
+
num_segments,
|
| 59 |
+
use_position_embeddings,
|
| 60 |
+
offset_positions_by_padding,
|
| 61 |
+
encoder_normalize_before,
|
| 62 |
+
apply_bert_init,
|
| 63 |
+
activation_fn,
|
| 64 |
+
learned_pos_embedding,
|
| 65 |
+
embed_scale,
|
| 66 |
+
freeze_embeddings,
|
| 67 |
+
n_trans_layers_to_freeze,
|
| 68 |
+
export,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.layers = nn.ModuleList(
|
| 72 |
+
[
|
| 73 |
+
SparseTransformerSentenceEncoderLayer(
|
| 74 |
+
embedding_dim=self.embedding_dim,
|
| 75 |
+
ffn_embedding_dim=ffn_embedding_dim,
|
| 76 |
+
num_attention_heads=num_attention_heads,
|
| 77 |
+
dropout=dropout,
|
| 78 |
+
attention_dropout=attention_dropout,
|
| 79 |
+
activation_dropout=activation_dropout,
|
| 80 |
+
activation_fn=activation_fn,
|
| 81 |
+
export=export,
|
| 82 |
+
is_bidirectional=is_bidirectional,
|
| 83 |
+
stride=stride,
|
| 84 |
+
expressivity=expressivity,
|
| 85 |
+
)
|
| 86 |
+
for _ in range(num_encoder_layers)
|
| 87 |
+
]
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def freeze_module_params(m):
|
| 91 |
+
if m is not None:
|
| 92 |
+
for p in m.parameters():
|
| 93 |
+
p.requires_grad = False
|
| 94 |
+
|
| 95 |
+
for layer in range(n_trans_layers_to_freeze):
|
| 96 |
+
freeze_module_params(self.layers[layer])
|