julse commited on
Commit
be611b4
·
verified ·
1 Parent(s): a844cff

Upload 551 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. fairseq/__init__.py +26 -0
  3. fairseq/__pycache__/__init__.cpython-310.pyc +0 -0
  4. fairseq/__pycache__/binarizer.cpython-310.pyc +0 -0
  5. fairseq/__pycache__/checkpoint_utils.cpython-310.pyc +0 -0
  6. fairseq/__pycache__/distributed_utils.cpython-310.pyc +0 -0
  7. fairseq/__pycache__/file_io.cpython-310.pyc +0 -0
  8. fairseq/__pycache__/file_utils.cpython-310.pyc +0 -0
  9. fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc +0 -0
  10. fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc +0 -0
  11. fairseq/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc +0 -0
  12. fairseq/__pycache__/options.cpython-310.pyc +0 -0
  13. fairseq/__pycache__/pdb.cpython-310.pyc +0 -0
  14. fairseq/__pycache__/registry.cpython-310.pyc +0 -0
  15. fairseq/__pycache__/search.cpython-310.pyc +0 -0
  16. fairseq/__pycache__/sequence_generator.cpython-310.pyc +0 -0
  17. fairseq/__pycache__/tokenizer.cpython-310.pyc +0 -0
  18. fairseq/__pycache__/utils.cpython-310.pyc +0 -0
  19. fairseq/benchmark/__init__.py +12 -0
  20. fairseq/benchmark/__pycache__/__init__.cpython-310.pyc +0 -0
  21. fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc +0 -0
  22. fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc +0 -0
  23. fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc +0 -0
  24. fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc +0 -0
  25. fairseq/benchmark/dummy_lm.py +118 -0
  26. fairseq/benchmark/dummy_masked_lm.py +127 -0
  27. fairseq/benchmark/dummy_model.py +95 -0
  28. fairseq/benchmark/dummy_mt.py +120 -0
  29. fairseq/binarizer.py +104 -0
  30. fairseq/checkpoint_utils.py +522 -0
  31. fairseq/clib/libbleu/libbleu.cpp +141 -0
  32. fairseq/clib/libbleu/module.cpp +37 -0
  33. fairseq/clib/libnat/edit_dist.cpp +231 -0
  34. fairseq/clib/libnat_cuda/binding.cpp +60 -0
  35. fairseq/clib/libnat_cuda/edit_dist.cu +332 -0
  36. fairseq/clib/libnat_cuda/edit_dist.h +25 -0
  37. fairseq/criterions/__init__.py +24 -0
  38. fairseq/criterions/__pycache__/__init__.cpython-310.pyc +0 -0
  39. fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc +0 -0
  40. fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc +0 -0
  41. fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc +0 -0
  42. fairseq/criterions/__pycache__/ctc.cpython-310.pyc +0 -0
  43. fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc +0 -0
  44. fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc +0 -0
  45. fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc +0 -0
  46. fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc +0 -0
  47. fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc +0 -0
  48. fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc +0 -0
  49. fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc +0 -0
  50. fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -40,3 +40,5 @@ fairseq/data/data_utils_fast.cpython-38-darwin.so filter=lfs diff=lfs merge=lfs
40
  fairseq/data/token_block_utils_fast.cpython-310-darwin.so filter=lfs diff=lfs merge=lfs -text
41
  fairseq/data/token_block_utils_fast.cpython-36m-darwin.so filter=lfs diff=lfs merge=lfs -text
42
  fairseq/data/token_block_utils_fast.cpython-38-darwin.so filter=lfs diff=lfs merge=lfs -text
 
 
 
40
  fairseq/data/token_block_utils_fast.cpython-310-darwin.so filter=lfs diff=lfs merge=lfs -text
41
  fairseq/data/token_block_utils_fast.cpython-36m-darwin.so filter=lfs diff=lfs merge=lfs -text
42
  fairseq/data/token_block_utils_fast.cpython-38-darwin.so filter=lfs diff=lfs merge=lfs -text
43
+ fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
44
+ fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
fairseq/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ __all__ = ['pdb']
7
+ __version__ = '0.9.0'
8
+
9
+ import sys
10
+
11
+ # backwards compatibility to support `from fairseq.meters import AverageMeter`
12
+ from fairseq.logging import meters, metrics, progress_bar # noqa
13
+ sys.modules['fairseq.meters'] = meters
14
+ sys.modules['fairseq.metrics'] = metrics
15
+ sys.modules['fairseq.progress_bar'] = progress_bar
16
+
17
+ import fairseq.criterions # noqa
18
+ import fairseq.models # noqa
19
+ import fairseq.modules # noqa
20
+ import fairseq.optim # noqa
21
+ import fairseq.optim.lr_scheduler # noqa
22
+ import fairseq.pdb # noqa
23
+ import fairseq.tasks # noqa
24
+
25
+ import fairseq.benchmark # noqa
26
+ import fairseq.model_parallel # noqa
fairseq/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (653 Bytes). View file
 
fairseq/__pycache__/binarizer.cpython-310.pyc ADDED
Binary file (3.09 kB). View file
 
fairseq/__pycache__/checkpoint_utils.cpython-310.pyc ADDED
Binary file (13.1 kB). View file
 
fairseq/__pycache__/distributed_utils.cpython-310.pyc ADDED
Binary file (9.11 kB). View file
 
fairseq/__pycache__/file_io.cpython-310.pyc ADDED
Binary file (2.84 kB). View file
 
fairseq/__pycache__/file_utils.cpython-310.pyc ADDED
Binary file (8.66 kB). View file
 
fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc ADDED
Binary file (2.19 kB). View file
 
fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc ADDED
Binary file (8.44 kB). View file
 
fairseq/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc ADDED
Binary file (4.99 kB). View file
 
fairseq/__pycache__/options.cpython-310.pyc ADDED
Binary file (23.6 kB). View file
 
fairseq/__pycache__/pdb.cpython-310.pyc ADDED
Binary file (1.31 kB). View file
 
fairseq/__pycache__/registry.cpython-310.pyc ADDED
Binary file (2.05 kB). View file
 
fairseq/__pycache__/search.cpython-310.pyc ADDED
Binary file (9.99 kB). View file
 
fairseq/__pycache__/sequence_generator.cpython-310.pyc ADDED
Binary file (24.4 kB). View file
 
fairseq/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (359 Bytes). View file
 
fairseq/__pycache__/utils.cpython-310.pyc ADDED
Binary file (18.1 kB). View file
 
fairseq/benchmark/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # import models/tasks to register them
7
+ from . import ( # noqa
8
+ dummy_lm,
9
+ dummy_masked_lm,
10
+ dummy_model,
11
+ dummy_mt,
12
+ )
fairseq/benchmark/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (247 Bytes). View file
 
fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc ADDED
Binary file (4.4 kB). View file
 
fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc ADDED
Binary file (4.61 kB). View file
 
fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc ADDED
Binary file (3.4 kB). View file
 
fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc ADDED
Binary file (4.45 kB). View file
 
fairseq/benchmark/dummy_lm.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from fairseq.data import Dictionary, FairseqDataset
12
+ from fairseq.tasks import FairseqTask, register_task
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @register_task('dummy_lm')
19
+ class DummyLMTask(FairseqTask):
20
+
21
+ @staticmethod
22
+ def add_args(parser):
23
+ """Add task-specific arguments to the parser."""
24
+ parser.add_argument('--dict-size', default=49996, type=int)
25
+ parser.add_argument('--dataset-size', default=100000, type=int)
26
+ parser.add_argument('--tokens-per-sample', default=512, type=int,
27
+ help='max number of total tokens over all segments '
28
+ 'per sample for BERT dataset')
29
+
30
+ def __init__(self, args, dictionary):
31
+ super().__init__(args)
32
+ self.dictionary = dictionary
33
+ self.seed = args.seed
34
+
35
+ dictionary.pad_to_multiple_(8) # often faster if divisible by 8
36
+
37
+ seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1
38
+
39
+ self.dummy_src = seq[:-1]
40
+ self.dummy_tgt = seq[1:]
41
+
42
+ @classmethod
43
+ def setup_task(cls, args, **kwargs):
44
+ """Setup the task. """
45
+ dictionary = Dictionary()
46
+ for i in range(args.dict_size):
47
+ dictionary.add_symbol('word{}'.format(i))
48
+ logger.info('dictionary: {} types'.format(len(dictionary)))
49
+ return cls(args, dictionary)
50
+
51
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
52
+ """Load a given dataset split.
53
+ Args:
54
+ split (str): name of the split (e.g., train, valid, test)
55
+ """
56
+ if self.args.max_sentences is not None:
57
+ bsz = self.args.max_sentences
58
+ else:
59
+ bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
60
+ self.datasets[split] = DummyDataset(
61
+ {
62
+ 'id': 1,
63
+ 'net_input': {
64
+ 'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
65
+ 'src_lengths': torch.full(
66
+ (bsz, ), self.args.tokens_per_sample, dtype=torch.long
67
+ ),
68
+ },
69
+ 'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
70
+ 'nsentences': bsz,
71
+ 'ntokens': bsz * self.args.tokens_per_sample,
72
+ },
73
+ num_items=self.args.dataset_size,
74
+ item_size=self.args.tokens_per_sample,
75
+ )
76
+
77
+ @property
78
+ def source_dictionary(self):
79
+ return self.dictionary
80
+
81
+ @property
82
+ def target_dictionary(self):
83
+ return self.dictionary
84
+
85
+
86
+ class DummyDataset(FairseqDataset):
87
+
88
+ def __init__(self, batch, num_items, item_size):
89
+ super().__init__()
90
+ self.batch = batch
91
+ self.num_items = num_items
92
+ self.item_size = item_size
93
+
94
+ def __getitem__(self, index):
95
+ return index
96
+
97
+ def __len__(self):
98
+ return self.num_items
99
+
100
+ def collater(self, samples):
101
+ return self.batch
102
+
103
+ @property
104
+ def sizes(self):
105
+ return np.array([self.item_size] * self.num_items)
106
+
107
+ def num_tokens(self, index):
108
+ return self.item_size
109
+
110
+ def size(self, index):
111
+ return self.item_size
112
+
113
+ def ordered_indices(self):
114
+ return np.arange(self.num_items)
115
+
116
+ @property
117
+ def supports_prefetch(self):
118
+ return False
fairseq/benchmark/dummy_masked_lm.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from fairseq.data import Dictionary, FairseqDataset
12
+ from fairseq.tasks import FairseqTask, register_task
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @register_task('dummy_masked_lm')
19
+ class DummyMaskedLMTask(FairseqTask):
20
+
21
+ @staticmethod
22
+ def add_args(parser):
23
+ """Add task-specific arguments to the parser."""
24
+ parser.add_argument('--dict-size', default=49995, type=int)
25
+ parser.add_argument('--dataset-size', default=100000, type=int)
26
+ parser.add_argument('--tokens-per-sample', default=512, type=int,
27
+ help='max number of total tokens over all segments '
28
+ 'per sample for BERT dataset')
29
+
30
+ def __init__(self, args, dictionary):
31
+ super().__init__(args)
32
+ self.dictionary = dictionary
33
+ self.seed = args.seed
34
+
35
+ # add mask token
36
+ self.mask_idx = dictionary.add_symbol('<mask>')
37
+ dictionary.pad_to_multiple_(8) # often faster if divisible by 8
38
+
39
+ mask_idx = 0
40
+ pad_idx = 1
41
+ seq = torch.arange(args.tokens_per_sample) + pad_idx + 1
42
+ mask = torch.arange(2, args.tokens_per_sample, 7) # ~15%
43
+ src = seq.clone()
44
+ src[mask] = mask_idx
45
+ tgt = torch.full_like(seq, pad_idx)
46
+ tgt[mask] = seq[mask]
47
+
48
+ self.dummy_src = src
49
+ self.dummy_tgt = tgt
50
+
51
+ @classmethod
52
+ def setup_task(cls, args, **kwargs):
53
+ """Setup the task. """
54
+ dictionary = Dictionary()
55
+ for i in range(args.dict_size):
56
+ dictionary.add_symbol('word{}'.format(i))
57
+ logger.info('dictionary: {} types'.format(len(dictionary)))
58
+ return cls(args, dictionary)
59
+
60
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
61
+ """Load a given dataset split.
62
+ Args:
63
+ split (str): name of the split (e.g., train, valid, test)
64
+ """
65
+ if self.args.max_sentences is not None:
66
+ bsz = self.args.max_sentences
67
+ else:
68
+ bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
69
+ self.datasets[split] = DummyDataset(
70
+ {
71
+ 'id': 1,
72
+ 'net_input': {
73
+ 'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
74
+ 'src_lengths': torch.full(
75
+ (bsz, ), self.args.tokens_per_sample, dtype=torch.long
76
+ ),
77
+ },
78
+ 'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
79
+ 'nsentences': bsz,
80
+ 'ntokens': bsz * self.args.tokens_per_sample,
81
+ },
82
+ num_items=self.args.dataset_size,
83
+ item_size=self.args.tokens_per_sample,
84
+ )
85
+
86
+ @property
87
+ def source_dictionary(self):
88
+ return self.dictionary
89
+
90
+ @property
91
+ def target_dictionary(self):
92
+ return self.dictionary
93
+
94
+
95
+ class DummyDataset(FairseqDataset):
96
+
97
+ def __init__(self, batch, num_items, item_size):
98
+ super().__init__()
99
+ self.batch = batch
100
+ self.num_items = num_items
101
+ self.item_size = item_size
102
+
103
+ def __getitem__(self, index):
104
+ return index
105
+
106
+ def __len__(self):
107
+ return self.num_items
108
+
109
+ def collater(self, samples):
110
+ return self.batch
111
+
112
+ @property
113
+ def sizes(self):
114
+ return np.array([self.item_size] * self.num_items)
115
+
116
+ def num_tokens(self, index):
117
+ return self.item_size
118
+
119
+ def size(self, index):
120
+ return self.item_size
121
+
122
+ def ordered_indices(self):
123
+ return np.arange(self.num_items)
124
+
125
+ @property
126
+ def supports_prefetch(self):
127
+ return False
fairseq/benchmark/dummy_model.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from fairseq.data import Dictionary
10
+ from fairseq.models import (
11
+ FairseqDecoder,
12
+ FairseqLanguageModel,
13
+ register_model,
14
+ register_model_architecture,
15
+ )
16
+
17
+
18
+ @register_model('dummy_model')
19
+ class DummyModel(FairseqLanguageModel):
20
+
21
+ def __init__(self, args, encoder):
22
+ super().__init__(encoder)
23
+ self.args = args
24
+
25
+ @staticmethod
26
+ def add_args(parser):
27
+ parser.add_argument('--num-layers', type=int, default=24)
28
+ parser.add_argument('--embed-dim', type=int, default=1024)
29
+
30
+ @classmethod
31
+ def build_model(cls, args, task):
32
+ encoder = DummyEncoder(
33
+ num_embed=len(task.target_dictionary),
34
+ embed_dim=args.embed_dim,
35
+ num_layers=args.num_layers,
36
+ )
37
+ return cls(args, encoder)
38
+
39
+ def forward(self, src_tokens, masked_tokens=None, **kwargs):
40
+ return self.decoder(src_tokens, masked_tokens=masked_tokens)
41
+
42
+
43
+ class DummyEncoder(FairseqDecoder):
44
+
45
+ def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
46
+ super().__init__(Dictionary())
47
+ self.embed = nn.Embedding(
48
+ num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
49
+ )
50
+ self.layers_a = nn.ModuleList([
51
+ nn.Sequential(
52
+ nn.LayerNorm(embed_dim),
53
+ nn.Linear(embed_dim, 3*embed_dim), # q, k, v input projection
54
+ nn.Linear(3*embed_dim, embed_dim), # skip self-attention
55
+ nn.Linear(embed_dim, embed_dim), # output projection
56
+ nn.Dropout(),
57
+ )
58
+ for i in range(num_layers)
59
+ ])
60
+ self.layers_b = nn.ModuleList([
61
+ nn.Sequential(
62
+ nn.LayerNorm(embed_dim),
63
+ nn.Linear(embed_dim, 4*embed_dim), # FFN
64
+ nn.ReLU(),
65
+ nn.Linear(4*embed_dim, embed_dim), # FFN
66
+ nn.Dropout(0.1),
67
+ )
68
+ for i in range(num_layers)
69
+ ])
70
+ self.out_proj = nn.Linear(embed_dim, num_embed)
71
+
72
+ def forward(self, tokens, masked_tokens=None):
73
+ x = self.embed(tokens)
74
+ for layer_a, layer_b in zip(self.layers_a, self.layers_b):
75
+ x = x + layer_a(x)
76
+ x = x + layer_b(x)
77
+ x = self.out_proj(x)
78
+ if masked_tokens is not None:
79
+ x = x[masked_tokens]
80
+ return (x,)
81
+
82
+ def max_positions(self):
83
+ return 1024
84
+
85
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
86
+ logits = net_output[0].float()
87
+ if log_probs:
88
+ return F.log_softmax(logits, dim=-1)
89
+ else:
90
+ return F.softmax(logits, dim=-1)
91
+
92
+
93
+ @register_model_architecture('dummy_model', 'dummy_model')
94
+ def base_architecture(args):
95
+ pass
fairseq/benchmark/dummy_mt.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from fairseq.data import Dictionary, FairseqDataset
12
+ from fairseq.tasks import FairseqTask, register_task
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @register_task('dummy_mt')
19
+ class DummyMTTask(FairseqTask):
20
+
21
+ @staticmethod
22
+ def add_args(parser):
23
+ """Add task-specific arguments to the parser."""
24
+ parser.add_argument('--dict-size', default=49996, type=int)
25
+ parser.add_argument('--dataset-size', default=100000, type=int)
26
+ parser.add_argument('--tokens-per-sample', default=512, type=int,
27
+ help='max number of total tokens over all segments '
28
+ 'per sample for BERT dataset')
29
+
30
+ def __init__(self, args, dictionary):
31
+ super().__init__(args)
32
+ self.dictionary = dictionary
33
+ self.seed = args.seed
34
+
35
+ dictionary.pad_to_multiple_(8) # often faster if divisible by 8
36
+
37
+ seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1
38
+
39
+ self.dummy_src = seq[:-1]
40
+ self.dummy_tgt = seq[1:]
41
+
42
+ @classmethod
43
+ def setup_task(cls, args, **kwargs):
44
+ """Setup the task. """
45
+ dictionary = Dictionary()
46
+ for i in range(args.dict_size):
47
+ dictionary.add_symbol('word{}'.format(i))
48
+ logger.info('dictionary: {} types'.format(len(dictionary)))
49
+ return cls(args, dictionary)
50
+
51
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
52
+ """Load a given dataset split.
53
+ Args:
54
+ split (str): name of the split (e.g., train, valid, test)
55
+ """
56
+ if self.args.max_sentences is not None:
57
+ bsz = self.args.max_sentences
58
+ else:
59
+ bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
60
+ tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
61
+ self.datasets[split] = DummyDataset(
62
+ {
63
+ 'id': 1,
64
+ 'net_input': {
65
+ 'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
66
+ 'src_lengths': torch.full(
67
+ (bsz, ), self.args.tokens_per_sample, dtype=torch.long
68
+ ),
69
+ 'prev_output_tokens': tgt.clone(),
70
+ },
71
+ 'target': tgt,
72
+ 'nsentences': bsz,
73
+ 'ntokens': bsz * self.args.tokens_per_sample,
74
+ },
75
+ num_items=self.args.dataset_size,
76
+ item_size=self.args.tokens_per_sample,
77
+ )
78
+
79
+ @property
80
+ def source_dictionary(self):
81
+ return self.dictionary
82
+
83
+ @property
84
+ def target_dictionary(self):
85
+ return self.dictionary
86
+
87
+
88
+ class DummyDataset(FairseqDataset):
89
+
90
+ def __init__(self, batch, num_items, item_size):
91
+ super().__init__()
92
+ self.batch = batch
93
+ self.num_items = num_items
94
+ self.item_size = item_size
95
+
96
+ def __getitem__(self, index):
97
+ return index
98
+
99
+ def __len__(self):
100
+ return self.num_items
101
+
102
+ def collater(self, samples):
103
+ return self.batch
104
+
105
+ @property
106
+ def sizes(self):
107
+ return np.array([self.item_size] * self.num_items)
108
+
109
+ def num_tokens(self, index):
110
+ return self.item_size
111
+
112
+ def size(self, index):
113
+ return self.item_size
114
+
115
+ def ordered_indices(self):
116
+ return np.arange(self.num_items)
117
+
118
+ @property
119
+ def supports_prefetch(self):
120
+ return False
fairseq/binarizer.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from collections import Counter
8
+
9
+ from fairseq.tokenizer import tokenize_line
10
+ import torch
11
+ from fairseq.file_io import PathManager
12
+
13
+ def safe_readline(f):
14
+ pos = f.tell()
15
+ while True:
16
+ try:
17
+ return f.readline()
18
+ except UnicodeDecodeError:
19
+ pos -= 1
20
+ f.seek(pos) # search where this character begins
21
+
22
+
23
+ class Binarizer:
24
+ @staticmethod
25
+ def binarize(
26
+ filename,
27
+ dict,
28
+ consumer,
29
+ tokenize=tokenize_line,
30
+ append_eos=True,
31
+ reverse_order=False,
32
+ offset=0,
33
+ end=-1,
34
+ already_numberized=False,
35
+ ):
36
+ nseq, ntok = 0, 0
37
+ replaced = Counter()
38
+
39
+ def replaced_consumer(word, idx):
40
+ if idx == dict.unk_index and word != dict.unk_word:
41
+ replaced.update([word])
42
+
43
+ with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f:
44
+ f.seek(offset)
45
+ # next(f) breaks f.tell(), hence readline() must be used
46
+ line = safe_readline(f)
47
+ while line:
48
+ if end > 0 and f.tell() > end:
49
+ break
50
+ if already_numberized:
51
+ id_strings = line.strip().split()
52
+ id_list = [int(id_string) for id_string in id_strings]
53
+ if reverse_order:
54
+ id_list.reverse()
55
+ if append_eos:
56
+ id_list.append(dict.eos())
57
+ ids = torch.IntTensor(id_list)
58
+ else:
59
+ ids = dict.encode_line(
60
+ line=line,
61
+ line_tokenizer=tokenize,
62
+ add_if_not_exist=False,
63
+ consumer=replaced_consumer,
64
+ append_eos=append_eos,
65
+ reverse_order=reverse_order,
66
+ )
67
+ nseq += 1
68
+ ntok += len(ids)
69
+ consumer(ids)
70
+ line = f.readline()
71
+ return {
72
+ "nseq": nseq,
73
+ "nunk": sum(replaced.values()),
74
+ "ntok": ntok,
75
+ "replaced": replaced,
76
+ }
77
+
78
+ @staticmethod
79
+ def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1):
80
+ nseq = 0
81
+
82
+ with open(PathManager.get_local_path(filename), "r") as f:
83
+ f.seek(offset)
84
+ line = safe_readline(f)
85
+ while line:
86
+ if end > 0 and f.tell() > end:
87
+ break
88
+ ids = alignment_parser(line)
89
+ nseq += 1
90
+ consumer(ids)
91
+ line = f.readline()
92
+ return {"nseq": nseq}
93
+
94
+ @staticmethod
95
+ def find_offsets(filename, num_chunks):
96
+ with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f:
97
+ size = os.fstat(f.fileno()).st_size
98
+ chunk_size = size // num_chunks
99
+ offsets = [0 for _ in range(num_chunks + 1)]
100
+ for i in range(1, num_chunks):
101
+ f.seek(chunk_size * i)
102
+ safe_readline(f)
103
+ offsets[i] = f.tell()
104
+ return offsets
fairseq/checkpoint_utils.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import collections
7
+ import logging
8
+ import os
9
+ import re
10
+ import traceback
11
+ from collections import OrderedDict
12
+ from typing import Union
13
+
14
+ import torch
15
+ from fairseq.file_io import PathManager
16
+ from fairseq.models import FairseqDecoder, FairseqEncoder
17
+ from torch.serialization import default_restore_location
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def save_checkpoint(args, trainer, epoch_itr, val_loss):
24
+ from fairseq import distributed_utils, meters
25
+
26
+ # only one worker should attempt to create the required dir
27
+ if args.distributed_rank == 0:
28
+ os.makedirs(args.save_dir, exist_ok=True)
29
+
30
+ prev_best = getattr(save_checkpoint, "best", val_loss)
31
+ if val_loss is not None:
32
+ best_function = max if args.maximize_best_checkpoint_metric else min
33
+ save_checkpoint.best = best_function(val_loss, prev_best)
34
+
35
+ if args.no_save or not trainer.is_data_parallel_master:
36
+ return
37
+
38
+ def is_better(a, b):
39
+ return a >= b if args.maximize_best_checkpoint_metric else a <= b
40
+
41
+ write_timer = meters.StopwatchMeter()
42
+ write_timer.start()
43
+
44
+ epoch = epoch_itr.epoch
45
+ end_of_epoch = epoch_itr.end_of_epoch()
46
+ updates = trainer.get_num_updates()
47
+
48
+ suffix = getattr(args, "checkpoint_suffix", "")
49
+ checkpoint_conds = collections.OrderedDict()
50
+ checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
51
+ end_of_epoch
52
+ and not args.no_epoch_checkpoints
53
+ and epoch % args.save_interval == 0
54
+ )
55
+ checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
56
+ not end_of_epoch
57
+ and args.save_interval_updates > 0
58
+ and updates % args.save_interval_updates == 0
59
+ )
60
+ checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
61
+ not hasattr(save_checkpoint, "best")
62
+ or is_better(val_loss, save_checkpoint.best)
63
+ )
64
+ if val_loss is not None and args.keep_best_checkpoints > 0:
65
+ checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format(
66
+ args.best_checkpoint_metric, val_loss)] = (
67
+ not hasattr(save_checkpoint, "best")
68
+ or is_better(val_loss, save_checkpoint.best)
69
+ )
70
+ checkpoint_conds["checkpoint_last{}.pt".format(suffix)] = not args.no_last_checkpoints
71
+
72
+ extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
73
+ if hasattr(save_checkpoint, "best"):
74
+ extra_state.update({"best": save_checkpoint.best})
75
+
76
+ checkpoints = [
77
+ os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
78
+ ]
79
+ if len(checkpoints) > 0:
80
+ trainer.save_checkpoint(checkpoints[0], extra_state)
81
+ for cp in checkpoints[1:]:
82
+ PathManager.copy(checkpoints[0], cp, overwrite=True)
83
+
84
+ write_timer.stop()
85
+ logger.info(
86
+ "saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
87
+ checkpoints[0], epoch, updates, val_loss, write_timer.sum
88
+ )
89
+ )
90
+
91
+ if not end_of_epoch and args.keep_interval_updates > 0:
92
+ # remove old checkpoints; checkpoints are sorted in descending order
93
+ checkpoints = checkpoint_paths(
94
+ args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt"
95
+ )
96
+ for old_chk in checkpoints[args.keep_interval_updates :]:
97
+ if os.path.lexists(old_chk):
98
+ os.remove(old_chk)
99
+
100
+ if args.keep_last_epochs > 0:
101
+ # remove old epoch checkpoints; checkpoints are sorted in descending order
102
+ checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint(\d+)\.pt")
103
+ for old_chk in checkpoints[args.keep_last_epochs :]:
104
+ if os.path.lexists(old_chk):
105
+ os.remove(old_chk)
106
+
107
+ if args.keep_best_checkpoints > 0:
108
+ # only keep the best N checkpoints according to validation metric
109
+ checkpoints = checkpoint_paths(
110
+ args.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(args.best_checkpoint_metric))
111
+ if not args.maximize_best_checkpoint_metric:
112
+ checkpoints = checkpoints[::-1]
113
+ for old_chk in checkpoints[args.keep_best_checkpoints:]:
114
+ if os.path.lexists(old_chk):
115
+ os.remove(old_chk)
116
+
117
+
118
+ def load_checkpoint(args, trainer, **passthrough_args):
119
+ """
120
+ Load a checkpoint and restore the training iterator.
121
+
122
+ *passthrough_args* will be passed through to
123
+ ``trainer.get_train_iterator``.
124
+ """
125
+ reset_optimizer = args.reset_optimizer
126
+ reset_lr_scheduler = args.reset_lr_scheduler
127
+ optimizer_overrides = eval(args.optimizer_overrides)
128
+ reset_meters = args.reset_meters
129
+ reset_dataloader = args.reset_dataloader
130
+
131
+ if getattr(args, 'finetune_from_model', None) is not None \
132
+ and (reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader):
133
+ raise ValueError("--finetune-from-model can not be set together with either --reset-optimizer"
134
+ " or reset_lr_scheduler or reset_meters or reset_dataloader")
135
+
136
+ suffix = getattr(args, "checkpoint_suffix", "")
137
+ if args.restore_file == "checkpoint_last.pt": # default value of restore_file is 'checkpoint_last.pt'
138
+ checkpoint_path = os.path.join(args.save_dir, "checkpoint_last{}.pt".format(suffix))
139
+ first_launch = not PathManager.exists(checkpoint_path)
140
+ if getattr(args, 'finetune_from_model', None) is not None and first_launch:
141
+ # if there is no last checkpoint to restore, start the finetune from pretrained model
142
+ # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
143
+ if PathManager.exists(args.finetune_from_model):
144
+ checkpoint_path = args.finetune_from_model
145
+ reset_optimizer = True
146
+ reset_lr_scheduler = True
147
+ reset_meters = True
148
+ reset_dataloader = True
149
+ logger.info(f'loading pretrained model from {checkpoint_path}: '
150
+ 'optimizer, lr scheduler, meters, dataloader will be reset')
151
+ else:
152
+ raise ValueError(f'--funetune-from-model {args.finetune_from_model} does not exist')
153
+ elif getattr(args, "model_parallel_size", 1) > 1:
154
+ checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt")
155
+ else:
156
+ checkpoint_path = args.restore_file
157
+
158
+ if args.restore_file != "checkpoint_last.pt" and getattr(args, 'finetune_from_model', None):
159
+ raise ValueError(
160
+ '--finetune-from-model and --restore-file (non-default value) '
161
+ 'can not be specified together: ' + str(args))
162
+
163
+ extra_state = trainer.load_checkpoint(
164
+ checkpoint_path,
165
+ reset_optimizer,
166
+ reset_lr_scheduler,
167
+ optimizer_overrides,
168
+ reset_meters=reset_meters,
169
+ )
170
+
171
+ if (
172
+ extra_state is not None
173
+ and "best" in extra_state
174
+ and not reset_optimizer
175
+ and not reset_meters
176
+ ):
177
+ save_checkpoint.best = extra_state["best"]
178
+
179
+ if extra_state is not None and not reset_dataloader:
180
+ # restore iterator from checkpoint
181
+ itr_state = extra_state["train_iterator"]
182
+ epoch_itr = trainer.get_train_iterator(
183
+ epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
184
+ )
185
+ epoch_itr.load_state_dict(itr_state)
186
+ else:
187
+ epoch_itr = trainer.get_train_iterator(
188
+ epoch=1, load_dataset=True, **passthrough_args
189
+ )
190
+
191
+ trainer.lr_step(epoch_itr.epoch)
192
+
193
+ return extra_state, epoch_itr
194
+
195
+
196
+ def load_checkpoint_to_cpu(path, arg_overrides=None):
197
+ """Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
198
+ with PathManager.open(path, "rb") as f:
199
+ state = torch.load(
200
+ f, map_location=lambda s, l: default_restore_location(s, "cpu")
201
+ )
202
+
203
+ args = state["args"]
204
+ if arg_overrides is not None:
205
+ for arg_name, arg_val in arg_overrides.items():
206
+ setattr(args, arg_name, arg_val)
207
+ state = _upgrade_state_dict(state)
208
+ return state
209
+
210
+
211
+ def load_model_ensemble(filenames, arg_overrides=None, task=None, strict=True, suffix=''):
212
+ """Loads an ensemble of models.
213
+
214
+ Args:
215
+ filenames (List[str]): checkpoint files to load
216
+ arg_overrides (Dict[str,Any], optional): override model args that
217
+ were used during model training
218
+ task (fairseq.tasks.FairseqTask, optional): task to use for loading
219
+ """
220
+ ensemble, args, _task = load_model_ensemble_and_task(
221
+ filenames, arg_overrides, task, strict, suffix,
222
+ )
223
+ return ensemble, args
224
+
225
+
226
+ def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix=''):
227
+ from fairseq import tasks
228
+
229
+ ensemble = []
230
+ for filename in filenames:
231
+ filename = filename.replace(".pt", suffix + ".pt")
232
+ if not PathManager.exists(filename):
233
+ raise IOError("Model file not found: {}".format(filename))
234
+ state = load_checkpoint_to_cpu(filename, arg_overrides)
235
+
236
+ args = state["args"]
237
+ if task is None:
238
+ task = tasks.setup_task(args)
239
+
240
+ # build model for ensemble
241
+ model = task.build_model(args)
242
+ model.load_state_dict(state["model"], strict=strict, args=args)
243
+ ensemble.append(model)
244
+ return ensemble, args, task
245
+
246
+
247
+ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
248
+ """Retrieves all checkpoints found in `path` directory.
249
+
250
+ Checkpoints are identified by matching filename to the specified pattern. If
251
+ the pattern contains groups, the result will be sorted by the first group in
252
+ descending order.
253
+ """
254
+ pt_regexp = re.compile(pattern)
255
+ files = os.listdir(path)
256
+
257
+ entries = []
258
+ for i, f in enumerate(files):
259
+ m = pt_regexp.fullmatch(f)
260
+ if m is not None:
261
+ idx = float(m.group(1)) if len(m.groups()) > 0 else i
262
+ entries.append((idx, m.group(0)))
263
+ return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
264
+
265
+
266
+ def torch_persistent_save(*args, **kwargs):
267
+ for i in range(3):
268
+ try:
269
+ return torch.save(*args, **kwargs)
270
+ except Exception:
271
+ if i == 2:
272
+ logger.error(traceback.format_exc())
273
+
274
+
275
+ def save_state(
276
+ filename,
277
+ args,
278
+ model_state_dict,
279
+ criterion,
280
+ optimizer,
281
+ lr_scheduler,
282
+ num_updates,
283
+ optim_history=None,
284
+ extra_state=None,
285
+ ):
286
+ from fairseq import utils
287
+
288
+ if optim_history is None:
289
+ optim_history = []
290
+ if extra_state is None:
291
+ extra_state = {}
292
+ state_dict = {
293
+ "args": args,
294
+ "model": model_state_dict or {},
295
+ "optimizer_history": optim_history
296
+ + [
297
+ {
298
+ "criterion_name": criterion.__class__.__name__,
299
+ "optimizer_name": optimizer.__class__.__name__,
300
+ "lr_scheduler_state": lr_scheduler.state_dict(),
301
+ "num_updates": num_updates,
302
+ }
303
+ ],
304
+ "extra_state": extra_state,
305
+ }
306
+ if utils.has_parameters(criterion):
307
+ state_dict["criterion"] = criterion.state_dict()
308
+ if not args.no_save_optimizer_state:
309
+ state_dict["last_optimizer_state"] = optimizer.state_dict()
310
+
311
+ # convert all state to CPU
312
+ state_dict = utils.move_to_cpu(state_dict)
313
+
314
+ with PathManager.open(filename, "wb") as f:
315
+ torch_persistent_save(state_dict, f)
316
+
317
+
318
+ def _upgrade_state_dict(state):
319
+ """Helper for upgrading old model checkpoints."""
320
+ from fairseq import models, registry, tasks
321
+
322
+ # add optimizer_history
323
+ if "optimizer_history" not in state:
324
+ state["optimizer_history"] = [
325
+ {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
326
+ ]
327
+ state["last_optimizer_state"] = state["optimizer"]
328
+ del state["optimizer"]
329
+ del state["best_loss"]
330
+ # move extra_state into sub-dictionary
331
+ if "epoch" in state and "extra_state" not in state:
332
+ state["extra_state"] = {
333
+ "epoch": state["epoch"],
334
+ "batch_offset": state["batch_offset"],
335
+ "val_loss": state["val_loss"],
336
+ }
337
+ del state["epoch"]
338
+ del state["batch_offset"]
339
+ del state["val_loss"]
340
+ # reduce optimizer history's memory usage (only keep the last state)
341
+ if "optimizer" in state["optimizer_history"][-1]:
342
+ state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
343
+ for optim_hist in state["optimizer_history"]:
344
+ del optim_hist["optimizer"]
345
+ # record the optimizer class name
346
+ if "optimizer_name" not in state["optimizer_history"][-1]:
347
+ state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
348
+ # move best_loss into lr_scheduler_state
349
+ if "lr_scheduler_state" not in state["optimizer_history"][-1]:
350
+ state["optimizer_history"][-1]["lr_scheduler_state"] = {
351
+ "best": state["optimizer_history"][-1]["best_loss"]
352
+ }
353
+ del state["optimizer_history"][-1]["best_loss"]
354
+ # keep track of number of updates
355
+ if "num_updates" not in state["optimizer_history"][-1]:
356
+ state["optimizer_history"][-1]["num_updates"] = 0
357
+ # old model checkpoints may not have separate source/target positions
358
+ if hasattr(state["args"], "max_positions") and not hasattr(
359
+ state["args"], "max_source_positions"
360
+ ):
361
+ state["args"].max_source_positions = state["args"].max_positions
362
+ state["args"].max_target_positions = state["args"].max_positions
363
+ # use stateful training data iterator
364
+ if "train_iterator" not in state["extra_state"]:
365
+ state["extra_state"]["train_iterator"] = {
366
+ "epoch": state["extra_state"]["epoch"],
367
+ "iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
368
+ }
369
+ # default to translation task
370
+ if not hasattr(state["args"], "task"):
371
+ state["args"].task = "translation"
372
+ # --raw-text and --lazy-load are deprecated
373
+ if getattr(state["args"], "raw_text", False):
374
+ state["args"].dataset_impl = "raw"
375
+ elif getattr(state["args"], "lazy_load", False):
376
+ state["args"].dataset_impl = "lazy"
377
+ # epochs start at 1
378
+ if state["extra_state"]["train_iterator"] is not None:
379
+ state["extra_state"]["train_iterator"]["epoch"] = max(
380
+ state["extra_state"]["train_iterator"].get("epoch", 1),
381
+ 1,
382
+ )
383
+
384
+ # set any missing default values in the task, model or other registries
385
+ registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task])
386
+ registry.set_defaults(state["args"], models.ARCH_MODEL_REGISTRY[state["args"].arch])
387
+ for registry_name, REGISTRY in registry.REGISTRIES.items():
388
+ choice = getattr(state["args"], registry_name, None)
389
+ if choice is not None:
390
+ cls = REGISTRY["registry"][choice]
391
+ registry.set_defaults(state["args"], cls)
392
+
393
+ return state
394
+
395
+
396
+ def prune_state_dict(state_dict, args):
397
+ """Prune the given state_dict if desired for LayerDrop
398
+ (https://arxiv.org/abs/1909.11556).
399
+
400
+ Training with LayerDrop allows models to be robust to pruning at inference
401
+ time. This function prunes state_dict to allow smaller models to be loaded
402
+ from a larger model and re-maps the existing state_dict for this to occur.
403
+
404
+ It's called by functions that load models from checkpoints and does not
405
+ need to be called directly.
406
+ """
407
+ if not args or args.arch == "ptt_transformer":
408
+ # args should not be none, but don't crash if it is.
409
+ return state_dict
410
+
411
+ encoder_layers_to_keep = (
412
+ args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None
413
+ )
414
+ decoder_layers_to_keep = (
415
+ args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None
416
+ )
417
+
418
+ if not encoder_layers_to_keep and not decoder_layers_to_keep:
419
+ return state_dict
420
+
421
+ # apply pruning
422
+ logger.info(
423
+ "Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
424
+ )
425
+
426
+ def create_pruning_pass(layers_to_keep, layer_name):
427
+ keep_layers = sorted(
428
+ [int(layer_string) for layer_string in layers_to_keep.split(",")]
429
+ )
430
+ mapping_dict = {}
431
+ for i in range(len(keep_layers)):
432
+ mapping_dict[str(keep_layers[i])] = str(i)
433
+
434
+ regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
435
+ return {"substitution_regex": regex, "mapping_dict": mapping_dict}
436
+
437
+ pruning_passes = []
438
+ if encoder_layers_to_keep:
439
+ pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
440
+ if decoder_layers_to_keep:
441
+ pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
442
+
443
+ new_state_dict = {}
444
+ for layer_name in state_dict.keys():
445
+ match = re.search(r"\.layers\.(\d+)\.", layer_name)
446
+ # if layer has no number in it, it is a supporting layer, such as an
447
+ # embedding
448
+ if not match:
449
+ new_state_dict[layer_name] = state_dict[layer_name]
450
+ continue
451
+
452
+ # otherwise, layer should be pruned.
453
+ original_layer_number = match.group(1)
454
+ # figure out which mapping dict to replace from
455
+ for pruning_pass in pruning_passes:
456
+ if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
457
+ "substitution_regex"
458
+ ].search(layer_name):
459
+ new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
460
+ substitution_match = pruning_pass["substitution_regex"].search(
461
+ layer_name
462
+ )
463
+ new_state_key = (
464
+ layer_name[: substitution_match.start(1)]
465
+ + new_layer_number
466
+ + layer_name[substitution_match.end(1) :]
467
+ )
468
+ new_state_dict[new_state_key] = state_dict[layer_name]
469
+
470
+ # Since layers are now pruned, *_layers_to_keep are no longer needed.
471
+ # This is more of "It would make it work fix" rather than a proper fix.
472
+ if "encoder_layers_to_keep" in vars(args):
473
+ args.encoder_layers_to_keep = None
474
+ if "decoder_layers_to_keep" in vars(args):
475
+ args.decoder_layers_to_keep = None
476
+
477
+ return new_state_dict
478
+
479
+
480
+ def load_pretrained_component_from_model(
481
+ component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str
482
+ ):
483
+ """
484
+ Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
485
+ provided `component` object. If state_dict fails to load, there may be a
486
+ mismatch in the architecture of the corresponding `component` found in the
487
+ `checkpoint` file.
488
+ """
489
+ if not PathManager.exists(checkpoint):
490
+ raise IOError("Model file not found: {}".format(checkpoint))
491
+ state = load_checkpoint_to_cpu(checkpoint)
492
+ if isinstance(component, FairseqEncoder):
493
+ component_type = "encoder"
494
+ elif isinstance(component, FairseqDecoder):
495
+ component_type = "decoder"
496
+ else:
497
+ raise ValueError(
498
+ "component to load must be either a FairseqEncoder or "
499
+ "FairseqDecoder. Loading other component types are not supported."
500
+ )
501
+ component_state_dict = OrderedDict()
502
+ for key in state["model"].keys():
503
+ if key.startswith(component_type):
504
+ # encoder.input_layers.0.0.weight --> input_layers.0.0.weight
505
+ component_subkey = key[len(component_type) + 1 :]
506
+ component_state_dict[component_subkey] = state["model"][key]
507
+ component.load_state_dict(component_state_dict, strict=True)
508
+ return component
509
+
510
+
511
+ def verify_checkpoint_directory(save_dir: str) -> None:
512
+ if not os.path.exists(save_dir):
513
+ os.makedirs(save_dir, exist_ok=True)
514
+ temp_file_path = os.path.join(save_dir, "dummy")
515
+ try:
516
+ with open(temp_file_path, "w"):
517
+ pass
518
+ except OSError as e:
519
+ logger.warning("Unable to access checkpoint save directory: {}".format(save_dir))
520
+ raise e
521
+ else:
522
+ os.remove(temp_file_path)
fairseq/clib/libbleu/libbleu.cpp ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include <map>
10
+ #include <array>
11
+ #include <cstring>
12
+ #include <cstdio>
13
+
14
+ typedef struct
15
+ {
16
+ size_t reflen;
17
+ size_t predlen;
18
+ size_t match1;
19
+ size_t count1;
20
+ size_t match2;
21
+ size_t count2;
22
+ size_t match3;
23
+ size_t count3;
24
+ size_t match4;
25
+ size_t count4;
26
+ } bleu_stat;
27
+
28
+ // left trim (remove pad)
29
+ void bleu_ltrim(size_t* len, int** sent, int pad) {
30
+ size_t start = 0;
31
+ while(start < *len) {
32
+ if (*(*sent + start) != pad) { break; }
33
+ start++;
34
+ }
35
+ *sent += start;
36
+ *len -= start;
37
+ }
38
+
39
+ // right trim remove (eos)
40
+ void bleu_rtrim(size_t* len, int** sent, int pad, int eos) {
41
+ size_t end = *len - 1;
42
+ while (end > 0) {
43
+ if (*(*sent + end) != eos && *(*sent + end) != pad) { break; }
44
+ end--;
45
+ }
46
+ *len = end + 1;
47
+ }
48
+
49
+ // left and right trim
50
+ void bleu_trim(size_t* len, int** sent, int pad, int eos) {
51
+ bleu_ltrim(len, sent, pad);
52
+ bleu_rtrim(len, sent, pad, eos);
53
+ }
54
+
55
+ size_t bleu_hash(int len, int* data) {
56
+ size_t h = 14695981039346656037ul;
57
+ size_t prime = 0x100000001b3;
58
+ char* b = (char*) data;
59
+ size_t blen = sizeof(int) * len;
60
+
61
+ while (blen-- > 0) {
62
+ h ^= *b++;
63
+ h *= prime;
64
+ }
65
+
66
+ return h;
67
+ }
68
+
69
+ void bleu_addngram(
70
+ size_t *ntotal, size_t *nmatch, size_t n,
71
+ size_t reflen, int* ref, size_t predlen, int* pred) {
72
+
73
+ if (predlen < n) { return; }
74
+
75
+ predlen = predlen - n + 1;
76
+ (*ntotal) += predlen;
77
+
78
+ if (reflen < n) { return; }
79
+
80
+ reflen = reflen - n + 1;
81
+
82
+ std::map<size_t, size_t> count;
83
+ while (predlen > 0) {
84
+ size_t w = bleu_hash(n, pred++);
85
+ count[w]++;
86
+ predlen--;
87
+ }
88
+
89
+ while (reflen > 0) {
90
+ size_t w = bleu_hash(n, ref++);
91
+ if (count[w] > 0) {
92
+ (*nmatch)++;
93
+ count[w] -=1;
94
+ }
95
+ reflen--;
96
+ }
97
+ }
98
+
99
+ extern "C" {
100
+
101
+ #ifdef _WIN64
102
+ __declspec(dllexport)
103
+ #endif
104
+ void bleu_zero_init(bleu_stat* stat) {
105
+ std::memset(stat, 0, sizeof(bleu_stat));
106
+ }
107
+
108
+ #ifdef _WIN64
109
+ __declspec(dllexport)
110
+ #endif
111
+ void bleu_one_init(bleu_stat* stat) {
112
+ bleu_zero_init(stat);
113
+ stat->count1 = 0;
114
+ stat->count2 = 1;
115
+ stat->count3 = 1;
116
+ stat->count4 = 1;
117
+ stat->match1 = 0;
118
+ stat->match2 = 1;
119
+ stat->match3 = 1;
120
+ stat->match4 = 1;
121
+ }
122
+
123
+ #ifdef _WIN64
124
+ __declspec(dllexport)
125
+ #endif
126
+ void bleu_add(
127
+ bleu_stat* stat,
128
+ size_t reflen, int* ref, size_t predlen, int* pred, int pad, int eos) {
129
+
130
+ bleu_trim(&reflen, &ref, pad, eos);
131
+ bleu_trim(&predlen, &pred, pad, eos);
132
+ stat->reflen += reflen;
133
+ stat->predlen += predlen;
134
+
135
+ bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred);
136
+ bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred);
137
+ bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred);
138
+ bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred);
139
+ }
140
+
141
+ }
fairseq/clib/libbleu/module.cpp ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include <Python.h>
10
+
11
+
12
+ static PyMethodDef method_def[] = {
13
+ {NULL, NULL, 0, NULL}
14
+ };
15
+
16
+ static struct PyModuleDef module_def = {
17
+ PyModuleDef_HEAD_INIT,
18
+ "libbleu", /* name of module */
19
+ NULL, /* module documentation, may be NULL */
20
+ -1, /* size of per-interpreter state of the module,
21
+ or -1 if the module keeps state in global variables. */
22
+ method_def
23
+ };
24
+
25
+
26
+ #if PY_MAJOR_VERSION == 2
27
+ PyMODINIT_FUNC init_libbleu()
28
+ #else
29
+ PyMODINIT_FUNC PyInit_libbleu()
30
+ #endif
31
+ {
32
+ PyObject *m = PyModule_Create(&module_def);
33
+ if (!m) {
34
+ return NULL;
35
+ }
36
+ return m;
37
+ }
fairseq/clib/libnat/edit_dist.cpp ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include <torch/torch.h> // @manual=//caffe2:torch_extension
10
+ #include <pybind11/detail/common.h>
11
+ #include <pybind11/pybind11.h>
12
+ #include <vector>
13
+ #include <algorithm>
14
+ #include <cstdint>
15
+ #include <iosfwd>
16
+ #include <memory>
17
+ #include <new>
18
+ #include <string>
19
+ #include <utility>
20
+
21
+ using namespace ::std;
22
+
23
+ vector<vector<uint32_t>> edit_distance2_with_dp(
24
+ vector<uint32_t>& x,
25
+ vector<uint32_t>& y) {
26
+ uint32_t lx = x.size();
27
+ uint32_t ly = y.size();
28
+ vector<vector<uint32_t>> d(lx + 1, vector<uint32_t>(ly + 1));
29
+ for (uint32_t i = 0; i < lx + 1; i++) {
30
+ d[i][0] = i;
31
+ }
32
+ for (uint32_t j = 0; j < ly + 1; j++) {
33
+ d[0][j] = j;
34
+ }
35
+ for (uint32_t i = 1; i < lx + 1; i++) {
36
+ for (uint32_t j = 1; j < ly + 1; j++) {
37
+ d[i][j] =
38
+ min(min(d[i - 1][j], d[i][j - 1]) + 1,
39
+ d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1));
40
+ }
41
+ }
42
+ return d;
43
+ }
44
+
45
+ vector<vector<uint32_t>> edit_distance2_backtracking(
46
+ vector<vector<uint32_t>>& d,
47
+ vector<uint32_t>& x,
48
+ vector<uint32_t>& y,
49
+ uint32_t terminal_symbol) {
50
+ vector<uint32_t> seq;
51
+ vector<vector<uint32_t>> edit_seqs(x.size() + 2, vector<uint32_t>());
52
+ /*
53
+ edit_seqs:
54
+ 0~x.size() cell is the insertion sequences
55
+ last cell is the delete sequence
56
+ */
57
+
58
+ if (x.size() == 0) {
59
+ edit_seqs.at(0) = y;
60
+ return edit_seqs;
61
+ }
62
+
63
+ uint32_t i = d.size() - 1;
64
+ uint32_t j = d.at(0).size() - 1;
65
+
66
+ while ((i >= 0) && (j >= 0)) {
67
+ if ((i == 0) && (j == 0)) {
68
+ break;
69
+ }
70
+
71
+ if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
72
+ seq.push_back(1); // insert
73
+ seq.push_back(y.at(j - 1));
74
+ j--;
75
+ } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
76
+ seq.push_back(2); // delete
77
+ seq.push_back(x.at(i - 1));
78
+ i--;
79
+ } else {
80
+ seq.push_back(3); // keep
81
+ seq.push_back(x.at(i - 1));
82
+ i--;
83
+ j--;
84
+ }
85
+ }
86
+
87
+ uint32_t prev_op, op, s, word;
88
+ prev_op = 0, s = 0;
89
+ for (uint32_t k = 0; k < seq.size() / 2; k++) {
90
+ op = seq.at(seq.size() - 2 * k - 2);
91
+ word = seq.at(seq.size() - 2 * k - 1);
92
+ if (prev_op != 1) {
93
+ s++;
94
+ }
95
+ if (op == 1) // insert
96
+ {
97
+ edit_seqs.at(s - 1).push_back(word);
98
+ } else if (op == 2) // delete
99
+ {
100
+ edit_seqs.at(x.size() + 1).push_back(1);
101
+ } else {
102
+ edit_seqs.at(x.size() + 1).push_back(0);
103
+ }
104
+
105
+ prev_op = op;
106
+ }
107
+
108
+ for (uint32_t k = 0; k < edit_seqs.size(); k++) {
109
+ if (edit_seqs[k].size() == 0) {
110
+ edit_seqs[k].push_back(terminal_symbol);
111
+ }
112
+ }
113
+ return edit_seqs;
114
+ }
115
+
116
+ vector<vector<uint32_t>> edit_distance2_backtracking_with_delete(
117
+ vector<vector<uint32_t>>& d,
118
+ vector<uint32_t>& x,
119
+ vector<uint32_t>& y,
120
+ uint32_t terminal_symbol,
121
+ uint32_t deletion_symbol) {
122
+ vector<uint32_t> seq;
123
+ vector<vector<uint32_t>> edit_seqs(x.size() + 1, vector<uint32_t>());
124
+ /*
125
+ edit_seqs:
126
+ 0~x.size() cell is the insertion sequences
127
+ last cell is the delete sequence
128
+ */
129
+
130
+ if (x.size() == 0) {
131
+ edit_seqs.at(0) = y;
132
+ return edit_seqs;
133
+ }
134
+
135
+ uint32_t i = d.size() - 1;
136
+ uint32_t j = d.at(0).size() - 1;
137
+
138
+ while ((i >= 0) && (j >= 0)) {
139
+ if ((i == 0) && (j == 0)) {
140
+ break;
141
+ }
142
+
143
+ if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
144
+ seq.push_back(1); // insert
145
+ seq.push_back(y.at(j - 1));
146
+ j--;
147
+ } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
148
+ seq.push_back(2); // delete
149
+ seq.push_back(x.at(i - 1));
150
+ i--;
151
+ } else {
152
+ seq.push_back(3); // keep
153
+ seq.push_back(x.at(i - 1));
154
+ i--;
155
+ j--;
156
+ }
157
+ }
158
+
159
+ uint32_t prev_op, op, s, word;
160
+ prev_op = 0, s = 0;
161
+ for (uint32_t k = 0; k < seq.size() / 2; k++) {
162
+ op = seq.at(seq.size() - 2 * k - 2);
163
+ word = seq.at(seq.size() - 2 * k - 1);
164
+ if (prev_op != 1) {
165
+ s++;
166
+ }
167
+ if (op == 1) // insert
168
+ {
169
+ edit_seqs.at(s - 1).push_back(word);
170
+ } else if (op == 2) // delete
171
+ {
172
+ edit_seqs.at(s - 1).push_back(deletion_symbol);
173
+ }
174
+
175
+ prev_op = op;
176
+ }
177
+
178
+ for (uint32_t k = 0; k < edit_seqs.size(); k++) {
179
+ if (edit_seqs.at(k).size() == 0) {
180
+ edit_seqs.at(k).push_back(terminal_symbol);
181
+ }
182
+ }
183
+ return edit_seqs;
184
+ }
185
+
186
+ vector<uint32_t> compute_ed2(
187
+ vector<vector<uint32_t>>& xs,
188
+ vector<vector<uint32_t>>& ys) {
189
+ vector<uint32_t> distances(xs.size());
190
+ for (uint32_t i = 0; i < xs.size(); i++) {
191
+ vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
192
+ distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size());
193
+ }
194
+ return distances;
195
+ }
196
+
197
+ vector<vector<vector<uint32_t>>> suggested_ed2_path(
198
+ vector<vector<uint32_t>>& xs,
199
+ vector<vector<uint32_t>>& ys,
200
+ uint32_t terminal_symbol) {
201
+ vector<vector<vector<uint32_t>>> seq(xs.size());
202
+ for (uint32_t i = 0; i < xs.size(); i++) {
203
+ vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
204
+ seq.at(i) =
205
+ edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol);
206
+ }
207
+ return seq;
208
+ }
209
+
210
+ vector<vector<vector<uint32_t>>> suggested_ed2_path_with_delete(
211
+ vector<vector<uint32_t>>& xs,
212
+ vector<vector<uint32_t>>& ys,
213
+ uint32_t terminal_symbol,
214
+ uint32_t deletion_symbol) {
215
+ vector<vector<vector<uint32_t>>> seq(xs.size());
216
+ for (uint32_t i = 0; i < xs.size(); i++) {
217
+ vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
218
+ seq.at(i) = edit_distance2_backtracking_with_delete(
219
+ d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol);
220
+ }
221
+ return seq;
222
+ }
223
+
224
+ PYBIND11_MODULE(libnat, m) {
225
+ m.def("compute_ed2", &compute_ed2, "compute_ed2");
226
+ m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path");
227
+ m.def(
228
+ "suggested_ed2_path_with_delete",
229
+ &suggested_ed2_path_with_delete,
230
+ "suggested_ed2_path_with_delete");
231
+ }
fairseq/clib/libnat_cuda/binding.cpp ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ /*
10
+ This code is partially adpoted from https://github.com/1ytic/pytorch-edit-distance
11
+ */
12
+
13
+ #include "edit_dist.h"
14
+ #include <torch/types.h>
15
+
16
+ #ifndef TORCH_CHECK
17
+ #define TORCH_CHECK AT_CHECK
18
+ #endif
19
+
20
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
21
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
22
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
23
+
24
+
25
+ torch::Tensor LevenshteinDistance(
26
+ torch::Tensor source,
27
+ torch::Tensor target,
28
+ torch::Tensor source_length,
29
+ torch::Tensor target_length) {
30
+
31
+ CHECK_INPUT(source);
32
+ CHECK_INPUT(target);
33
+ CHECK_INPUT(source_length);
34
+ CHECK_INPUT(target_length);
35
+ return LevenshteinDistanceCuda(source, target, source_length, target_length);
36
+ }
37
+
38
+ torch::Tensor GenerateDeletionLabel(
39
+ torch::Tensor source,
40
+ torch::Tensor operations) {
41
+
42
+ CHECK_INPUT(source);
43
+ CHECK_INPUT(operations);
44
+ return GenerateDeletionLabelCuda(source, operations);
45
+ }
46
+
47
+ std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel(
48
+ torch::Tensor target,
49
+ torch::Tensor operations) {
50
+
51
+ CHECK_INPUT(target);
52
+ CHECK_INPUT(operations);
53
+ return GenerateInsertionLabelCuda(target, operations);
54
+ }
55
+
56
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
57
+ m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance");
58
+ m.def("generate_deletion_labels", &GenerateDeletionLabel, "Generate Deletion Label");
59
+ m.def("generate_insertion_labels", &GenerateInsertionLabel, "Generate Insertion Label");
60
+ }
fairseq/clib/libnat_cuda/edit_dist.cu ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include "edit_dist.h"
10
+ #include <THC/THC.h>
11
+ #include <cuda.h>
12
+ #include <cuda_runtime.h>
13
+ #include <device_launch_parameters.h>
14
+ #include <utility> // std::pair
15
+
16
+ template <typename scalar_t>
17
+ __global__ void generate_deletion_label_kernel(
18
+ const scalar_t* __restrict__ source,
19
+ const size_t source_size,
20
+ const size_t operation_size,
21
+ int* __restrict__ operations,
22
+ int* __restrict__ labels) {
23
+
24
+ const int index = blockIdx.x;
25
+ const int offset = index * operation_size;
26
+ const int offset_label = index * source_size;
27
+
28
+ for (int i = 0; i < source_size; i++) {
29
+ labels[offset_label + i] = 0;
30
+ }
31
+
32
+ int k = 0;
33
+ for (int i = 0; i < operation_size; i++){
34
+ if (operations[offset + i] == 0){
35
+ break;
36
+ } else if (operations[offset + i] == 1){
37
+ continue;
38
+ } else {
39
+ labels[offset_label + k] = 3 - operations[offset + i];
40
+ k++;
41
+ }
42
+ }
43
+ }
44
+
45
+ template <typename scalar_t>
46
+ __global__ void generate_insertion_label_kernel(
47
+ const scalar_t* __restrict__ target,
48
+ const size_t target_size,
49
+ const size_t operation_size,
50
+ int* __restrict__ operations,
51
+ int* __restrict__ labels,
52
+ int* __restrict__ masks) {
53
+
54
+ const int index = blockIdx.x;
55
+ const int offset = index * operation_size;
56
+ const int offset_label = index * target_size;
57
+
58
+ int k = 0;
59
+ int u = 0;
60
+ int m = 0;
61
+
62
+ for (int i = 0; i < target_size; i++) {
63
+ labels[offset_label + i] = 0;
64
+ masks[offset_label + i] = 0;
65
+ }
66
+
67
+ for (int i = 0; i < operation_size-1; i++){
68
+ if (operations[offset + i] == 0){
69
+ break;
70
+ } else if (operations[offset + i] == 2){
71
+ continue;
72
+ } else if (operations[offset + i] == 1){
73
+ masks[offset_label + m] = 1;
74
+ u++; m++;
75
+ } else {
76
+ labels[offset_label + k] = u;
77
+ masks[offset_label + m] = 0;
78
+ k++; m++;
79
+ u = 0;
80
+ }
81
+ }
82
+ }
83
+
84
+ template <typename scalar_t>
85
+ __global__ void levenshtein_distance_kernel(
86
+ const scalar_t* __restrict__ source,
87
+ const scalar_t* __restrict__ target,
88
+ const int* __restrict__ source_length,
89
+ const int* __restrict__ target_length,
90
+ const size_t source_size,
91
+ const size_t target_size,
92
+ int* __restrict__ operations,
93
+ int* __restrict__ errors_curr) {
94
+
95
+ const int index = blockIdx.x;
96
+ const int offset = index * (source_size + target_size);
97
+ const int d = index * (source_size + 1) * (target_size + 1);
98
+ const int t = target_size + 1;
99
+
100
+ auto err_idx = [d, t](int i, int j) { return d + i * t + j; };
101
+ auto opt_idx = [offset](int k) { return offset + k; };
102
+
103
+ const int hyp_len = source_length[index];
104
+ const int ref_len = target_length[index];
105
+ const scalar_t* hyp_begin = source + index * source_size;
106
+ const scalar_t* ref_begin = target + index * target_size;
107
+
108
+ // dynamic programming
109
+ for (int i = 0; i <= hyp_len; i++){
110
+ errors_curr[err_idx(i, 0)] = i;
111
+ }
112
+ for (int j = 0; j <= ref_len; j++){
113
+ errors_curr[err_idx(0, j)] = j;
114
+ }
115
+ for (int i = 1; i <= hyp_len; i++){
116
+ for (int j = 1; j <= ref_len; j++){
117
+ errors_curr[err_idx(i, j)] = min(
118
+ min(
119
+ errors_curr[err_idx(i-1, j)],
120
+ errors_curr[err_idx(i, j-1)]
121
+ ) + 1,
122
+ errors_curr[err_idx(i-1, j-1)] + 2 * (
123
+ *(hyp_begin+i-1) == *(ref_begin+j-1) ? 0 : 1
124
+ )
125
+ );
126
+ }
127
+ }
128
+
129
+ // back-tracing
130
+ int i = hyp_len;
131
+ int j = ref_len;
132
+ int o = hyp_len + ref_len;
133
+
134
+ for (int k = 0; k < source_size + target_size; k++) {
135
+ operations[opt_idx(k)] = 0;
136
+ }
137
+
138
+ while ((i >= 0) && (j >= 0)) {
139
+ if ((i == 0) && (j == 0)) {
140
+ break;
141
+ }
142
+
143
+ if ((j > 0) && (errors_curr[err_idx(i, j-1)] < errors_curr[err_idx(i, j)])) {
144
+ o--; operations[opt_idx(o)] = 1; j--; // insertion
145
+ } else if ((i > 0) && (errors_curr[err_idx(i-1, j)] < errors_curr[err_idx(i, j)])) {
146
+ o--; operations[opt_idx(o)] = 2; i--; // deletion
147
+ } else {
148
+ o--; operations[opt_idx(o)] = 3; i--; j--; // do nothing
149
+ }
150
+ }
151
+
152
+ // moving to the left
153
+ for (int k = 0; k < hyp_len + ref_len; k++) {
154
+ if (k + o < hyp_len + ref_len){
155
+ operations[opt_idx(k)] = operations[opt_idx(k+o)];
156
+ } else{
157
+ operations[opt_idx(k)] = 0; // padding
158
+ }
159
+ }
160
+
161
+ }
162
+
163
+ template <typename scalar_t>
164
+ __global__ void faster_levenshtein_distance_kernel(
165
+ const scalar_t* __restrict__ source,
166
+ const scalar_t* __restrict__ target,
167
+ const int* __restrict__ source_length,
168
+ const int* __restrict__ target_length,
169
+ const size_t source_size,
170
+ const size_t target_size,
171
+ int* __restrict__ operations) {
172
+
173
+ extern __shared__ short errors[];
174
+ auto errors_curr = errors;
175
+
176
+ const int index = blockIdx.x;
177
+ const int offset = index * (source_size + target_size);
178
+ const int t = target_size + 1;
179
+
180
+ auto err_idx = [t](int i, int j) { return i * t + j; };
181
+ auto opt_idx = [offset](int k) { return offset + k; };
182
+
183
+ const int hyp_len = source_length[index];
184
+ const int ref_len = target_length[index];
185
+ const scalar_t* hyp_begin = source + index * source_size;
186
+ const scalar_t* ref_begin = target + index * target_size;
187
+
188
+ // dynamic programming
189
+ for (int i = 0; i <= hyp_len; i++){
190
+ errors_curr[err_idx(i, 0)] = i;
191
+ }
192
+ for (int j = 0; j <= ref_len; j++){
193
+ errors_curr[err_idx(0, j)] = j;
194
+ }
195
+ for (int i = 1; i <= hyp_len; i++){
196
+ for (int j = 1; j <= ref_len; j++){
197
+ errors_curr[err_idx(i, j)] = min(
198
+ min(
199
+ errors_curr[err_idx(i-1, j)],
200
+ errors_curr[err_idx(i, j-1)]
201
+ ) + 1,
202
+ errors_curr[err_idx(i-1, j-1)] + 2 * (
203
+ *(hyp_begin+i-1) == *(ref_begin+j-1) ? 0 : 1
204
+ )
205
+ );
206
+ }
207
+ }
208
+
209
+ // back-tracing
210
+ int i = hyp_len;
211
+ int j = ref_len;
212
+ int o = hyp_len + ref_len;
213
+
214
+ for (int k = 0; k < source_size + target_size; k++) {
215
+ operations[opt_idx(k)] = 0;
216
+ }
217
+
218
+ while ((i >= 0) && (j >= 0)) {
219
+ if ((i == 0) && (j == 0)) {
220
+ break;
221
+ }
222
+
223
+ if ((j > 0) && (errors_curr[err_idx(i, j-1)] < errors_curr[err_idx(i, j)])) {
224
+ o--; operations[opt_idx(o)] = 1; j--; // insertion
225
+ } else if ((i > 0) && (errors_curr[err_idx(i-1, j)] < errors_curr[err_idx(i, j)])) {
226
+ o--; operations[opt_idx(o)] = 2; i--; // deletion
227
+ } else {
228
+ o--; operations[opt_idx(o)] = 3; i--; j--; // do nothing
229
+ }
230
+ }
231
+
232
+ // moving to the left
233
+ for (int k = 0; k < hyp_len + ref_len; k++) {
234
+ if (k + o < hyp_len + ref_len){
235
+ operations[opt_idx(k)] = operations[opt_idx(k+o)];
236
+ } else{
237
+ operations[opt_idx(k)] = 0; // padding
238
+ }
239
+ }
240
+
241
+ }
242
+
243
+
244
+ torch::Tensor GenerateDeletionLabelCuda(
245
+ torch::Tensor source,
246
+ torch::Tensor operations) {
247
+
248
+ const auto batch_size = source.size(0);
249
+ at::TensorOptions options(source.device());
250
+ options = options.dtype(at::ScalarType::Int);
251
+ auto labels = torch::empty({batch_size, source.size(1)}, options);
252
+ auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
253
+
254
+ AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] {
255
+ generate_deletion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
256
+ source.data<scalar_t>(),
257
+ source.size(1),
258
+ operations.size(1),
259
+ operations.data<int>(),
260
+ labels.data<int>());
261
+ }));
262
+
263
+ return labels;
264
+ }
265
+
266
+ std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
267
+ torch::Tensor target,
268
+ torch::Tensor operations) {
269
+
270
+ const auto batch_size = target.size(0);
271
+ at::TensorOptions options(target.device());
272
+ options = options.dtype(at::ScalarType::Int);
273
+ auto labels = torch::empty({batch_size, target.size(1)}, options);
274
+ auto masks = torch::empty({batch_size, target.size(1)}, options);
275
+ auto stream = at::cuda::getCurrentCUDAStream(target.device().index());
276
+
277
+ AT_DISPATCH_ALL_TYPES(target.scalar_type(), "generate_insertion_labels", ([&] {
278
+ generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
279
+ target.data<scalar_t>(),
280
+ target.size(1),
281
+ operations.size(1),
282
+ operations.data<int>(),
283
+ labels.data<int>(),
284
+ masks.data<int>());
285
+ }));
286
+
287
+ return std::make_pair(labels, masks);
288
+ }
289
+
290
+
291
+ torch::Tensor LevenshteinDistanceCuda(
292
+ torch::Tensor source,
293
+ torch::Tensor target,
294
+ torch::Tensor source_length,
295
+ torch::Tensor target_length) {
296
+
297
+ const auto batch_size = source.size(0);
298
+ const auto shared_size = (source.size(1) + 1) * (target.size(1) + 1) * sizeof(short);
299
+
300
+ at::TensorOptions options(source.device());
301
+ options = options.dtype(at::ScalarType::Int);
302
+ auto operations = torch::empty({batch_size, source.size(1) + target.size(1)}, options);
303
+ auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
304
+
305
+ if (shared_size > 40000) {
306
+ auto distances = torch::empty({batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options);
307
+ AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] {
308
+ levenshtein_distance_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
309
+ source.data<scalar_t>(),
310
+ target.data<scalar_t>(),
311
+ source_length.data<int>(),
312
+ target_length.data<int>(),
313
+ source.size(1),
314
+ target.size(1),
315
+ operations.data<int>(),
316
+ distances.data<int>());
317
+ }));
318
+ } else {
319
+ AT_DISPATCH_ALL_TYPES(source.scalar_type(), "faster_levenshtein_distance", ([&] {
320
+ faster_levenshtein_distance_kernel<scalar_t><<<batch_size, 1, shared_size, stream>>>(
321
+ source.data<scalar_t>(),
322
+ target.data<scalar_t>(),
323
+ source_length.data<int>(),
324
+ target_length.data<int>(),
325
+ source.size(1),
326
+ target.size(1),
327
+ operations.data<int>());
328
+ }));
329
+ }
330
+
331
+ return operations;
332
+ }
fairseq/clib/libnat_cuda/edit_dist.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include <torch/extension.h>
12
+
13
+ torch::Tensor LevenshteinDistanceCuda(
14
+ torch::Tensor source,
15
+ torch::Tensor target,
16
+ torch::Tensor source_length,
17
+ torch::Tensor target_length);
18
+
19
+ torch::Tensor GenerateDeletionLabelCuda(
20
+ torch::Tensor source,
21
+ torch::Tensor operations);
22
+
23
+ std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
24
+ torch::Tensor source,
25
+ torch::Tensor operations);
fairseq/criterions/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import os
8
+
9
+ from fairseq import registry
10
+ from fairseq.criterions.fairseq_criterion import FairseqCriterion, LegacyFairseqCriterion
11
+
12
+
13
+ build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry(
14
+ '--criterion',
15
+ base_class=FairseqCriterion,
16
+ default='cross_entropy',
17
+ )
18
+
19
+
20
+ # automatically import any Python files in the criterions/ directory
21
+ for file in os.listdir(os.path.dirname(__file__)):
22
+ if file.endswith('.py') and not file.startswith('_'):
23
+ module = file[:file.find('.py')]
24
+ importlib.import_module('fairseq.criterions.' + module)
fairseq/criterions/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (697 Bytes). View file
 
fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc ADDED
Binary file (4.1 kB). View file
 
fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc ADDED
Binary file (4.36 kB). View file
 
fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc ADDED
Binary file (3.4 kB). View file
 
fairseq/criterions/__pycache__/ctc.cpython-310.pyc ADDED
Binary file (6.99 kB). View file
 
fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc ADDED
Binary file (4.19 kB). View file
 
fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc ADDED
Binary file (4.39 kB). View file
 
fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc ADDED
Binary file (4.55 kB). View file
 
fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc ADDED
Binary file (5.36 kB). View file
 
fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc ADDED
Binary file (3.11 kB). View file
 
fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc ADDED
Binary file (5.92 kB). View file
 
fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc ADDED
Binary file (3.85 kB). View file
 
fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc ADDED
Binary file (4.5 kB). View file