Upload 551 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- fairseq/__init__.py +26 -0
- fairseq/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/__pycache__/binarizer.cpython-310.pyc +0 -0
- fairseq/__pycache__/checkpoint_utils.cpython-310.pyc +0 -0
- fairseq/__pycache__/distributed_utils.cpython-310.pyc +0 -0
- fairseq/__pycache__/file_io.cpython-310.pyc +0 -0
- fairseq/__pycache__/file_utils.cpython-310.pyc +0 -0
- fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc +0 -0
- fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc +0 -0
- fairseq/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc +0 -0
- fairseq/__pycache__/options.cpython-310.pyc +0 -0
- fairseq/__pycache__/pdb.cpython-310.pyc +0 -0
- fairseq/__pycache__/registry.cpython-310.pyc +0 -0
- fairseq/__pycache__/search.cpython-310.pyc +0 -0
- fairseq/__pycache__/sequence_generator.cpython-310.pyc +0 -0
- fairseq/__pycache__/tokenizer.cpython-310.pyc +0 -0
- fairseq/__pycache__/utils.cpython-310.pyc +0 -0
- fairseq/benchmark/__init__.py +12 -0
- fairseq/benchmark/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc +0 -0
- fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc +0 -0
- fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc +0 -0
- fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc +0 -0
- fairseq/benchmark/dummy_lm.py +118 -0
- fairseq/benchmark/dummy_masked_lm.py +127 -0
- fairseq/benchmark/dummy_model.py +95 -0
- fairseq/benchmark/dummy_mt.py +120 -0
- fairseq/binarizer.py +104 -0
- fairseq/checkpoint_utils.py +522 -0
- fairseq/clib/libbleu/libbleu.cpp +141 -0
- fairseq/clib/libbleu/module.cpp +37 -0
- fairseq/clib/libnat/edit_dist.cpp +231 -0
- fairseq/clib/libnat_cuda/binding.cpp +60 -0
- fairseq/clib/libnat_cuda/edit_dist.cu +332 -0
- fairseq/clib/libnat_cuda/edit_dist.h +25 -0
- fairseq/criterions/__init__.py +24 -0
- fairseq/criterions/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc +0 -0
- fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc +0 -0
- fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc +0 -0
- fairseq/criterions/__pycache__/ctc.cpython-310.pyc +0 -0
- fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc +0 -0
- fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc +0 -0
- fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc +0 -0
- fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc +0 -0
- fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc +0 -0
- fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc +0 -0
- fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc +0 -0
- fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc +0 -0
.gitattributes
CHANGED
|
@@ -40,3 +40,5 @@ fairseq/data/data_utils_fast.cpython-38-darwin.so filter=lfs diff=lfs merge=lfs
|
|
| 40 |
fairseq/data/token_block_utils_fast.cpython-310-darwin.so filter=lfs diff=lfs merge=lfs -text
|
| 41 |
fairseq/data/token_block_utils_fast.cpython-36m-darwin.so filter=lfs diff=lfs merge=lfs -text
|
| 42 |
fairseq/data/token_block_utils_fast.cpython-38-darwin.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 40 |
fairseq/data/token_block_utils_fast.cpython-310-darwin.so filter=lfs diff=lfs merge=lfs -text
|
| 41 |
fairseq/data/token_block_utils_fast.cpython-36m-darwin.so filter=lfs diff=lfs merge=lfs -text
|
| 42 |
fairseq/data/token_block_utils_fast.cpython-38-darwin.so filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
fairseq/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
__all__ = ['pdb']
|
| 7 |
+
__version__ = '0.9.0'
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
# backwards compatibility to support `from fairseq.meters import AverageMeter`
|
| 12 |
+
from fairseq.logging import meters, metrics, progress_bar # noqa
|
| 13 |
+
sys.modules['fairseq.meters'] = meters
|
| 14 |
+
sys.modules['fairseq.metrics'] = metrics
|
| 15 |
+
sys.modules['fairseq.progress_bar'] = progress_bar
|
| 16 |
+
|
| 17 |
+
import fairseq.criterions # noqa
|
| 18 |
+
import fairseq.models # noqa
|
| 19 |
+
import fairseq.modules # noqa
|
| 20 |
+
import fairseq.optim # noqa
|
| 21 |
+
import fairseq.optim.lr_scheduler # noqa
|
| 22 |
+
import fairseq.pdb # noqa
|
| 23 |
+
import fairseq.tasks # noqa
|
| 24 |
+
|
| 25 |
+
import fairseq.benchmark # noqa
|
| 26 |
+
import fairseq.model_parallel # noqa
|
fairseq/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (653 Bytes). View file
|
|
|
fairseq/__pycache__/binarizer.cpython-310.pyc
ADDED
|
Binary file (3.09 kB). View file
|
|
|
fairseq/__pycache__/checkpoint_utils.cpython-310.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
fairseq/__pycache__/distributed_utils.cpython-310.pyc
ADDED
|
Binary file (9.11 kB). View file
|
|
|
fairseq/__pycache__/file_io.cpython-310.pyc
ADDED
|
Binary file (2.84 kB). View file
|
|
|
fairseq/__pycache__/file_utils.cpython-310.pyc
ADDED
|
Binary file (8.66 kB). View file
|
|
|
fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc
ADDED
|
Binary file (2.19 kB). View file
|
|
|
fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc
ADDED
|
Binary file (8.44 kB). View file
|
|
|
fairseq/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc
ADDED
|
Binary file (4.99 kB). View file
|
|
|
fairseq/__pycache__/options.cpython-310.pyc
ADDED
|
Binary file (23.6 kB). View file
|
|
|
fairseq/__pycache__/pdb.cpython-310.pyc
ADDED
|
Binary file (1.31 kB). View file
|
|
|
fairseq/__pycache__/registry.cpython-310.pyc
ADDED
|
Binary file (2.05 kB). View file
|
|
|
fairseq/__pycache__/search.cpython-310.pyc
ADDED
|
Binary file (9.99 kB). View file
|
|
|
fairseq/__pycache__/sequence_generator.cpython-310.pyc
ADDED
|
Binary file (24.4 kB). View file
|
|
|
fairseq/__pycache__/tokenizer.cpython-310.pyc
ADDED
|
Binary file (359 Bytes). View file
|
|
|
fairseq/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (18.1 kB). View file
|
|
|
fairseq/benchmark/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# import models/tasks to register them
|
| 7 |
+
from . import ( # noqa
|
| 8 |
+
dummy_lm,
|
| 9 |
+
dummy_masked_lm,
|
| 10 |
+
dummy_model,
|
| 11 |
+
dummy_mt,
|
| 12 |
+
)
|
fairseq/benchmark/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (247 Bytes). View file
|
|
|
fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc
ADDED
|
Binary file (4.4 kB). View file
|
|
|
fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc
ADDED
|
Binary file (4.61 kB). View file
|
|
|
fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc
ADDED
|
Binary file (3.4 kB). View file
|
|
|
fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc
ADDED
|
Binary file (4.45 kB). View file
|
|
|
fairseq/benchmark/dummy_lm.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from fairseq.data import Dictionary, FairseqDataset
|
| 12 |
+
from fairseq.tasks import FairseqTask, register_task
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@register_task('dummy_lm')
|
| 19 |
+
class DummyLMTask(FairseqTask):
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def add_args(parser):
|
| 23 |
+
"""Add task-specific arguments to the parser."""
|
| 24 |
+
parser.add_argument('--dict-size', default=49996, type=int)
|
| 25 |
+
parser.add_argument('--dataset-size', default=100000, type=int)
|
| 26 |
+
parser.add_argument('--tokens-per-sample', default=512, type=int,
|
| 27 |
+
help='max number of total tokens over all segments '
|
| 28 |
+
'per sample for BERT dataset')
|
| 29 |
+
|
| 30 |
+
def __init__(self, args, dictionary):
|
| 31 |
+
super().__init__(args)
|
| 32 |
+
self.dictionary = dictionary
|
| 33 |
+
self.seed = args.seed
|
| 34 |
+
|
| 35 |
+
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
| 36 |
+
|
| 37 |
+
seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1
|
| 38 |
+
|
| 39 |
+
self.dummy_src = seq[:-1]
|
| 40 |
+
self.dummy_tgt = seq[1:]
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def setup_task(cls, args, **kwargs):
|
| 44 |
+
"""Setup the task. """
|
| 45 |
+
dictionary = Dictionary()
|
| 46 |
+
for i in range(args.dict_size):
|
| 47 |
+
dictionary.add_symbol('word{}'.format(i))
|
| 48 |
+
logger.info('dictionary: {} types'.format(len(dictionary)))
|
| 49 |
+
return cls(args, dictionary)
|
| 50 |
+
|
| 51 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
| 52 |
+
"""Load a given dataset split.
|
| 53 |
+
Args:
|
| 54 |
+
split (str): name of the split (e.g., train, valid, test)
|
| 55 |
+
"""
|
| 56 |
+
if self.args.max_sentences is not None:
|
| 57 |
+
bsz = self.args.max_sentences
|
| 58 |
+
else:
|
| 59 |
+
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
|
| 60 |
+
self.datasets[split] = DummyDataset(
|
| 61 |
+
{
|
| 62 |
+
'id': 1,
|
| 63 |
+
'net_input': {
|
| 64 |
+
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
|
| 65 |
+
'src_lengths': torch.full(
|
| 66 |
+
(bsz, ), self.args.tokens_per_sample, dtype=torch.long
|
| 67 |
+
),
|
| 68 |
+
},
|
| 69 |
+
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
| 70 |
+
'nsentences': bsz,
|
| 71 |
+
'ntokens': bsz * self.args.tokens_per_sample,
|
| 72 |
+
},
|
| 73 |
+
num_items=self.args.dataset_size,
|
| 74 |
+
item_size=self.args.tokens_per_sample,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def source_dictionary(self):
|
| 79 |
+
return self.dictionary
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def target_dictionary(self):
|
| 83 |
+
return self.dictionary
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class DummyDataset(FairseqDataset):
|
| 87 |
+
|
| 88 |
+
def __init__(self, batch, num_items, item_size):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.batch = batch
|
| 91 |
+
self.num_items = num_items
|
| 92 |
+
self.item_size = item_size
|
| 93 |
+
|
| 94 |
+
def __getitem__(self, index):
|
| 95 |
+
return index
|
| 96 |
+
|
| 97 |
+
def __len__(self):
|
| 98 |
+
return self.num_items
|
| 99 |
+
|
| 100 |
+
def collater(self, samples):
|
| 101 |
+
return self.batch
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def sizes(self):
|
| 105 |
+
return np.array([self.item_size] * self.num_items)
|
| 106 |
+
|
| 107 |
+
def num_tokens(self, index):
|
| 108 |
+
return self.item_size
|
| 109 |
+
|
| 110 |
+
def size(self, index):
|
| 111 |
+
return self.item_size
|
| 112 |
+
|
| 113 |
+
def ordered_indices(self):
|
| 114 |
+
return np.arange(self.num_items)
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def supports_prefetch(self):
|
| 118 |
+
return False
|
fairseq/benchmark/dummy_masked_lm.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from fairseq.data import Dictionary, FairseqDataset
|
| 12 |
+
from fairseq.tasks import FairseqTask, register_task
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@register_task('dummy_masked_lm')
|
| 19 |
+
class DummyMaskedLMTask(FairseqTask):
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def add_args(parser):
|
| 23 |
+
"""Add task-specific arguments to the parser."""
|
| 24 |
+
parser.add_argument('--dict-size', default=49995, type=int)
|
| 25 |
+
parser.add_argument('--dataset-size', default=100000, type=int)
|
| 26 |
+
parser.add_argument('--tokens-per-sample', default=512, type=int,
|
| 27 |
+
help='max number of total tokens over all segments '
|
| 28 |
+
'per sample for BERT dataset')
|
| 29 |
+
|
| 30 |
+
def __init__(self, args, dictionary):
|
| 31 |
+
super().__init__(args)
|
| 32 |
+
self.dictionary = dictionary
|
| 33 |
+
self.seed = args.seed
|
| 34 |
+
|
| 35 |
+
# add mask token
|
| 36 |
+
self.mask_idx = dictionary.add_symbol('<mask>')
|
| 37 |
+
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
| 38 |
+
|
| 39 |
+
mask_idx = 0
|
| 40 |
+
pad_idx = 1
|
| 41 |
+
seq = torch.arange(args.tokens_per_sample) + pad_idx + 1
|
| 42 |
+
mask = torch.arange(2, args.tokens_per_sample, 7) # ~15%
|
| 43 |
+
src = seq.clone()
|
| 44 |
+
src[mask] = mask_idx
|
| 45 |
+
tgt = torch.full_like(seq, pad_idx)
|
| 46 |
+
tgt[mask] = seq[mask]
|
| 47 |
+
|
| 48 |
+
self.dummy_src = src
|
| 49 |
+
self.dummy_tgt = tgt
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def setup_task(cls, args, **kwargs):
|
| 53 |
+
"""Setup the task. """
|
| 54 |
+
dictionary = Dictionary()
|
| 55 |
+
for i in range(args.dict_size):
|
| 56 |
+
dictionary.add_symbol('word{}'.format(i))
|
| 57 |
+
logger.info('dictionary: {} types'.format(len(dictionary)))
|
| 58 |
+
return cls(args, dictionary)
|
| 59 |
+
|
| 60 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
| 61 |
+
"""Load a given dataset split.
|
| 62 |
+
Args:
|
| 63 |
+
split (str): name of the split (e.g., train, valid, test)
|
| 64 |
+
"""
|
| 65 |
+
if self.args.max_sentences is not None:
|
| 66 |
+
bsz = self.args.max_sentences
|
| 67 |
+
else:
|
| 68 |
+
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
|
| 69 |
+
self.datasets[split] = DummyDataset(
|
| 70 |
+
{
|
| 71 |
+
'id': 1,
|
| 72 |
+
'net_input': {
|
| 73 |
+
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
|
| 74 |
+
'src_lengths': torch.full(
|
| 75 |
+
(bsz, ), self.args.tokens_per_sample, dtype=torch.long
|
| 76 |
+
),
|
| 77 |
+
},
|
| 78 |
+
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
| 79 |
+
'nsentences': bsz,
|
| 80 |
+
'ntokens': bsz * self.args.tokens_per_sample,
|
| 81 |
+
},
|
| 82 |
+
num_items=self.args.dataset_size,
|
| 83 |
+
item_size=self.args.tokens_per_sample,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def source_dictionary(self):
|
| 88 |
+
return self.dictionary
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def target_dictionary(self):
|
| 92 |
+
return self.dictionary
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class DummyDataset(FairseqDataset):
|
| 96 |
+
|
| 97 |
+
def __init__(self, batch, num_items, item_size):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.batch = batch
|
| 100 |
+
self.num_items = num_items
|
| 101 |
+
self.item_size = item_size
|
| 102 |
+
|
| 103 |
+
def __getitem__(self, index):
|
| 104 |
+
return index
|
| 105 |
+
|
| 106 |
+
def __len__(self):
|
| 107 |
+
return self.num_items
|
| 108 |
+
|
| 109 |
+
def collater(self, samples):
|
| 110 |
+
return self.batch
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def sizes(self):
|
| 114 |
+
return np.array([self.item_size] * self.num_items)
|
| 115 |
+
|
| 116 |
+
def num_tokens(self, index):
|
| 117 |
+
return self.item_size
|
| 118 |
+
|
| 119 |
+
def size(self, index):
|
| 120 |
+
return self.item_size
|
| 121 |
+
|
| 122 |
+
def ordered_indices(self):
|
| 123 |
+
return np.arange(self.num_items)
|
| 124 |
+
|
| 125 |
+
@property
|
| 126 |
+
def supports_prefetch(self):
|
| 127 |
+
return False
|
fairseq/benchmark/dummy_model.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from fairseq.data import Dictionary
|
| 10 |
+
from fairseq.models import (
|
| 11 |
+
FairseqDecoder,
|
| 12 |
+
FairseqLanguageModel,
|
| 13 |
+
register_model,
|
| 14 |
+
register_model_architecture,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@register_model('dummy_model')
|
| 19 |
+
class DummyModel(FairseqLanguageModel):
|
| 20 |
+
|
| 21 |
+
def __init__(self, args, encoder):
|
| 22 |
+
super().__init__(encoder)
|
| 23 |
+
self.args = args
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def add_args(parser):
|
| 27 |
+
parser.add_argument('--num-layers', type=int, default=24)
|
| 28 |
+
parser.add_argument('--embed-dim', type=int, default=1024)
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def build_model(cls, args, task):
|
| 32 |
+
encoder = DummyEncoder(
|
| 33 |
+
num_embed=len(task.target_dictionary),
|
| 34 |
+
embed_dim=args.embed_dim,
|
| 35 |
+
num_layers=args.num_layers,
|
| 36 |
+
)
|
| 37 |
+
return cls(args, encoder)
|
| 38 |
+
|
| 39 |
+
def forward(self, src_tokens, masked_tokens=None, **kwargs):
|
| 40 |
+
return self.decoder(src_tokens, masked_tokens=masked_tokens)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class DummyEncoder(FairseqDecoder):
|
| 44 |
+
|
| 45 |
+
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
|
| 46 |
+
super().__init__(Dictionary())
|
| 47 |
+
self.embed = nn.Embedding(
|
| 48 |
+
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
|
| 49 |
+
)
|
| 50 |
+
self.layers_a = nn.ModuleList([
|
| 51 |
+
nn.Sequential(
|
| 52 |
+
nn.LayerNorm(embed_dim),
|
| 53 |
+
nn.Linear(embed_dim, 3*embed_dim), # q, k, v input projection
|
| 54 |
+
nn.Linear(3*embed_dim, embed_dim), # skip self-attention
|
| 55 |
+
nn.Linear(embed_dim, embed_dim), # output projection
|
| 56 |
+
nn.Dropout(),
|
| 57 |
+
)
|
| 58 |
+
for i in range(num_layers)
|
| 59 |
+
])
|
| 60 |
+
self.layers_b = nn.ModuleList([
|
| 61 |
+
nn.Sequential(
|
| 62 |
+
nn.LayerNorm(embed_dim),
|
| 63 |
+
nn.Linear(embed_dim, 4*embed_dim), # FFN
|
| 64 |
+
nn.ReLU(),
|
| 65 |
+
nn.Linear(4*embed_dim, embed_dim), # FFN
|
| 66 |
+
nn.Dropout(0.1),
|
| 67 |
+
)
|
| 68 |
+
for i in range(num_layers)
|
| 69 |
+
])
|
| 70 |
+
self.out_proj = nn.Linear(embed_dim, num_embed)
|
| 71 |
+
|
| 72 |
+
def forward(self, tokens, masked_tokens=None):
|
| 73 |
+
x = self.embed(tokens)
|
| 74 |
+
for layer_a, layer_b in zip(self.layers_a, self.layers_b):
|
| 75 |
+
x = x + layer_a(x)
|
| 76 |
+
x = x + layer_b(x)
|
| 77 |
+
x = self.out_proj(x)
|
| 78 |
+
if masked_tokens is not None:
|
| 79 |
+
x = x[masked_tokens]
|
| 80 |
+
return (x,)
|
| 81 |
+
|
| 82 |
+
def max_positions(self):
|
| 83 |
+
return 1024
|
| 84 |
+
|
| 85 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
| 86 |
+
logits = net_output[0].float()
|
| 87 |
+
if log_probs:
|
| 88 |
+
return F.log_softmax(logits, dim=-1)
|
| 89 |
+
else:
|
| 90 |
+
return F.softmax(logits, dim=-1)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@register_model_architecture('dummy_model', 'dummy_model')
|
| 94 |
+
def base_architecture(args):
|
| 95 |
+
pass
|
fairseq/benchmark/dummy_mt.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from fairseq.data import Dictionary, FairseqDataset
|
| 12 |
+
from fairseq.tasks import FairseqTask, register_task
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@register_task('dummy_mt')
|
| 19 |
+
class DummyMTTask(FairseqTask):
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def add_args(parser):
|
| 23 |
+
"""Add task-specific arguments to the parser."""
|
| 24 |
+
parser.add_argument('--dict-size', default=49996, type=int)
|
| 25 |
+
parser.add_argument('--dataset-size', default=100000, type=int)
|
| 26 |
+
parser.add_argument('--tokens-per-sample', default=512, type=int,
|
| 27 |
+
help='max number of total tokens over all segments '
|
| 28 |
+
'per sample for BERT dataset')
|
| 29 |
+
|
| 30 |
+
def __init__(self, args, dictionary):
|
| 31 |
+
super().__init__(args)
|
| 32 |
+
self.dictionary = dictionary
|
| 33 |
+
self.seed = args.seed
|
| 34 |
+
|
| 35 |
+
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
| 36 |
+
|
| 37 |
+
seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1
|
| 38 |
+
|
| 39 |
+
self.dummy_src = seq[:-1]
|
| 40 |
+
self.dummy_tgt = seq[1:]
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def setup_task(cls, args, **kwargs):
|
| 44 |
+
"""Setup the task. """
|
| 45 |
+
dictionary = Dictionary()
|
| 46 |
+
for i in range(args.dict_size):
|
| 47 |
+
dictionary.add_symbol('word{}'.format(i))
|
| 48 |
+
logger.info('dictionary: {} types'.format(len(dictionary)))
|
| 49 |
+
return cls(args, dictionary)
|
| 50 |
+
|
| 51 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
| 52 |
+
"""Load a given dataset split.
|
| 53 |
+
Args:
|
| 54 |
+
split (str): name of the split (e.g., train, valid, test)
|
| 55 |
+
"""
|
| 56 |
+
if self.args.max_sentences is not None:
|
| 57 |
+
bsz = self.args.max_sentences
|
| 58 |
+
else:
|
| 59 |
+
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
|
| 60 |
+
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
|
| 61 |
+
self.datasets[split] = DummyDataset(
|
| 62 |
+
{
|
| 63 |
+
'id': 1,
|
| 64 |
+
'net_input': {
|
| 65 |
+
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
|
| 66 |
+
'src_lengths': torch.full(
|
| 67 |
+
(bsz, ), self.args.tokens_per_sample, dtype=torch.long
|
| 68 |
+
),
|
| 69 |
+
'prev_output_tokens': tgt.clone(),
|
| 70 |
+
},
|
| 71 |
+
'target': tgt,
|
| 72 |
+
'nsentences': bsz,
|
| 73 |
+
'ntokens': bsz * self.args.tokens_per_sample,
|
| 74 |
+
},
|
| 75 |
+
num_items=self.args.dataset_size,
|
| 76 |
+
item_size=self.args.tokens_per_sample,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def source_dictionary(self):
|
| 81 |
+
return self.dictionary
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def target_dictionary(self):
|
| 85 |
+
return self.dictionary
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class DummyDataset(FairseqDataset):
|
| 89 |
+
|
| 90 |
+
def __init__(self, batch, num_items, item_size):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.batch = batch
|
| 93 |
+
self.num_items = num_items
|
| 94 |
+
self.item_size = item_size
|
| 95 |
+
|
| 96 |
+
def __getitem__(self, index):
|
| 97 |
+
return index
|
| 98 |
+
|
| 99 |
+
def __len__(self):
|
| 100 |
+
return self.num_items
|
| 101 |
+
|
| 102 |
+
def collater(self, samples):
|
| 103 |
+
return self.batch
|
| 104 |
+
|
| 105 |
+
@property
|
| 106 |
+
def sizes(self):
|
| 107 |
+
return np.array([self.item_size] * self.num_items)
|
| 108 |
+
|
| 109 |
+
def num_tokens(self, index):
|
| 110 |
+
return self.item_size
|
| 111 |
+
|
| 112 |
+
def size(self, index):
|
| 113 |
+
return self.item_size
|
| 114 |
+
|
| 115 |
+
def ordered_indices(self):
|
| 116 |
+
return np.arange(self.num_items)
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def supports_prefetch(self):
|
| 120 |
+
return False
|
fairseq/binarizer.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from collections import Counter
|
| 8 |
+
|
| 9 |
+
from fairseq.tokenizer import tokenize_line
|
| 10 |
+
import torch
|
| 11 |
+
from fairseq.file_io import PathManager
|
| 12 |
+
|
| 13 |
+
def safe_readline(f):
|
| 14 |
+
pos = f.tell()
|
| 15 |
+
while True:
|
| 16 |
+
try:
|
| 17 |
+
return f.readline()
|
| 18 |
+
except UnicodeDecodeError:
|
| 19 |
+
pos -= 1
|
| 20 |
+
f.seek(pos) # search where this character begins
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Binarizer:
|
| 24 |
+
@staticmethod
|
| 25 |
+
def binarize(
|
| 26 |
+
filename,
|
| 27 |
+
dict,
|
| 28 |
+
consumer,
|
| 29 |
+
tokenize=tokenize_line,
|
| 30 |
+
append_eos=True,
|
| 31 |
+
reverse_order=False,
|
| 32 |
+
offset=0,
|
| 33 |
+
end=-1,
|
| 34 |
+
already_numberized=False,
|
| 35 |
+
):
|
| 36 |
+
nseq, ntok = 0, 0
|
| 37 |
+
replaced = Counter()
|
| 38 |
+
|
| 39 |
+
def replaced_consumer(word, idx):
|
| 40 |
+
if idx == dict.unk_index and word != dict.unk_word:
|
| 41 |
+
replaced.update([word])
|
| 42 |
+
|
| 43 |
+
with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f:
|
| 44 |
+
f.seek(offset)
|
| 45 |
+
# next(f) breaks f.tell(), hence readline() must be used
|
| 46 |
+
line = safe_readline(f)
|
| 47 |
+
while line:
|
| 48 |
+
if end > 0 and f.tell() > end:
|
| 49 |
+
break
|
| 50 |
+
if already_numberized:
|
| 51 |
+
id_strings = line.strip().split()
|
| 52 |
+
id_list = [int(id_string) for id_string in id_strings]
|
| 53 |
+
if reverse_order:
|
| 54 |
+
id_list.reverse()
|
| 55 |
+
if append_eos:
|
| 56 |
+
id_list.append(dict.eos())
|
| 57 |
+
ids = torch.IntTensor(id_list)
|
| 58 |
+
else:
|
| 59 |
+
ids = dict.encode_line(
|
| 60 |
+
line=line,
|
| 61 |
+
line_tokenizer=tokenize,
|
| 62 |
+
add_if_not_exist=False,
|
| 63 |
+
consumer=replaced_consumer,
|
| 64 |
+
append_eos=append_eos,
|
| 65 |
+
reverse_order=reverse_order,
|
| 66 |
+
)
|
| 67 |
+
nseq += 1
|
| 68 |
+
ntok += len(ids)
|
| 69 |
+
consumer(ids)
|
| 70 |
+
line = f.readline()
|
| 71 |
+
return {
|
| 72 |
+
"nseq": nseq,
|
| 73 |
+
"nunk": sum(replaced.values()),
|
| 74 |
+
"ntok": ntok,
|
| 75 |
+
"replaced": replaced,
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1):
|
| 80 |
+
nseq = 0
|
| 81 |
+
|
| 82 |
+
with open(PathManager.get_local_path(filename), "r") as f:
|
| 83 |
+
f.seek(offset)
|
| 84 |
+
line = safe_readline(f)
|
| 85 |
+
while line:
|
| 86 |
+
if end > 0 and f.tell() > end:
|
| 87 |
+
break
|
| 88 |
+
ids = alignment_parser(line)
|
| 89 |
+
nseq += 1
|
| 90 |
+
consumer(ids)
|
| 91 |
+
line = f.readline()
|
| 92 |
+
return {"nseq": nseq}
|
| 93 |
+
|
| 94 |
+
@staticmethod
|
| 95 |
+
def find_offsets(filename, num_chunks):
|
| 96 |
+
with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f:
|
| 97 |
+
size = os.fstat(f.fileno()).st_size
|
| 98 |
+
chunk_size = size // num_chunks
|
| 99 |
+
offsets = [0 for _ in range(num_chunks + 1)]
|
| 100 |
+
for i in range(1, num_chunks):
|
| 101 |
+
f.seek(chunk_size * i)
|
| 102 |
+
safe_readline(f)
|
| 103 |
+
offsets[i] = f.tell()
|
| 104 |
+
return offsets
|
fairseq/checkpoint_utils.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import collections
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import re
|
| 10 |
+
import traceback
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
from typing import Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from fairseq.file_io import PathManager
|
| 16 |
+
from fairseq.models import FairseqDecoder, FairseqEncoder
|
| 17 |
+
from torch.serialization import default_restore_location
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def save_checkpoint(args, trainer, epoch_itr, val_loss):
|
| 24 |
+
from fairseq import distributed_utils, meters
|
| 25 |
+
|
| 26 |
+
# only one worker should attempt to create the required dir
|
| 27 |
+
if args.distributed_rank == 0:
|
| 28 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
prev_best = getattr(save_checkpoint, "best", val_loss)
|
| 31 |
+
if val_loss is not None:
|
| 32 |
+
best_function = max if args.maximize_best_checkpoint_metric else min
|
| 33 |
+
save_checkpoint.best = best_function(val_loss, prev_best)
|
| 34 |
+
|
| 35 |
+
if args.no_save or not trainer.is_data_parallel_master:
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
def is_better(a, b):
|
| 39 |
+
return a >= b if args.maximize_best_checkpoint_metric else a <= b
|
| 40 |
+
|
| 41 |
+
write_timer = meters.StopwatchMeter()
|
| 42 |
+
write_timer.start()
|
| 43 |
+
|
| 44 |
+
epoch = epoch_itr.epoch
|
| 45 |
+
end_of_epoch = epoch_itr.end_of_epoch()
|
| 46 |
+
updates = trainer.get_num_updates()
|
| 47 |
+
|
| 48 |
+
suffix = getattr(args, "checkpoint_suffix", "")
|
| 49 |
+
checkpoint_conds = collections.OrderedDict()
|
| 50 |
+
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
|
| 51 |
+
end_of_epoch
|
| 52 |
+
and not args.no_epoch_checkpoints
|
| 53 |
+
and epoch % args.save_interval == 0
|
| 54 |
+
)
|
| 55 |
+
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
|
| 56 |
+
not end_of_epoch
|
| 57 |
+
and args.save_interval_updates > 0
|
| 58 |
+
and updates % args.save_interval_updates == 0
|
| 59 |
+
)
|
| 60 |
+
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
|
| 61 |
+
not hasattr(save_checkpoint, "best")
|
| 62 |
+
or is_better(val_loss, save_checkpoint.best)
|
| 63 |
+
)
|
| 64 |
+
if val_loss is not None and args.keep_best_checkpoints > 0:
|
| 65 |
+
checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format(
|
| 66 |
+
args.best_checkpoint_metric, val_loss)] = (
|
| 67 |
+
not hasattr(save_checkpoint, "best")
|
| 68 |
+
or is_better(val_loss, save_checkpoint.best)
|
| 69 |
+
)
|
| 70 |
+
checkpoint_conds["checkpoint_last{}.pt".format(suffix)] = not args.no_last_checkpoints
|
| 71 |
+
|
| 72 |
+
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
|
| 73 |
+
if hasattr(save_checkpoint, "best"):
|
| 74 |
+
extra_state.update({"best": save_checkpoint.best})
|
| 75 |
+
|
| 76 |
+
checkpoints = [
|
| 77 |
+
os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
|
| 78 |
+
]
|
| 79 |
+
if len(checkpoints) > 0:
|
| 80 |
+
trainer.save_checkpoint(checkpoints[0], extra_state)
|
| 81 |
+
for cp in checkpoints[1:]:
|
| 82 |
+
PathManager.copy(checkpoints[0], cp, overwrite=True)
|
| 83 |
+
|
| 84 |
+
write_timer.stop()
|
| 85 |
+
logger.info(
|
| 86 |
+
"saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
|
| 87 |
+
checkpoints[0], epoch, updates, val_loss, write_timer.sum
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if not end_of_epoch and args.keep_interval_updates > 0:
|
| 92 |
+
# remove old checkpoints; checkpoints are sorted in descending order
|
| 93 |
+
checkpoints = checkpoint_paths(
|
| 94 |
+
args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt"
|
| 95 |
+
)
|
| 96 |
+
for old_chk in checkpoints[args.keep_interval_updates :]:
|
| 97 |
+
if os.path.lexists(old_chk):
|
| 98 |
+
os.remove(old_chk)
|
| 99 |
+
|
| 100 |
+
if args.keep_last_epochs > 0:
|
| 101 |
+
# remove old epoch checkpoints; checkpoints are sorted in descending order
|
| 102 |
+
checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint(\d+)\.pt")
|
| 103 |
+
for old_chk in checkpoints[args.keep_last_epochs :]:
|
| 104 |
+
if os.path.lexists(old_chk):
|
| 105 |
+
os.remove(old_chk)
|
| 106 |
+
|
| 107 |
+
if args.keep_best_checkpoints > 0:
|
| 108 |
+
# only keep the best N checkpoints according to validation metric
|
| 109 |
+
checkpoints = checkpoint_paths(
|
| 110 |
+
args.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(args.best_checkpoint_metric))
|
| 111 |
+
if not args.maximize_best_checkpoint_metric:
|
| 112 |
+
checkpoints = checkpoints[::-1]
|
| 113 |
+
for old_chk in checkpoints[args.keep_best_checkpoints:]:
|
| 114 |
+
if os.path.lexists(old_chk):
|
| 115 |
+
os.remove(old_chk)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def load_checkpoint(args, trainer, **passthrough_args):
|
| 119 |
+
"""
|
| 120 |
+
Load a checkpoint and restore the training iterator.
|
| 121 |
+
|
| 122 |
+
*passthrough_args* will be passed through to
|
| 123 |
+
``trainer.get_train_iterator``.
|
| 124 |
+
"""
|
| 125 |
+
reset_optimizer = args.reset_optimizer
|
| 126 |
+
reset_lr_scheduler = args.reset_lr_scheduler
|
| 127 |
+
optimizer_overrides = eval(args.optimizer_overrides)
|
| 128 |
+
reset_meters = args.reset_meters
|
| 129 |
+
reset_dataloader = args.reset_dataloader
|
| 130 |
+
|
| 131 |
+
if getattr(args, 'finetune_from_model', None) is not None \
|
| 132 |
+
and (reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader):
|
| 133 |
+
raise ValueError("--finetune-from-model can not be set together with either --reset-optimizer"
|
| 134 |
+
" or reset_lr_scheduler or reset_meters or reset_dataloader")
|
| 135 |
+
|
| 136 |
+
suffix = getattr(args, "checkpoint_suffix", "")
|
| 137 |
+
if args.restore_file == "checkpoint_last.pt": # default value of restore_file is 'checkpoint_last.pt'
|
| 138 |
+
checkpoint_path = os.path.join(args.save_dir, "checkpoint_last{}.pt".format(suffix))
|
| 139 |
+
first_launch = not PathManager.exists(checkpoint_path)
|
| 140 |
+
if getattr(args, 'finetune_from_model', None) is not None and first_launch:
|
| 141 |
+
# if there is no last checkpoint to restore, start the finetune from pretrained model
|
| 142 |
+
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
|
| 143 |
+
if PathManager.exists(args.finetune_from_model):
|
| 144 |
+
checkpoint_path = args.finetune_from_model
|
| 145 |
+
reset_optimizer = True
|
| 146 |
+
reset_lr_scheduler = True
|
| 147 |
+
reset_meters = True
|
| 148 |
+
reset_dataloader = True
|
| 149 |
+
logger.info(f'loading pretrained model from {checkpoint_path}: '
|
| 150 |
+
'optimizer, lr scheduler, meters, dataloader will be reset')
|
| 151 |
+
else:
|
| 152 |
+
raise ValueError(f'--funetune-from-model {args.finetune_from_model} does not exist')
|
| 153 |
+
elif getattr(args, "model_parallel_size", 1) > 1:
|
| 154 |
+
checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt")
|
| 155 |
+
else:
|
| 156 |
+
checkpoint_path = args.restore_file
|
| 157 |
+
|
| 158 |
+
if args.restore_file != "checkpoint_last.pt" and getattr(args, 'finetune_from_model', None):
|
| 159 |
+
raise ValueError(
|
| 160 |
+
'--finetune-from-model and --restore-file (non-default value) '
|
| 161 |
+
'can not be specified together: ' + str(args))
|
| 162 |
+
|
| 163 |
+
extra_state = trainer.load_checkpoint(
|
| 164 |
+
checkpoint_path,
|
| 165 |
+
reset_optimizer,
|
| 166 |
+
reset_lr_scheduler,
|
| 167 |
+
optimizer_overrides,
|
| 168 |
+
reset_meters=reset_meters,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
if (
|
| 172 |
+
extra_state is not None
|
| 173 |
+
and "best" in extra_state
|
| 174 |
+
and not reset_optimizer
|
| 175 |
+
and not reset_meters
|
| 176 |
+
):
|
| 177 |
+
save_checkpoint.best = extra_state["best"]
|
| 178 |
+
|
| 179 |
+
if extra_state is not None and not reset_dataloader:
|
| 180 |
+
# restore iterator from checkpoint
|
| 181 |
+
itr_state = extra_state["train_iterator"]
|
| 182 |
+
epoch_itr = trainer.get_train_iterator(
|
| 183 |
+
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
|
| 184 |
+
)
|
| 185 |
+
epoch_itr.load_state_dict(itr_state)
|
| 186 |
+
else:
|
| 187 |
+
epoch_itr = trainer.get_train_iterator(
|
| 188 |
+
epoch=1, load_dataset=True, **passthrough_args
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
trainer.lr_step(epoch_itr.epoch)
|
| 192 |
+
|
| 193 |
+
return extra_state, epoch_itr
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def load_checkpoint_to_cpu(path, arg_overrides=None):
|
| 197 |
+
"""Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
|
| 198 |
+
with PathManager.open(path, "rb") as f:
|
| 199 |
+
state = torch.load(
|
| 200 |
+
f, map_location=lambda s, l: default_restore_location(s, "cpu")
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
args = state["args"]
|
| 204 |
+
if arg_overrides is not None:
|
| 205 |
+
for arg_name, arg_val in arg_overrides.items():
|
| 206 |
+
setattr(args, arg_name, arg_val)
|
| 207 |
+
state = _upgrade_state_dict(state)
|
| 208 |
+
return state
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, suffix=''):
|
| 212 |
+
"""Loads an ensemble of models.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
filenames (List[str]): checkpoint files to load
|
| 216 |
+
arg_overrides (Dict[str,Any], optional): override model args that
|
| 217 |
+
were used during model training
|
| 218 |
+
task (fairseq.tasks.FairseqTask, optional): task to use for loading
|
| 219 |
+
"""
|
| 220 |
+
ensemble, args, _task = load_model_ensemble_and_task(
|
| 221 |
+
filenames, arg_overrides, task, strict, suffix,
|
| 222 |
+
)
|
| 223 |
+
return ensemble, args
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix=''):
|
| 227 |
+
from fairseq import tasks
|
| 228 |
+
|
| 229 |
+
ensemble = []
|
| 230 |
+
for filename in filenames:
|
| 231 |
+
filename = filename.replace(".pt", suffix + ".pt")
|
| 232 |
+
if not PathManager.exists(filename):
|
| 233 |
+
raise IOError("Model file not found: {}".format(filename))
|
| 234 |
+
state = load_checkpoint_to_cpu(filename, arg_overrides)
|
| 235 |
+
|
| 236 |
+
args = state["args"]
|
| 237 |
+
if task is None:
|
| 238 |
+
task = tasks.setup_task(args)
|
| 239 |
+
|
| 240 |
+
# build model for ensemble
|
| 241 |
+
model = task.build_model(args)
|
| 242 |
+
model.load_state_dict(state["model"], strict=strict, args=args)
|
| 243 |
+
ensemble.append(model)
|
| 244 |
+
return ensemble, args, task
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
|
| 248 |
+
"""Retrieves all checkpoints found in `path` directory.
|
| 249 |
+
|
| 250 |
+
Checkpoints are identified by matching filename to the specified pattern. If
|
| 251 |
+
the pattern contains groups, the result will be sorted by the first group in
|
| 252 |
+
descending order.
|
| 253 |
+
"""
|
| 254 |
+
pt_regexp = re.compile(pattern)
|
| 255 |
+
files = os.listdir(path)
|
| 256 |
+
|
| 257 |
+
entries = []
|
| 258 |
+
for i, f in enumerate(files):
|
| 259 |
+
m = pt_regexp.fullmatch(f)
|
| 260 |
+
if m is not None:
|
| 261 |
+
idx = float(m.group(1)) if len(m.groups()) > 0 else i
|
| 262 |
+
entries.append((idx, m.group(0)))
|
| 263 |
+
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def torch_persistent_save(*args, **kwargs):
|
| 267 |
+
for i in range(3):
|
| 268 |
+
try:
|
| 269 |
+
return torch.save(*args, **kwargs)
|
| 270 |
+
except Exception:
|
| 271 |
+
if i == 2:
|
| 272 |
+
logger.error(traceback.format_exc())
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def save_state(
|
| 276 |
+
filename,
|
| 277 |
+
args,
|
| 278 |
+
model_state_dict,
|
| 279 |
+
criterion,
|
| 280 |
+
optimizer,
|
| 281 |
+
lr_scheduler,
|
| 282 |
+
num_updates,
|
| 283 |
+
optim_history=None,
|
| 284 |
+
extra_state=None,
|
| 285 |
+
):
|
| 286 |
+
from fairseq import utils
|
| 287 |
+
|
| 288 |
+
if optim_history is None:
|
| 289 |
+
optim_history = []
|
| 290 |
+
if extra_state is None:
|
| 291 |
+
extra_state = {}
|
| 292 |
+
state_dict = {
|
| 293 |
+
"args": args,
|
| 294 |
+
"model": model_state_dict or {},
|
| 295 |
+
"optimizer_history": optim_history
|
| 296 |
+
+ [
|
| 297 |
+
{
|
| 298 |
+
"criterion_name": criterion.__class__.__name__,
|
| 299 |
+
"optimizer_name": optimizer.__class__.__name__,
|
| 300 |
+
"lr_scheduler_state": lr_scheduler.state_dict(),
|
| 301 |
+
"num_updates": num_updates,
|
| 302 |
+
}
|
| 303 |
+
],
|
| 304 |
+
"extra_state": extra_state,
|
| 305 |
+
}
|
| 306 |
+
if utils.has_parameters(criterion):
|
| 307 |
+
state_dict["criterion"] = criterion.state_dict()
|
| 308 |
+
if not args.no_save_optimizer_state:
|
| 309 |
+
state_dict["last_optimizer_state"] = optimizer.state_dict()
|
| 310 |
+
|
| 311 |
+
# convert all state to CPU
|
| 312 |
+
state_dict = utils.move_to_cpu(state_dict)
|
| 313 |
+
|
| 314 |
+
with PathManager.open(filename, "wb") as f:
|
| 315 |
+
torch_persistent_save(state_dict, f)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def _upgrade_state_dict(state):
|
| 319 |
+
"""Helper for upgrading old model checkpoints."""
|
| 320 |
+
from fairseq import models, registry, tasks
|
| 321 |
+
|
| 322 |
+
# add optimizer_history
|
| 323 |
+
if "optimizer_history" not in state:
|
| 324 |
+
state["optimizer_history"] = [
|
| 325 |
+
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
|
| 326 |
+
]
|
| 327 |
+
state["last_optimizer_state"] = state["optimizer"]
|
| 328 |
+
del state["optimizer"]
|
| 329 |
+
del state["best_loss"]
|
| 330 |
+
# move extra_state into sub-dictionary
|
| 331 |
+
if "epoch" in state and "extra_state" not in state:
|
| 332 |
+
state["extra_state"] = {
|
| 333 |
+
"epoch": state["epoch"],
|
| 334 |
+
"batch_offset": state["batch_offset"],
|
| 335 |
+
"val_loss": state["val_loss"],
|
| 336 |
+
}
|
| 337 |
+
del state["epoch"]
|
| 338 |
+
del state["batch_offset"]
|
| 339 |
+
del state["val_loss"]
|
| 340 |
+
# reduce optimizer history's memory usage (only keep the last state)
|
| 341 |
+
if "optimizer" in state["optimizer_history"][-1]:
|
| 342 |
+
state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
|
| 343 |
+
for optim_hist in state["optimizer_history"]:
|
| 344 |
+
del optim_hist["optimizer"]
|
| 345 |
+
# record the optimizer class name
|
| 346 |
+
if "optimizer_name" not in state["optimizer_history"][-1]:
|
| 347 |
+
state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
|
| 348 |
+
# move best_loss into lr_scheduler_state
|
| 349 |
+
if "lr_scheduler_state" not in state["optimizer_history"][-1]:
|
| 350 |
+
state["optimizer_history"][-1]["lr_scheduler_state"] = {
|
| 351 |
+
"best": state["optimizer_history"][-1]["best_loss"]
|
| 352 |
+
}
|
| 353 |
+
del state["optimizer_history"][-1]["best_loss"]
|
| 354 |
+
# keep track of number of updates
|
| 355 |
+
if "num_updates" not in state["optimizer_history"][-1]:
|
| 356 |
+
state["optimizer_history"][-1]["num_updates"] = 0
|
| 357 |
+
# old model checkpoints may not have separate source/target positions
|
| 358 |
+
if hasattr(state["args"], "max_positions") and not hasattr(
|
| 359 |
+
state["args"], "max_source_positions"
|
| 360 |
+
):
|
| 361 |
+
state["args"].max_source_positions = state["args"].max_positions
|
| 362 |
+
state["args"].max_target_positions = state["args"].max_positions
|
| 363 |
+
# use stateful training data iterator
|
| 364 |
+
if "train_iterator" not in state["extra_state"]:
|
| 365 |
+
state["extra_state"]["train_iterator"] = {
|
| 366 |
+
"epoch": state["extra_state"]["epoch"],
|
| 367 |
+
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
|
| 368 |
+
}
|
| 369 |
+
# default to translation task
|
| 370 |
+
if not hasattr(state["args"], "task"):
|
| 371 |
+
state["args"].task = "translation"
|
| 372 |
+
# --raw-text and --lazy-load are deprecated
|
| 373 |
+
if getattr(state["args"], "raw_text", False):
|
| 374 |
+
state["args"].dataset_impl = "raw"
|
| 375 |
+
elif getattr(state["args"], "lazy_load", False):
|
| 376 |
+
state["args"].dataset_impl = "lazy"
|
| 377 |
+
# epochs start at 1
|
| 378 |
+
if state["extra_state"]["train_iterator"] is not None:
|
| 379 |
+
state["extra_state"]["train_iterator"]["epoch"] = max(
|
| 380 |
+
state["extra_state"]["train_iterator"].get("epoch", 1),
|
| 381 |
+
1,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# set any missing default values in the task, model or other registries
|
| 385 |
+
registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task])
|
| 386 |
+
registry.set_defaults(state["args"], models.ARCH_MODEL_REGISTRY[state["args"].arch])
|
| 387 |
+
for registry_name, REGISTRY in registry.REGISTRIES.items():
|
| 388 |
+
choice = getattr(state["args"], registry_name, None)
|
| 389 |
+
if choice is not None:
|
| 390 |
+
cls = REGISTRY["registry"][choice]
|
| 391 |
+
registry.set_defaults(state["args"], cls)
|
| 392 |
+
|
| 393 |
+
return state
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def prune_state_dict(state_dict, args):
|
| 397 |
+
"""Prune the given state_dict if desired for LayerDrop
|
| 398 |
+
(https://arxiv.org/abs/1909.11556).
|
| 399 |
+
|
| 400 |
+
Training with LayerDrop allows models to be robust to pruning at inference
|
| 401 |
+
time. This function prunes state_dict to allow smaller models to be loaded
|
| 402 |
+
from a larger model and re-maps the existing state_dict for this to occur.
|
| 403 |
+
|
| 404 |
+
It's called by functions that load models from checkpoints and does not
|
| 405 |
+
need to be called directly.
|
| 406 |
+
"""
|
| 407 |
+
if not args or args.arch == "ptt_transformer":
|
| 408 |
+
# args should not be none, but don't crash if it is.
|
| 409 |
+
return state_dict
|
| 410 |
+
|
| 411 |
+
encoder_layers_to_keep = (
|
| 412 |
+
args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None
|
| 413 |
+
)
|
| 414 |
+
decoder_layers_to_keep = (
|
| 415 |
+
args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
if not encoder_layers_to_keep and not decoder_layers_to_keep:
|
| 419 |
+
return state_dict
|
| 420 |
+
|
| 421 |
+
# apply pruning
|
| 422 |
+
logger.info(
|
| 423 |
+
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
def create_pruning_pass(layers_to_keep, layer_name):
|
| 427 |
+
keep_layers = sorted(
|
| 428 |
+
[int(layer_string) for layer_string in layers_to_keep.split(",")]
|
| 429 |
+
)
|
| 430 |
+
mapping_dict = {}
|
| 431 |
+
for i in range(len(keep_layers)):
|
| 432 |
+
mapping_dict[str(keep_layers[i])] = str(i)
|
| 433 |
+
|
| 434 |
+
regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
|
| 435 |
+
return {"substitution_regex": regex, "mapping_dict": mapping_dict}
|
| 436 |
+
|
| 437 |
+
pruning_passes = []
|
| 438 |
+
if encoder_layers_to_keep:
|
| 439 |
+
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
|
| 440 |
+
if decoder_layers_to_keep:
|
| 441 |
+
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
|
| 442 |
+
|
| 443 |
+
new_state_dict = {}
|
| 444 |
+
for layer_name in state_dict.keys():
|
| 445 |
+
match = re.search(r"\.layers\.(\d+)\.", layer_name)
|
| 446 |
+
# if layer has no number in it, it is a supporting layer, such as an
|
| 447 |
+
# embedding
|
| 448 |
+
if not match:
|
| 449 |
+
new_state_dict[layer_name] = state_dict[layer_name]
|
| 450 |
+
continue
|
| 451 |
+
|
| 452 |
+
# otherwise, layer should be pruned.
|
| 453 |
+
original_layer_number = match.group(1)
|
| 454 |
+
# figure out which mapping dict to replace from
|
| 455 |
+
for pruning_pass in pruning_passes:
|
| 456 |
+
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
|
| 457 |
+
"substitution_regex"
|
| 458 |
+
].search(layer_name):
|
| 459 |
+
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
|
| 460 |
+
substitution_match = pruning_pass["substitution_regex"].search(
|
| 461 |
+
layer_name
|
| 462 |
+
)
|
| 463 |
+
new_state_key = (
|
| 464 |
+
layer_name[: substitution_match.start(1)]
|
| 465 |
+
+ new_layer_number
|
| 466 |
+
+ layer_name[substitution_match.end(1) :]
|
| 467 |
+
)
|
| 468 |
+
new_state_dict[new_state_key] = state_dict[layer_name]
|
| 469 |
+
|
| 470 |
+
# Since layers are now pruned, *_layers_to_keep are no longer needed.
|
| 471 |
+
# This is more of "It would make it work fix" rather than a proper fix.
|
| 472 |
+
if "encoder_layers_to_keep" in vars(args):
|
| 473 |
+
args.encoder_layers_to_keep = None
|
| 474 |
+
if "decoder_layers_to_keep" in vars(args):
|
| 475 |
+
args.decoder_layers_to_keep = None
|
| 476 |
+
|
| 477 |
+
return new_state_dict
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def load_pretrained_component_from_model(
|
| 481 |
+
component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str
|
| 482 |
+
):
|
| 483 |
+
"""
|
| 484 |
+
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
|
| 485 |
+
provided `component` object. If state_dict fails to load, there may be a
|
| 486 |
+
mismatch in the architecture of the corresponding `component` found in the
|
| 487 |
+
`checkpoint` file.
|
| 488 |
+
"""
|
| 489 |
+
if not PathManager.exists(checkpoint):
|
| 490 |
+
raise IOError("Model file not found: {}".format(checkpoint))
|
| 491 |
+
state = load_checkpoint_to_cpu(checkpoint)
|
| 492 |
+
if isinstance(component, FairseqEncoder):
|
| 493 |
+
component_type = "encoder"
|
| 494 |
+
elif isinstance(component, FairseqDecoder):
|
| 495 |
+
component_type = "decoder"
|
| 496 |
+
else:
|
| 497 |
+
raise ValueError(
|
| 498 |
+
"component to load must be either a FairseqEncoder or "
|
| 499 |
+
"FairseqDecoder. Loading other component types are not supported."
|
| 500 |
+
)
|
| 501 |
+
component_state_dict = OrderedDict()
|
| 502 |
+
for key in state["model"].keys():
|
| 503 |
+
if key.startswith(component_type):
|
| 504 |
+
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
|
| 505 |
+
component_subkey = key[len(component_type) + 1 :]
|
| 506 |
+
component_state_dict[component_subkey] = state["model"][key]
|
| 507 |
+
component.load_state_dict(component_state_dict, strict=True)
|
| 508 |
+
return component
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def verify_checkpoint_directory(save_dir: str) -> None:
|
| 512 |
+
if not os.path.exists(save_dir):
|
| 513 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 514 |
+
temp_file_path = os.path.join(save_dir, "dummy")
|
| 515 |
+
try:
|
| 516 |
+
with open(temp_file_path, "w"):
|
| 517 |
+
pass
|
| 518 |
+
except OSError as e:
|
| 519 |
+
logger.warning("Unable to access checkpoint save directory: {}".format(save_dir))
|
| 520 |
+
raise e
|
| 521 |
+
else:
|
| 522 |
+
os.remove(temp_file_path)
|
fairseq/clib/libbleu/libbleu.cpp
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 2017-present, Facebook, Inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
*
|
| 5 |
+
* This source code is licensed under the license found in the
|
| 6 |
+
* LICENSE file in the root directory of this source tree.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
#include <map>
|
| 10 |
+
#include <array>
|
| 11 |
+
#include <cstring>
|
| 12 |
+
#include <cstdio>
|
| 13 |
+
|
| 14 |
+
typedef struct
|
| 15 |
+
{
|
| 16 |
+
size_t reflen;
|
| 17 |
+
size_t predlen;
|
| 18 |
+
size_t match1;
|
| 19 |
+
size_t count1;
|
| 20 |
+
size_t match2;
|
| 21 |
+
size_t count2;
|
| 22 |
+
size_t match3;
|
| 23 |
+
size_t count3;
|
| 24 |
+
size_t match4;
|
| 25 |
+
size_t count4;
|
| 26 |
+
} bleu_stat;
|
| 27 |
+
|
| 28 |
+
// left trim (remove pad)
|
| 29 |
+
void bleu_ltrim(size_t* len, int** sent, int pad) {
|
| 30 |
+
size_t start = 0;
|
| 31 |
+
while(start < *len) {
|
| 32 |
+
if (*(*sent + start) != pad) { break; }
|
| 33 |
+
start++;
|
| 34 |
+
}
|
| 35 |
+
*sent += start;
|
| 36 |
+
*len -= start;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
// right trim remove (eos)
|
| 40 |
+
void bleu_rtrim(size_t* len, int** sent, int pad, int eos) {
|
| 41 |
+
size_t end = *len - 1;
|
| 42 |
+
while (end > 0) {
|
| 43 |
+
if (*(*sent + end) != eos && *(*sent + end) != pad) { break; }
|
| 44 |
+
end--;
|
| 45 |
+
}
|
| 46 |
+
*len = end + 1;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
// left and right trim
|
| 50 |
+
void bleu_trim(size_t* len, int** sent, int pad, int eos) {
|
| 51 |
+
bleu_ltrim(len, sent, pad);
|
| 52 |
+
bleu_rtrim(len, sent, pad, eos);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
size_t bleu_hash(int len, int* data) {
|
| 56 |
+
size_t h = 14695981039346656037ul;
|
| 57 |
+
size_t prime = 0x100000001b3;
|
| 58 |
+
char* b = (char*) data;
|
| 59 |
+
size_t blen = sizeof(int) * len;
|
| 60 |
+
|
| 61 |
+
while (blen-- > 0) {
|
| 62 |
+
h ^= *b++;
|
| 63 |
+
h *= prime;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
return h;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
void bleu_addngram(
|
| 70 |
+
size_t *ntotal, size_t *nmatch, size_t n,
|
| 71 |
+
size_t reflen, int* ref, size_t predlen, int* pred) {
|
| 72 |
+
|
| 73 |
+
if (predlen < n) { return; }
|
| 74 |
+
|
| 75 |
+
predlen = predlen - n + 1;
|
| 76 |
+
(*ntotal) += predlen;
|
| 77 |
+
|
| 78 |
+
if (reflen < n) { return; }
|
| 79 |
+
|
| 80 |
+
reflen = reflen - n + 1;
|
| 81 |
+
|
| 82 |
+
std::map<size_t, size_t> count;
|
| 83 |
+
while (predlen > 0) {
|
| 84 |
+
size_t w = bleu_hash(n, pred++);
|
| 85 |
+
count[w]++;
|
| 86 |
+
predlen--;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
while (reflen > 0) {
|
| 90 |
+
size_t w = bleu_hash(n, ref++);
|
| 91 |
+
if (count[w] > 0) {
|
| 92 |
+
(*nmatch)++;
|
| 93 |
+
count[w] -=1;
|
| 94 |
+
}
|
| 95 |
+
reflen--;
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
extern "C" {
|
| 100 |
+
|
| 101 |
+
#ifdef _WIN64
|
| 102 |
+
__declspec(dllexport)
|
| 103 |
+
#endif
|
| 104 |
+
void bleu_zero_init(bleu_stat* stat) {
|
| 105 |
+
std::memset(stat, 0, sizeof(bleu_stat));
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
#ifdef _WIN64
|
| 109 |
+
__declspec(dllexport)
|
| 110 |
+
#endif
|
| 111 |
+
void bleu_one_init(bleu_stat* stat) {
|
| 112 |
+
bleu_zero_init(stat);
|
| 113 |
+
stat->count1 = 0;
|
| 114 |
+
stat->count2 = 1;
|
| 115 |
+
stat->count3 = 1;
|
| 116 |
+
stat->count4 = 1;
|
| 117 |
+
stat->match1 = 0;
|
| 118 |
+
stat->match2 = 1;
|
| 119 |
+
stat->match3 = 1;
|
| 120 |
+
stat->match4 = 1;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
#ifdef _WIN64
|
| 124 |
+
__declspec(dllexport)
|
| 125 |
+
#endif
|
| 126 |
+
void bleu_add(
|
| 127 |
+
bleu_stat* stat,
|
| 128 |
+
size_t reflen, int* ref, size_t predlen, int* pred, int pad, int eos) {
|
| 129 |
+
|
| 130 |
+
bleu_trim(&reflen, &ref, pad, eos);
|
| 131 |
+
bleu_trim(&predlen, &pred, pad, eos);
|
| 132 |
+
stat->reflen += reflen;
|
| 133 |
+
stat->predlen += predlen;
|
| 134 |
+
|
| 135 |
+
bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred);
|
| 136 |
+
bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred);
|
| 137 |
+
bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred);
|
| 138 |
+
bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
}
|
fairseq/clib/libbleu/module.cpp
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 2017-present, Facebook, Inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
*
|
| 5 |
+
* This source code is licensed under the license found in the
|
| 6 |
+
* LICENSE file in the root directory of this source tree.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
#include <Python.h>
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
static PyMethodDef method_def[] = {
|
| 13 |
+
{NULL, NULL, 0, NULL}
|
| 14 |
+
};
|
| 15 |
+
|
| 16 |
+
static struct PyModuleDef module_def = {
|
| 17 |
+
PyModuleDef_HEAD_INIT,
|
| 18 |
+
"libbleu", /* name of module */
|
| 19 |
+
NULL, /* module documentation, may be NULL */
|
| 20 |
+
-1, /* size of per-interpreter state of the module,
|
| 21 |
+
or -1 if the module keeps state in global variables. */
|
| 22 |
+
method_def
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
#if PY_MAJOR_VERSION == 2
|
| 27 |
+
PyMODINIT_FUNC init_libbleu()
|
| 28 |
+
#else
|
| 29 |
+
PyMODINIT_FUNC PyInit_libbleu()
|
| 30 |
+
#endif
|
| 31 |
+
{
|
| 32 |
+
PyObject *m = PyModule_Create(&module_def);
|
| 33 |
+
if (!m) {
|
| 34 |
+
return NULL;
|
| 35 |
+
}
|
| 36 |
+
return m;
|
| 37 |
+
}
|
fairseq/clib/libnat/edit_dist.cpp
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 2017-present, Facebook, Inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
*
|
| 5 |
+
* This source code is licensed under the license found in the
|
| 6 |
+
* LICENSE file in the root directory of this source tree.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
#include <torch/torch.h> // @manual=//caffe2:torch_extension
|
| 10 |
+
#include <pybind11/detail/common.h>
|
| 11 |
+
#include <pybind11/pybind11.h>
|
| 12 |
+
#include <vector>
|
| 13 |
+
#include <algorithm>
|
| 14 |
+
#include <cstdint>
|
| 15 |
+
#include <iosfwd>
|
| 16 |
+
#include <memory>
|
| 17 |
+
#include <new>
|
| 18 |
+
#include <string>
|
| 19 |
+
#include <utility>
|
| 20 |
+
|
| 21 |
+
using namespace ::std;
|
| 22 |
+
|
| 23 |
+
vector<vector<uint32_t>> edit_distance2_with_dp(
|
| 24 |
+
vector<uint32_t>& x,
|
| 25 |
+
vector<uint32_t>& y) {
|
| 26 |
+
uint32_t lx = x.size();
|
| 27 |
+
uint32_t ly = y.size();
|
| 28 |
+
vector<vector<uint32_t>> d(lx + 1, vector<uint32_t>(ly + 1));
|
| 29 |
+
for (uint32_t i = 0; i < lx + 1; i++) {
|
| 30 |
+
d[i][0] = i;
|
| 31 |
+
}
|
| 32 |
+
for (uint32_t j = 0; j < ly + 1; j++) {
|
| 33 |
+
d[0][j] = j;
|
| 34 |
+
}
|
| 35 |
+
for (uint32_t i = 1; i < lx + 1; i++) {
|
| 36 |
+
for (uint32_t j = 1; j < ly + 1; j++) {
|
| 37 |
+
d[i][j] =
|
| 38 |
+
min(min(d[i - 1][j], d[i][j - 1]) + 1,
|
| 39 |
+
d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1));
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
return d;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
vector<vector<uint32_t>> edit_distance2_backtracking(
|
| 46 |
+
vector<vector<uint32_t>>& d,
|
| 47 |
+
vector<uint32_t>& x,
|
| 48 |
+
vector<uint32_t>& y,
|
| 49 |
+
uint32_t terminal_symbol) {
|
| 50 |
+
vector<uint32_t> seq;
|
| 51 |
+
vector<vector<uint32_t>> edit_seqs(x.size() + 2, vector<uint32_t>());
|
| 52 |
+
/*
|
| 53 |
+
edit_seqs:
|
| 54 |
+
0~x.size() cell is the insertion sequences
|
| 55 |
+
last cell is the delete sequence
|
| 56 |
+
*/
|
| 57 |
+
|
| 58 |
+
if (x.size() == 0) {
|
| 59 |
+
edit_seqs.at(0) = y;
|
| 60 |
+
return edit_seqs;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
uint32_t i = d.size() - 1;
|
| 64 |
+
uint32_t j = d.at(0).size() - 1;
|
| 65 |
+
|
| 66 |
+
while ((i >= 0) && (j >= 0)) {
|
| 67 |
+
if ((i == 0) && (j == 0)) {
|
| 68 |
+
break;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
|
| 72 |
+
seq.push_back(1); // insert
|
| 73 |
+
seq.push_back(y.at(j - 1));
|
| 74 |
+
j--;
|
| 75 |
+
} else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
|
| 76 |
+
seq.push_back(2); // delete
|
| 77 |
+
seq.push_back(x.at(i - 1));
|
| 78 |
+
i--;
|
| 79 |
+
} else {
|
| 80 |
+
seq.push_back(3); // keep
|
| 81 |
+
seq.push_back(x.at(i - 1));
|
| 82 |
+
i--;
|
| 83 |
+
j--;
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
uint32_t prev_op, op, s, word;
|
| 88 |
+
prev_op = 0, s = 0;
|
| 89 |
+
for (uint32_t k = 0; k < seq.size() / 2; k++) {
|
| 90 |
+
op = seq.at(seq.size() - 2 * k - 2);
|
| 91 |
+
word = seq.at(seq.size() - 2 * k - 1);
|
| 92 |
+
if (prev_op != 1) {
|
| 93 |
+
s++;
|
| 94 |
+
}
|
| 95 |
+
if (op == 1) // insert
|
| 96 |
+
{
|
| 97 |
+
edit_seqs.at(s - 1).push_back(word);
|
| 98 |
+
} else if (op == 2) // delete
|
| 99 |
+
{
|
| 100 |
+
edit_seqs.at(x.size() + 1).push_back(1);
|
| 101 |
+
} else {
|
| 102 |
+
edit_seqs.at(x.size() + 1).push_back(0);
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
prev_op = op;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
for (uint32_t k = 0; k < edit_seqs.size(); k++) {
|
| 109 |
+
if (edit_seqs[k].size() == 0) {
|
| 110 |
+
edit_seqs[k].push_back(terminal_symbol);
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
return edit_seqs;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
vector<vector<uint32_t>> edit_distance2_backtracking_with_delete(
|
| 117 |
+
vector<vector<uint32_t>>& d,
|
| 118 |
+
vector<uint32_t>& x,
|
| 119 |
+
vector<uint32_t>& y,
|
| 120 |
+
uint32_t terminal_symbol,
|
| 121 |
+
uint32_t deletion_symbol) {
|
| 122 |
+
vector<uint32_t> seq;
|
| 123 |
+
vector<vector<uint32_t>> edit_seqs(x.size() + 1, vector<uint32_t>());
|
| 124 |
+
/*
|
| 125 |
+
edit_seqs:
|
| 126 |
+
0~x.size() cell is the insertion sequences
|
| 127 |
+
last cell is the delete sequence
|
| 128 |
+
*/
|
| 129 |
+
|
| 130 |
+
if (x.size() == 0) {
|
| 131 |
+
edit_seqs.at(0) = y;
|
| 132 |
+
return edit_seqs;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
uint32_t i = d.size() - 1;
|
| 136 |
+
uint32_t j = d.at(0).size() - 1;
|
| 137 |
+
|
| 138 |
+
while ((i >= 0) && (j >= 0)) {
|
| 139 |
+
if ((i == 0) && (j == 0)) {
|
| 140 |
+
break;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
|
| 144 |
+
seq.push_back(1); // insert
|
| 145 |
+
seq.push_back(y.at(j - 1));
|
| 146 |
+
j--;
|
| 147 |
+
} else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
|
| 148 |
+
seq.push_back(2); // delete
|
| 149 |
+
seq.push_back(x.at(i - 1));
|
| 150 |
+
i--;
|
| 151 |
+
} else {
|
| 152 |
+
seq.push_back(3); // keep
|
| 153 |
+
seq.push_back(x.at(i - 1));
|
| 154 |
+
i--;
|
| 155 |
+
j--;
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
uint32_t prev_op, op, s, word;
|
| 160 |
+
prev_op = 0, s = 0;
|
| 161 |
+
for (uint32_t k = 0; k < seq.size() / 2; k++) {
|
| 162 |
+
op = seq.at(seq.size() - 2 * k - 2);
|
| 163 |
+
word = seq.at(seq.size() - 2 * k - 1);
|
| 164 |
+
if (prev_op != 1) {
|
| 165 |
+
s++;
|
| 166 |
+
}
|
| 167 |
+
if (op == 1) // insert
|
| 168 |
+
{
|
| 169 |
+
edit_seqs.at(s - 1).push_back(word);
|
| 170 |
+
} else if (op == 2) // delete
|
| 171 |
+
{
|
| 172 |
+
edit_seqs.at(s - 1).push_back(deletion_symbol);
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
prev_op = op;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
for (uint32_t k = 0; k < edit_seqs.size(); k++) {
|
| 179 |
+
if (edit_seqs.at(k).size() == 0) {
|
| 180 |
+
edit_seqs.at(k).push_back(terminal_symbol);
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
return edit_seqs;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
vector<uint32_t> compute_ed2(
|
| 187 |
+
vector<vector<uint32_t>>& xs,
|
| 188 |
+
vector<vector<uint32_t>>& ys) {
|
| 189 |
+
vector<uint32_t> distances(xs.size());
|
| 190 |
+
for (uint32_t i = 0; i < xs.size(); i++) {
|
| 191 |
+
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
| 192 |
+
distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size());
|
| 193 |
+
}
|
| 194 |
+
return distances;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
vector<vector<vector<uint32_t>>> suggested_ed2_path(
|
| 198 |
+
vector<vector<uint32_t>>& xs,
|
| 199 |
+
vector<vector<uint32_t>>& ys,
|
| 200 |
+
uint32_t terminal_symbol) {
|
| 201 |
+
vector<vector<vector<uint32_t>>> seq(xs.size());
|
| 202 |
+
for (uint32_t i = 0; i < xs.size(); i++) {
|
| 203 |
+
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
| 204 |
+
seq.at(i) =
|
| 205 |
+
edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol);
|
| 206 |
+
}
|
| 207 |
+
return seq;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
vector<vector<vector<uint32_t>>> suggested_ed2_path_with_delete(
|
| 211 |
+
vector<vector<uint32_t>>& xs,
|
| 212 |
+
vector<vector<uint32_t>>& ys,
|
| 213 |
+
uint32_t terminal_symbol,
|
| 214 |
+
uint32_t deletion_symbol) {
|
| 215 |
+
vector<vector<vector<uint32_t>>> seq(xs.size());
|
| 216 |
+
for (uint32_t i = 0; i < xs.size(); i++) {
|
| 217 |
+
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
| 218 |
+
seq.at(i) = edit_distance2_backtracking_with_delete(
|
| 219 |
+
d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol);
|
| 220 |
+
}
|
| 221 |
+
return seq;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
PYBIND11_MODULE(libnat, m) {
|
| 225 |
+
m.def("compute_ed2", &compute_ed2, "compute_ed2");
|
| 226 |
+
m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path");
|
| 227 |
+
m.def(
|
| 228 |
+
"suggested_ed2_path_with_delete",
|
| 229 |
+
&suggested_ed2_path_with_delete,
|
| 230 |
+
"suggested_ed2_path_with_delete");
|
| 231 |
+
}
|
fairseq/clib/libnat_cuda/binding.cpp
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 2017-present, Facebook, Inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
*
|
| 5 |
+
* This source code is licensed under the license found in the
|
| 6 |
+
* LICENSE file in the root directory of this source tree.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
/*
|
| 10 |
+
This code is partially adpoted from https://github.com/1ytic/pytorch-edit-distance
|
| 11 |
+
*/
|
| 12 |
+
|
| 13 |
+
#include "edit_dist.h"
|
| 14 |
+
#include <torch/types.h>
|
| 15 |
+
|
| 16 |
+
#ifndef TORCH_CHECK
|
| 17 |
+
#define TORCH_CHECK AT_CHECK
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
| 21 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 22 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
torch::Tensor LevenshteinDistance(
|
| 26 |
+
torch::Tensor source,
|
| 27 |
+
torch::Tensor target,
|
| 28 |
+
torch::Tensor source_length,
|
| 29 |
+
torch::Tensor target_length) {
|
| 30 |
+
|
| 31 |
+
CHECK_INPUT(source);
|
| 32 |
+
CHECK_INPUT(target);
|
| 33 |
+
CHECK_INPUT(source_length);
|
| 34 |
+
CHECK_INPUT(target_length);
|
| 35 |
+
return LevenshteinDistanceCuda(source, target, source_length, target_length);
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
torch::Tensor GenerateDeletionLabel(
|
| 39 |
+
torch::Tensor source,
|
| 40 |
+
torch::Tensor operations) {
|
| 41 |
+
|
| 42 |
+
CHECK_INPUT(source);
|
| 43 |
+
CHECK_INPUT(operations);
|
| 44 |
+
return GenerateDeletionLabelCuda(source, operations);
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel(
|
| 48 |
+
torch::Tensor target,
|
| 49 |
+
torch::Tensor operations) {
|
| 50 |
+
|
| 51 |
+
CHECK_INPUT(target);
|
| 52 |
+
CHECK_INPUT(operations);
|
| 53 |
+
return GenerateInsertionLabelCuda(target, operations);
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 57 |
+
m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance");
|
| 58 |
+
m.def("generate_deletion_labels", &GenerateDeletionLabel, "Generate Deletion Label");
|
| 59 |
+
m.def("generate_insertion_labels", &GenerateInsertionLabel, "Generate Insertion Label");
|
| 60 |
+
}
|
fairseq/clib/libnat_cuda/edit_dist.cu
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 2017-present, Facebook, Inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
*
|
| 5 |
+
* This source code is licensed under the license found in the
|
| 6 |
+
* LICENSE file in the root directory of this source tree.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
#include "edit_dist.h"
|
| 10 |
+
#include <THC/THC.h>
|
| 11 |
+
#include <cuda.h>
|
| 12 |
+
#include <cuda_runtime.h>
|
| 13 |
+
#include <device_launch_parameters.h>
|
| 14 |
+
#include <utility> // std::pair
|
| 15 |
+
|
| 16 |
+
template <typename scalar_t>
|
| 17 |
+
__global__ void generate_deletion_label_kernel(
|
| 18 |
+
const scalar_t* __restrict__ source,
|
| 19 |
+
const size_t source_size,
|
| 20 |
+
const size_t operation_size,
|
| 21 |
+
int* __restrict__ operations,
|
| 22 |
+
int* __restrict__ labels) {
|
| 23 |
+
|
| 24 |
+
const int index = blockIdx.x;
|
| 25 |
+
const int offset = index * operation_size;
|
| 26 |
+
const int offset_label = index * source_size;
|
| 27 |
+
|
| 28 |
+
for (int i = 0; i < source_size; i++) {
|
| 29 |
+
labels[offset_label + i] = 0;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
int k = 0;
|
| 33 |
+
for (int i = 0; i < operation_size; i++){
|
| 34 |
+
if (operations[offset + i] == 0){
|
| 35 |
+
break;
|
| 36 |
+
} else if (operations[offset + i] == 1){
|
| 37 |
+
continue;
|
| 38 |
+
} else {
|
| 39 |
+
labels[offset_label + k] = 3 - operations[offset + i];
|
| 40 |
+
k++;
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template <typename scalar_t>
|
| 46 |
+
__global__ void generate_insertion_label_kernel(
|
| 47 |
+
const scalar_t* __restrict__ target,
|
| 48 |
+
const size_t target_size,
|
| 49 |
+
const size_t operation_size,
|
| 50 |
+
int* __restrict__ operations,
|
| 51 |
+
int* __restrict__ labels,
|
| 52 |
+
int* __restrict__ masks) {
|
| 53 |
+
|
| 54 |
+
const int index = blockIdx.x;
|
| 55 |
+
const int offset = index * operation_size;
|
| 56 |
+
const int offset_label = index * target_size;
|
| 57 |
+
|
| 58 |
+
int k = 0;
|
| 59 |
+
int u = 0;
|
| 60 |
+
int m = 0;
|
| 61 |
+
|
| 62 |
+
for (int i = 0; i < target_size; i++) {
|
| 63 |
+
labels[offset_label + i] = 0;
|
| 64 |
+
masks[offset_label + i] = 0;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
for (int i = 0; i < operation_size-1; i++){
|
| 68 |
+
if (operations[offset + i] == 0){
|
| 69 |
+
break;
|
| 70 |
+
} else if (operations[offset + i] == 2){
|
| 71 |
+
continue;
|
| 72 |
+
} else if (operations[offset + i] == 1){
|
| 73 |
+
masks[offset_label + m] = 1;
|
| 74 |
+
u++; m++;
|
| 75 |
+
} else {
|
| 76 |
+
labels[offset_label + k] = u;
|
| 77 |
+
masks[offset_label + m] = 0;
|
| 78 |
+
k++; m++;
|
| 79 |
+
u = 0;
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
template <typename scalar_t>
|
| 85 |
+
__global__ void levenshtein_distance_kernel(
|
| 86 |
+
const scalar_t* __restrict__ source,
|
| 87 |
+
const scalar_t* __restrict__ target,
|
| 88 |
+
const int* __restrict__ source_length,
|
| 89 |
+
const int* __restrict__ target_length,
|
| 90 |
+
const size_t source_size,
|
| 91 |
+
const size_t target_size,
|
| 92 |
+
int* __restrict__ operations,
|
| 93 |
+
int* __restrict__ errors_curr) {
|
| 94 |
+
|
| 95 |
+
const int index = blockIdx.x;
|
| 96 |
+
const int offset = index * (source_size + target_size);
|
| 97 |
+
const int d = index * (source_size + 1) * (target_size + 1);
|
| 98 |
+
const int t = target_size + 1;
|
| 99 |
+
|
| 100 |
+
auto err_idx = [d, t](int i, int j) { return d + i * t + j; };
|
| 101 |
+
auto opt_idx = [offset](int k) { return offset + k; };
|
| 102 |
+
|
| 103 |
+
const int hyp_len = source_length[index];
|
| 104 |
+
const int ref_len = target_length[index];
|
| 105 |
+
const scalar_t* hyp_begin = source + index * source_size;
|
| 106 |
+
const scalar_t* ref_begin = target + index * target_size;
|
| 107 |
+
|
| 108 |
+
// dynamic programming
|
| 109 |
+
for (int i = 0; i <= hyp_len; i++){
|
| 110 |
+
errors_curr[err_idx(i, 0)] = i;
|
| 111 |
+
}
|
| 112 |
+
for (int j = 0; j <= ref_len; j++){
|
| 113 |
+
errors_curr[err_idx(0, j)] = j;
|
| 114 |
+
}
|
| 115 |
+
for (int i = 1; i <= hyp_len; i++){
|
| 116 |
+
for (int j = 1; j <= ref_len; j++){
|
| 117 |
+
errors_curr[err_idx(i, j)] = min(
|
| 118 |
+
min(
|
| 119 |
+
errors_curr[err_idx(i-1, j)],
|
| 120 |
+
errors_curr[err_idx(i, j-1)]
|
| 121 |
+
) + 1,
|
| 122 |
+
errors_curr[err_idx(i-1, j-1)] + 2 * (
|
| 123 |
+
*(hyp_begin+i-1) == *(ref_begin+j-1) ? 0 : 1
|
| 124 |
+
)
|
| 125 |
+
);
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// back-tracing
|
| 130 |
+
int i = hyp_len;
|
| 131 |
+
int j = ref_len;
|
| 132 |
+
int o = hyp_len + ref_len;
|
| 133 |
+
|
| 134 |
+
for (int k = 0; k < source_size + target_size; k++) {
|
| 135 |
+
operations[opt_idx(k)] = 0;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
while ((i >= 0) && (j >= 0)) {
|
| 139 |
+
if ((i == 0) && (j == 0)) {
|
| 140 |
+
break;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
if ((j > 0) && (errors_curr[err_idx(i, j-1)] < errors_curr[err_idx(i, j)])) {
|
| 144 |
+
o--; operations[opt_idx(o)] = 1; j--; // insertion
|
| 145 |
+
} else if ((i > 0) && (errors_curr[err_idx(i-1, j)] < errors_curr[err_idx(i, j)])) {
|
| 146 |
+
o--; operations[opt_idx(o)] = 2; i--; // deletion
|
| 147 |
+
} else {
|
| 148 |
+
o--; operations[opt_idx(o)] = 3; i--; j--; // do nothing
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
// moving to the left
|
| 153 |
+
for (int k = 0; k < hyp_len + ref_len; k++) {
|
| 154 |
+
if (k + o < hyp_len + ref_len){
|
| 155 |
+
operations[opt_idx(k)] = operations[opt_idx(k+o)];
|
| 156 |
+
} else{
|
| 157 |
+
operations[opt_idx(k)] = 0; // padding
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
template <typename scalar_t>
|
| 164 |
+
__global__ void faster_levenshtein_distance_kernel(
|
| 165 |
+
const scalar_t* __restrict__ source,
|
| 166 |
+
const scalar_t* __restrict__ target,
|
| 167 |
+
const int* __restrict__ source_length,
|
| 168 |
+
const int* __restrict__ target_length,
|
| 169 |
+
const size_t source_size,
|
| 170 |
+
const size_t target_size,
|
| 171 |
+
int* __restrict__ operations) {
|
| 172 |
+
|
| 173 |
+
extern __shared__ short errors[];
|
| 174 |
+
auto errors_curr = errors;
|
| 175 |
+
|
| 176 |
+
const int index = blockIdx.x;
|
| 177 |
+
const int offset = index * (source_size + target_size);
|
| 178 |
+
const int t = target_size + 1;
|
| 179 |
+
|
| 180 |
+
auto err_idx = [t](int i, int j) { return i * t + j; };
|
| 181 |
+
auto opt_idx = [offset](int k) { return offset + k; };
|
| 182 |
+
|
| 183 |
+
const int hyp_len = source_length[index];
|
| 184 |
+
const int ref_len = target_length[index];
|
| 185 |
+
const scalar_t* hyp_begin = source + index * source_size;
|
| 186 |
+
const scalar_t* ref_begin = target + index * target_size;
|
| 187 |
+
|
| 188 |
+
// dynamic programming
|
| 189 |
+
for (int i = 0; i <= hyp_len; i++){
|
| 190 |
+
errors_curr[err_idx(i, 0)] = i;
|
| 191 |
+
}
|
| 192 |
+
for (int j = 0; j <= ref_len; j++){
|
| 193 |
+
errors_curr[err_idx(0, j)] = j;
|
| 194 |
+
}
|
| 195 |
+
for (int i = 1; i <= hyp_len; i++){
|
| 196 |
+
for (int j = 1; j <= ref_len; j++){
|
| 197 |
+
errors_curr[err_idx(i, j)] = min(
|
| 198 |
+
min(
|
| 199 |
+
errors_curr[err_idx(i-1, j)],
|
| 200 |
+
errors_curr[err_idx(i, j-1)]
|
| 201 |
+
) + 1,
|
| 202 |
+
errors_curr[err_idx(i-1, j-1)] + 2 * (
|
| 203 |
+
*(hyp_begin+i-1) == *(ref_begin+j-1) ? 0 : 1
|
| 204 |
+
)
|
| 205 |
+
);
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
// back-tracing
|
| 210 |
+
int i = hyp_len;
|
| 211 |
+
int j = ref_len;
|
| 212 |
+
int o = hyp_len + ref_len;
|
| 213 |
+
|
| 214 |
+
for (int k = 0; k < source_size + target_size; k++) {
|
| 215 |
+
operations[opt_idx(k)] = 0;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
while ((i >= 0) && (j >= 0)) {
|
| 219 |
+
if ((i == 0) && (j == 0)) {
|
| 220 |
+
break;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
if ((j > 0) && (errors_curr[err_idx(i, j-1)] < errors_curr[err_idx(i, j)])) {
|
| 224 |
+
o--; operations[opt_idx(o)] = 1; j--; // insertion
|
| 225 |
+
} else if ((i > 0) && (errors_curr[err_idx(i-1, j)] < errors_curr[err_idx(i, j)])) {
|
| 226 |
+
o--; operations[opt_idx(o)] = 2; i--; // deletion
|
| 227 |
+
} else {
|
| 228 |
+
o--; operations[opt_idx(o)] = 3; i--; j--; // do nothing
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
// moving to the left
|
| 233 |
+
for (int k = 0; k < hyp_len + ref_len; k++) {
|
| 234 |
+
if (k + o < hyp_len + ref_len){
|
| 235 |
+
operations[opt_idx(k)] = operations[opt_idx(k+o)];
|
| 236 |
+
} else{
|
| 237 |
+
operations[opt_idx(k)] = 0; // padding
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
torch::Tensor GenerateDeletionLabelCuda(
|
| 245 |
+
torch::Tensor source,
|
| 246 |
+
torch::Tensor operations) {
|
| 247 |
+
|
| 248 |
+
const auto batch_size = source.size(0);
|
| 249 |
+
at::TensorOptions options(source.device());
|
| 250 |
+
options = options.dtype(at::ScalarType::Int);
|
| 251 |
+
auto labels = torch::empty({batch_size, source.size(1)}, options);
|
| 252 |
+
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
|
| 253 |
+
|
| 254 |
+
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] {
|
| 255 |
+
generate_deletion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
|
| 256 |
+
source.data<scalar_t>(),
|
| 257 |
+
source.size(1),
|
| 258 |
+
operations.size(1),
|
| 259 |
+
operations.data<int>(),
|
| 260 |
+
labels.data<int>());
|
| 261 |
+
}));
|
| 262 |
+
|
| 263 |
+
return labels;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
|
| 267 |
+
torch::Tensor target,
|
| 268 |
+
torch::Tensor operations) {
|
| 269 |
+
|
| 270 |
+
const auto batch_size = target.size(0);
|
| 271 |
+
at::TensorOptions options(target.device());
|
| 272 |
+
options = options.dtype(at::ScalarType::Int);
|
| 273 |
+
auto labels = torch::empty({batch_size, target.size(1)}, options);
|
| 274 |
+
auto masks = torch::empty({batch_size, target.size(1)}, options);
|
| 275 |
+
auto stream = at::cuda::getCurrentCUDAStream(target.device().index());
|
| 276 |
+
|
| 277 |
+
AT_DISPATCH_ALL_TYPES(target.scalar_type(), "generate_insertion_labels", ([&] {
|
| 278 |
+
generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
|
| 279 |
+
target.data<scalar_t>(),
|
| 280 |
+
target.size(1),
|
| 281 |
+
operations.size(1),
|
| 282 |
+
operations.data<int>(),
|
| 283 |
+
labels.data<int>(),
|
| 284 |
+
masks.data<int>());
|
| 285 |
+
}));
|
| 286 |
+
|
| 287 |
+
return std::make_pair(labels, masks);
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
torch::Tensor LevenshteinDistanceCuda(
|
| 292 |
+
torch::Tensor source,
|
| 293 |
+
torch::Tensor target,
|
| 294 |
+
torch::Tensor source_length,
|
| 295 |
+
torch::Tensor target_length) {
|
| 296 |
+
|
| 297 |
+
const auto batch_size = source.size(0);
|
| 298 |
+
const auto shared_size = (source.size(1) + 1) * (target.size(1) + 1) * sizeof(short);
|
| 299 |
+
|
| 300 |
+
at::TensorOptions options(source.device());
|
| 301 |
+
options = options.dtype(at::ScalarType::Int);
|
| 302 |
+
auto operations = torch::empty({batch_size, source.size(1) + target.size(1)}, options);
|
| 303 |
+
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
|
| 304 |
+
|
| 305 |
+
if (shared_size > 40000) {
|
| 306 |
+
auto distances = torch::empty({batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options);
|
| 307 |
+
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] {
|
| 308 |
+
levenshtein_distance_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
|
| 309 |
+
source.data<scalar_t>(),
|
| 310 |
+
target.data<scalar_t>(),
|
| 311 |
+
source_length.data<int>(),
|
| 312 |
+
target_length.data<int>(),
|
| 313 |
+
source.size(1),
|
| 314 |
+
target.size(1),
|
| 315 |
+
operations.data<int>(),
|
| 316 |
+
distances.data<int>());
|
| 317 |
+
}));
|
| 318 |
+
} else {
|
| 319 |
+
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "faster_levenshtein_distance", ([&] {
|
| 320 |
+
faster_levenshtein_distance_kernel<scalar_t><<<batch_size, 1, shared_size, stream>>>(
|
| 321 |
+
source.data<scalar_t>(),
|
| 322 |
+
target.data<scalar_t>(),
|
| 323 |
+
source_length.data<int>(),
|
| 324 |
+
target_length.data<int>(),
|
| 325 |
+
source.size(1),
|
| 326 |
+
target.size(1),
|
| 327 |
+
operations.data<int>());
|
| 328 |
+
}));
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
return operations;
|
| 332 |
+
}
|
fairseq/clib/libnat_cuda/edit_dist.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 2017-present, Facebook, Inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
*
|
| 5 |
+
* This source code is licensed under the license found in the
|
| 6 |
+
* LICENSE file in the root directory of this source tree.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include <torch/extension.h>
|
| 12 |
+
|
| 13 |
+
torch::Tensor LevenshteinDistanceCuda(
|
| 14 |
+
torch::Tensor source,
|
| 15 |
+
torch::Tensor target,
|
| 16 |
+
torch::Tensor source_length,
|
| 17 |
+
torch::Tensor target_length);
|
| 18 |
+
|
| 19 |
+
torch::Tensor GenerateDeletionLabelCuda(
|
| 20 |
+
torch::Tensor source,
|
| 21 |
+
torch::Tensor operations);
|
| 22 |
+
|
| 23 |
+
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
|
| 24 |
+
torch::Tensor source,
|
| 25 |
+
torch::Tensor operations);
|
fairseq/criterions/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import importlib
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from fairseq import registry
|
| 10 |
+
from fairseq.criterions.fairseq_criterion import FairseqCriterion, LegacyFairseqCriterion
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry(
|
| 14 |
+
'--criterion',
|
| 15 |
+
base_class=FairseqCriterion,
|
| 16 |
+
default='cross_entropy',
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# automatically import any Python files in the criterions/ directory
|
| 21 |
+
for file in os.listdir(os.path.dirname(__file__)):
|
| 22 |
+
if file.endswith('.py') and not file.startswith('_'):
|
| 23 |
+
module = file[:file.find('.py')]
|
| 24 |
+
importlib.import_module('fairseq.criterions.' + module)
|
fairseq/criterions/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (697 Bytes). View file
|
|
|
fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc
ADDED
|
Binary file (4.1 kB). View file
|
|
|
fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc
ADDED
|
Binary file (4.36 kB). View file
|
|
|
fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc
ADDED
|
Binary file (3.4 kB). View file
|
|
|
fairseq/criterions/__pycache__/ctc.cpython-310.pyc
ADDED
|
Binary file (6.99 kB). View file
|
|
|
fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc
ADDED
|
Binary file (4.19 kB). View file
|
|
|
fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc
ADDED
|
Binary file (4.39 kB). View file
|
|
|
fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc
ADDED
|
Binary file (4.55 kB). View file
|
|
|
fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc
ADDED
|
Binary file (5.36 kB). View file
|
|
|
fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc
ADDED
|
Binary file (3.11 kB). View file
|
|
|
fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc
ADDED
|
Binary file (5.92 kB). View file
|
|
|
fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc
ADDED
|
Binary file (3.85 kB). View file
|
|
|
fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc
ADDED
|
Binary file (4.5 kB). View file
|
|
|